﻿// Hitachi.Software.Cryptography.Mugi.cs
// 1.0.*.1

using System;
using System.Runtime.CompilerServices;
using System.Runtime.ConstrainedExecution;
using System.Runtime.InteropServices;
using System.Security;
using System.Security.Cryptography;
using System.Threading;
using Microsoft.Win32.SafeHandles;

namespace Hitachi.Software.Cryptography
{
    partial struct Mugi
    {
        private struct State
        {
            public UInt64 a0, a1, a2;
        }
        private struct Buffer
        {
            public UInt64 b15, b14, b13, b12, b11, b10, b9, b8, b7, b6, b5, b4, b3, b2, b1, b0;
        }
        private sealed class SafeSboxMdsTableMemoryBlock : SafeHandleZeroOrMinusOneIsInvalid
        {
            public unsafe UInt32* Address
            {
                get
                {
                    if(IsClosed)
                        throw new ObjectDisposedException(GetType().FullName);
                    return (UInt32*)handle;
                }
            }
            [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)]
            [SuppressUnmanagedCodeSecurity]
            [DllImport("Ole32.dll", ExactSpelling = true)]
            private static extern IntPtr CoTaskMemAlloc(Int32 cb);
            public unsafe SafeSboxMdsTableMemoryBlock()
                : base(true)
            {
                RuntimeHelpers.PrepareConstrainedRegions();
                try
                {
                }
                finally
                {
                    handle = CoTaskMemAlloc(256 * sizeof(UInt32));
                }
                if(IsInvalid)
                    throw new OutOfMemoryException();
                Byte[] sbox = new Byte[256];
                UInt32 i = 1, j = 1;
                do
                {
                    i = (Byte)(i ^ (i << 1) ^ ((i & 0x80) != 0 ? 0x1b : 0));
                    j ^= j << 1;
                    j ^= j << 2;
                    j ^= j << 4;
                    if((j & 0x80) != 0)
                        j ^= 0x09;
                    j = (Byte)j;
                    sbox[i] = (Byte)(j ^ j << 1 ^ j << 2 ^ j << 3 ^ j << 4 ^ j >> 4 ^ j >> 5 ^ j >> 6 ^ j >> 7 ^ 0x63);
                }
                while(i != 1);
                sbox[0] = 0x63;
                Byte[] mul2 = new Byte[256];
                i = 0;
                do
                    mul2[i] = (Byte)(i << 1);
                while(++i < 128);
                do
                    mul2[i] = (Byte)(i << 1 ^ 0x1b);
                while(++i < 256);
                UInt32* t = Address;
                i = 0;
                do
                {
                    UInt32 m1 = sbox[i], m2 = mul2[m1], m3 = m2 ^ m1;
                    t[i] = IntPtr.Size == 4 ? m1 | m1 << 8 | m2 << 16 | m3 << 24 : m1 | m2 << 8 | m3 << 16 | m1 << 24;  // 3211(32bit) / 1231(64bit)
                }
                while(++i < 256);
            }
            [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)]
            [SuppressUnmanagedCodeSecurity]
            [DllImport("Ole32.dll", ExactSpelling = true)]
            private static extern void CoTaskMemFree(IntPtr ptr);
            protected override Boolean ReleaseHandle()
            {
                CoTaskMemFree(handle);
                return true;
            }
        }
        private static WeakObjectHolder<SafeSboxMdsTableMemoryBlock> _WeakSboxMdsTableHolder;
        private static SafeSboxMdsTableMemoryBlock SboxMdsTable
        {
            get
            {
                WeakObjectHolder<SafeSboxMdsTableMemoryBlock> woh;
                if((woh = _WeakSboxMdsTableHolder) != null || Interlocked.CompareExchange(ref _WeakSboxMdsTableHolder, woh = new WeakObjectHolder<SafeSboxMdsTableMemoryBlock>(), null) == null)
                    return woh.Target;
                SafeSboxMdsTableMemoryBlock b;
                if((b = woh.Target) != null)
                    b.Close();
                return _WeakSboxMdsTableHolder.Target;
            }
        }
        private static unsafe UInt64 F32(UInt32* pt, UInt64 o)
        {
            UInt32 o0 = (UInt32)o, o1 = (UInt32)(o >> 32);
            UInt32 t, y0, y1;
            t = pt[(Byte)o0];
            o0 >>= 8;
            y0 = t >> 16 | t << 16;
            t = pt[(Byte)o0];
            o0 >>= 8;
            y0 ^= t >> 8;
            y0 ^= t << 24;
            y0 ^= pt[(Byte)o0];
            o0 >>= 8;
            t = pt[o0];
            y0 ^= t << 8;
            y0 ^= t >> 24;
            t = pt[(Byte)o1];
            o1 >>= 8;
            y1 = t >> 16 | t << 16;
            t = pt[(Byte)o1];
            o1 >>= 8;
            y1 ^= t >> 8;
            y1 ^= t << 24;
            y1 ^= pt[(Byte)o1];
            o1 >>= 8;
            t = pt[o1];
            y1 ^= t << 8;
            y1 ^= t >> 24;
            return y0 & 0x0000ffff | y1 & 0xffff0000 | (UInt64)(y1 & 0x0000ffff | y0 & 0xffff0000) << 32;
        }
        private static unsafe UInt64 F64(UInt32* pt, UInt64 o)
        {
            UInt64 o1 = o >> 32;
            o = (UInt32)o;
            UInt32 y0, y1;
            y0 = pt[(Byte)o];
            y1 = pt[(Byte)o1];
            o >>= 8;
            o1 >>= 8;
            y0 = pt[(Byte)o] ^ y0 >> 8 ^ y0 << 24;
            y1 = pt[(Byte)o1] ^ y1 >> 8 ^ y1 << 24;
            o >>= 8;
            o1 >>= 8;
            y0 = pt[(Byte)o] ^ y0 >> 8 ^ y0 << 24;
            y1 = pt[(Byte)o1] ^ y1 >> 8 ^ y1 << 24;
            o >>= 8;
            o1 >>= 8;
            y0 = pt[o] ^ y0 >> 8 ^ y0 << 24;
            y1 = pt[o1] ^ y1 >> 8 ^ y1 << 24;
            UInt64 y = (UInt64)y0 << 32 | y1;
            return y << 16 | y >> 48;
        }
        private State _State;
        private Buffer _Buffer;
        private unsafe UInt32* _SboxMdsTableAddress;
        private unsafe UInt64 Rho()
        {
            const UInt64 C1 = 0xbb67ae8584caa73bUL, C2 = 0x3c6ef372fe94f82bUL;
            UInt64 prev_a0 = _State.a0;
            if(IntPtr.Size == 4)
            {
                _State.a0 = _State.a1;
                _State.a1 = F32(_SboxMdsTableAddress, _State.a1 ^ _Buffer.b4) ^ _State.a2 ^ C1;
                _State.a2 = F32(_SboxMdsTableAddress, _State.a0 ^ _Buffer.b10 << 17 ^ _Buffer.b10 >> 47) ^ prev_a0 ^ C2;
            }
            else
            {
                _State.a1 = F64(_SboxMdsTableAddress, (_State.a0 = _State.a1) ^ _Buffer.b4) ^ _State.a2 ^ C1;
                _State.a2 = F64(_SboxMdsTableAddress, _State.a0 ^ _Buffer.b10 << 17 ^ _Buffer.b10 >> 47) ^ prev_a0 ^ C2;
            }
            return prev_a0;
        }
        private void Lambda(UInt64 prev_a0)
        {
            UInt64 prev_b15 = _Buffer.b15;
            _Buffer.b15 = _Buffer.b14;
            _Buffer.b14 = _Buffer.b13;
            _Buffer.b13 = _Buffer.b12;
            _Buffer.b12 = _Buffer.b11;
            _Buffer.b11 = _Buffer.b10;
            _Buffer.b10 = _Buffer.b9 ^ _Buffer.b14 << 32 ^ _Buffer.b14 >> 32;
            _Buffer.b9 = _Buffer.b8;
            _Buffer.b8 = _Buffer.b7;
            _Buffer.b7 = _Buffer.b6;
            _Buffer.b6 = _Buffer.b5;
            _Buffer.b5 = _Buffer.b4;
            _Buffer.b4 = _Buffer.b3 ^ _Buffer.b8;
            _Buffer.b3 = _Buffer.b2;
            _Buffer.b2 = _Buffer.b1;
            _Buffer.b1 = _Buffer.b0;
            _Buffer.b0 = prev_b15 ^ prev_a0;
        }
        private unsafe void GetA2(Byte* p)
        {
            if(IntPtr.Size == 4)
            {
                UInt32 v = (UInt32)_State.a2;
                p[7] = (Byte)v;
                v >>= 8;
                p[6] = (Byte)v;
                v >>= 8;
                p[5] = (Byte)v;
                v >>= 8;
                p[4] = (Byte)v;
                v = (UInt32)(_State.a2 >> 32);
                p[3] = (Byte)v;
                v >>= 8;
                p[2] = (Byte)v;
                v >>= 8;
                p[1] = (Byte)v;
                v >>= 8;
                p[0] = (Byte)v;
            }
            else
            {
                UInt64 v = _State.a2;
                p[7] = (Byte)v;
                v >>= 8;
                p[6] = (Byte)v;
                v >>= 8;
                p[5] = (Byte)v;
                v >>= 8;
                p[4] = (Byte)v;
                v >>= 8;
                p[3] = (Byte)v;
                v >>= 8;
                p[2] = (Byte)v;
                v >>= 8;
                p[1] = (Byte)v;
                v >>= 8;
                p[0] = (Byte)v;
            }
        }
        private static unsafe UInt64 GetUnit(Byte* p)
        {
            if(IntPtr.Size == 4)
            {
                UInt32 v0, v1;
                v0 = p[0];
                v0 <<= 8;
                v0 |= p[1];
                v0 <<= 8;
                v0 |= p[2];
                v0 <<= 8;
                v0 |= p[3];
                v1 = p[4];
                v1 <<= 8;
                v1 |= p[5];
                v1 <<= 8;
                v1 |= p[6];
                v1 <<= 8;
                v1 |= p[7];
                return (UInt64)v0 << 32 | v1;
            }
            else
            {
                UInt64 v;
                v = p[0];
                v <<= 8;
                v |= p[1];
                v <<= 8;
                v |= p[2];
                v <<= 8;
                v |= p[3];
                v <<= 8;
                v |= p[4];
                v <<= 8;
                v |= p[5];
                v <<= 8;
                v |= p[6];
                v <<= 8;
                v |= p[7];
                return v;
            }
        }
        private SafeSboxMdsTableMemoryBlock _SboxMdsTable;
        public unsafe Mugi(Byte[] key, Byte[] iv)
        {
            const UInt64 C0 = 0x6a09e667f3bcc908UL;
            if(key == null)
                throw new ArgumentNullException("key", EnvironmentExtension.GetResourceString("ArgumentNull_Array"));
            if(iv == null)
                throw new ArgumentNullException("iv", EnvironmentExtension.GetResourceString("ArgumentNull_Array"));
            if(key.Length < 16)
                throw new CryptographicException(EnvironmentExtension.GetResourceString("Cryptography_InvalidKeySize"));
            if(iv.Length < 16)
                throw new CryptographicException(EnvironmentExtension.GetResourceString("Cryptography_InvalidIVSize"));
            _SboxMdsTableAddress = (_SboxMdsTable = SboxMdsTable).Address;
            UInt64 k0, k1, i0, i1;
            fixed(Byte* p = &key[0])
            {
                k0 = GetUnit(&p[0]);
                k1 = GetUnit(&p[8]);
            }
            fixed(Byte* p = &iv[0])
            {
                i0 = GetUnit(&p[0]);
                i1 = GetUnit(&p[8]);
            }
            _State.a0 = k0;
            _State.a1 = k1;
            _State.a2 = k0 << 7 ^ k0 >> 57 ^ k1 << 57 ^ k1 >> 7 ^ C0;
            _Buffer = default(Buffer);
            Buffer b0;
            Rho();
            b0.b15 = _State.a0;
            Rho();
            b0.b14 = _State.a0;
            Rho();
            b0.b13 = _State.a0;
            Rho();
            b0.b12 = _State.a0;
            Rho();
            b0.b11 = _State.a0;
            Rho();
            b0.b10 = _State.a0;
            Rho();
            b0.b9 = _State.a0;
            Rho();
            b0.b8 = _State.a0;
            Rho();
            b0.b7 = _State.a0;
            Rho();
            b0.b6 = _State.a0;
            Rho();
            b0.b5 = _State.a0;
            Rho();
            b0.b4 = _State.a0;
            Rho();
            b0.b3 = _State.a0;
            Rho();
            b0.b2 = _State.a0;
            Rho();
            b0.b1 = _State.a0;
            Rho();
            b0.b0 = _State.a0;
            _State.a0 ^= i0;
            _State.a1 ^= i1;
            _State.a2 = _State.a2 ^ i0 << 7 ^ i0 >> 57 ^ i1 << 57 ^ i1 >> 7 ^ C0;
            UInt32 i = 16;
            do
                Rho();
            while(--i != 0);
            _Buffer = b0;
            i = 16;
            do
                Lambda(Rho());
            while(--i != 0);
        }
        public unsafe Byte[] NextRound(Int32 count)
        {
            if(count < 0)
                throw new ArgumentOutOfRangeException("count", EnvironmentExtension.GetResourceString("ArgumentOutOfRange_NegativeCount"));
            Byte[] r = new Byte[count * 8];
            if(count > 0)
                fixed(Byte* p = r)
                {
                    Int32 i = 0;
                    do
                    {
                        GetA2(&p[i * 8]);
                        Lambda(Rho());
                    }
                    while(++i < count);
                }
            return r;
        }
    }
}
