#pragma once

#include "registers.h"

#include <data/vector.h>
#include <time/timer.h>

namespace NLibrary {
    namespace NFlash {
        template<class TProvider>
        class TFlash {
        public:
            using TProviderPtr = TProvider*;

        public:
            static constexpr size_t SectorSize = 4096;
            static constexpr size_t PageSize = 256;

        public:
            enum class EAddressingMode {
                ThreeBytes,
                FourBytes
            };

        public:
            TFlash() = default;

            bool Init();
            void DeInit();

            void UpdateId();
            bool Write(uint32_t address, const uint8_t data[], size_t size);
            bool Read(uint32_t address, uint8_t data[], size_t size);

            bool Write(size_t sector, uint32_t address, NData::TBufferView data);
            bool Read(size_t sector, uint32_t address, NData::TBufferView data);

            void Erase(uint32_t address);
            void SectorErase(size_t sector);
            void FullErase();

            uint32_t GetId() const {
                return Id;
            }

        private:
            EAddressingMode AddressingMode = EAddressingMode::ThreeBytes;
            uint32_t Id = 0;

        private:
            TProviderPtr GetProvider() {
                return TProvider::GetInstance();
            }

            void TransmitByte(uint8_t data);
            uint8_t ReceiveByte();
            uint8_t ReadStatus();
            bool WriteStatus(uint16_t status);
            bool IsBusy();
            bool WaitForReady(NTime::TTime timeout = NTime::TTime::MilliSecond(300));

            void TransmitAddress(uint32_t address);
            void SimpleWrite(uint32_t address, const uint8_t data[], size_t size);

            void WriteEnabled(bool state);
            void SetSectorProtection(uint32_t address, bool state);
            bool GetSectorProtection(uint32_t address);
        };

        template<class TProvider>
        bool TFlash<TProvider>::Init() {
            GetProvider()->Init();
            NTime::TTime::Delay(NTime::TTime::MilliSecond(16));

            auto status = ReadStatus();
            if (status == 0xFF) {
                TransmitByte(CMD_RESET_ENABLE);
                NTime::TTime::Delay(NTime::TTime::MilliSecond(1));
                TransmitByte(CMD_RESET_MEMORY);
                NTime::TTime::Delay(NTime::TTime::MilliSecond(2));
            }
            ReadStatus(); // dummy read status

            UpdateId();

            if(GetId() == SST25VF064C_ID) {
                // Reset status register for off memory protection
                WriteStatus(0);
            } else if(GetId() == IS25LP256D_ID) {
                WriteStatus(0);
                TransmitByte(CMD_FOUR_BYTE_ADDRESSING);
            } else if(GetId() == MX25L25645G_ID) {
                TransmitByte(CMD_FOUR_BYTE_ADDRESSING);
            }

            return true;
        }

        template<class TProvider>
        void TFlash<TProvider>::UpdateId() {
            uint32_t result = 0;
            uint8_t bytes[3] = {0};

            {
                NLibrary::NLock::TLock lock(GetProvider()->GetLock());
                TransmitByte(CMD_READ_ID);
                GetProvider()->Receive(bytes, sizeof(bytes));
            }

            result |= (bytes[0] << 16);
            result |= (bytes[1] << 8);
            result |= bytes[0];

            Id = result;

            if (Id == IS25LP256D_ID || Id == MX25L25645G_ID) {
                AddressingMode = EAddressingMode::FourBytes;
            } else {
                AddressingMode = EAddressingMode::ThreeBytes;
            }
        }

        template<class TProvider>
        void TFlash<TProvider>::TransmitByte(uint8_t data) {
            GetProvider()->Transmit(&data, sizeof(data));
        }

        template<class TProvider>
        uint8_t TFlash<TProvider>::ReceiveByte() {
            uint8_t result = 0xFF;
            GetProvider()->Receive(&result, sizeof(result));
            return result;
        }

        template<class TProvider>
        uint8_t TFlash<TProvider>::ReadStatus() {
            NLibrary::NLock::TLock lock(GetProvider()->GetLock());
            TransmitByte(CMD_READ_STATUS);
            return ReceiveByte();
        }

        template<class TProvider>
        bool TFlash<TProvider>::IsBusy() {
            return (ReadStatus() & 0x01) == 0x01;
        }

        template<class TProvider>
        bool TFlash<TProvider>::WaitForReady(NTime::TTime timeout) {
            NTime::TTimer timer;
            timer.Start();

            while (!timer.HasExpired(timeout)) {
                if (!IsBusy()) {
                    return true;
                }
                NTime::TTime::Delay(NTime::TTime::MilliSecond(1));
            }

            return false;
        }

        template<class TProvider>
        bool TFlash<TProvider>::WriteStatus(uint16_t status) {
            if(IsBusy()) {
                return false;
            }

            WriteEnabled(true);

            uint8_t dataToWrite[] = {
                CMD_WRITE_STATUS,
                static_cast<uint8_t>(status & 0xFF),
                static_cast<uint8_t>((status >> 8) & 0xFF)
            };

            {
                auto provider = GetProvider();
                NLibrary::NLock::TLock lock(provider->GetLock());
                provider->Transmit(dataToWrite, sizeof(dataToWrite));
            }
            return WaitForReady(NTime::TTime::MilliSecond(86));
        }

        template<class TProvider>
        void TFlash<TProvider>::TransmitAddress(uint32_t address) {
            if (AddressingMode == EAddressingMode::FourBytes) {
                TransmitByte((address >> 24) & 0xFF);
            }
            TransmitByte((address >> 16) & 0xFF);
            TransmitByte((address >> 8) & 0xFF);
            TransmitByte((address >> 0) & 0xFF);
        }

        template<class TProvider>
        bool TFlash<TProvider>::Read(uint32_t address, uint8_t data[], size_t size) {
            NLock::TLock lock(GetProvider()->GetLock());

            TransmitByte(CMD_READ);
            TransmitAddress(address);
            TransmitByte(0x00); // dummy send byte

            return GetProvider()->Receive(data, size);
        }

        template<class TProvider>
        void TFlash<TProvider>::SimpleWrite(uint32_t address, const uint8_t data[], size_t size) {
            NLock::TLock lock(GetProvider()->GetLock());
            TransmitByte(CMD_WRITE);
            TransmitAddress(address);
            GetProvider()->Transmit(data, size);
        }

        template<class TProvider>
        bool TFlash<TProvider>::Write(uint32_t address, const uint8_t data[], size_t size) {
            if (size == 0) {
                return false;
            }

            for (size_t offset = 0; offset != size;) {
                size_t bytesBeforePageEnd = PageSize - ((address + offset) % PageSize);
                size_t unwrittenSize = size - offset;
                size_t blockSize = std::min(bytesBeforePageEnd, unwrittenSize);

                WriteEnabled(true);
                SetSectorProtection(address + offset, false);

                if (!WaitForReady()) {
                    return false;
                }

                WriteEnabled(true);
                SimpleWrite(address + offset, data + offset, blockSize);
                if (!WaitForReady()) {
                    return false;
                }

                offset += blockSize;
                if (offset >= (SectorSize - PageSize)) {
                    return false;
                }
            }
            return true;
        }

        template<class TProvider>
        void TFlash<TProvider>::WriteEnabled(bool state) {
            NLock::TLock lock(GetProvider()->GetLock());
            TransmitByte(state ? CMD_WRITE_ENABLE : CMD_WRITE_DISABLE);
        }

        template<class TProvider>
        void TFlash<TProvider>::SetSectorProtection(uint32_t address, bool state) {
            NLock::TLock lock(GetProvider()->GetLock());
            TransmitByte(state ? CMD_PROTECT_SECTOR : CMD_UNPROTECT_SECTOR);
            TransmitAddress(address);
        }

        template<class TProvider>
        bool TFlash<TProvider>::GetSectorProtection(uint32_t address) {
            NLock::TLock lock(GetProvider()->GetLock());
            TransmitByte(CMD_READ_SECTOR_PROTECTION);
            TransmitAddress(address);
            return static_cast<bool>(ReceiveByte());
        }

        template<class TProvider>
        void TFlash<TProvider>::Erase(uint32_t address) {
            WriteEnabled(true);
            SetSectorProtection(address, false);
            WaitForReady();
            WriteEnabled(true);
            {
                NLock::TLock lock(GetProvider()->GetLock());
                TransmitByte(CMD_ERASE_4K);
                TransmitAddress(address);
            }

            WaitForReady();
            NTime::TTime::Delay(NTime::TTime::MilliSecond(3));
        }

        template<class TProvider>
        void TFlash<TProvider>::FullErase() {
            WriteEnabled(true);
            WaitForReady();
            {
                NLock::TLock lock(GetProvider()->GetLock());
                TransmitByte(CMD_ERASE_CHIP);
            }
            WaitForReady();
        }

        template<class TProvider>
        void TFlash<TProvider>::DeInit() {
            Id = 0;
            AddressingMode = EAddressingMode::ThreeBytes;
            GetProvider()->DeInit();
        }

        template<class TProvider>
        bool TFlash<TProvider>::Write(size_t sector, uint32_t address, NData::TBufferView data) {
            uint32_t physicalAddress = sector * SectorSize + address;
            return Write(physicalAddress, data.data(), data.size());
        }

        template<class TProvider>
        bool TFlash<TProvider>::Read(size_t sector, uint32_t address, NData::TBufferView data) {
            uint32_t physicalAddress = sector * SectorSize + address;
            return Read(physicalAddress, data.data(), data.size());
        }

        template<class TProvider>
        void TFlash<TProvider>::SectorErase(size_t sector) {
            uint32_t physicalAddress = sector * SectorSize;
            Erase(physicalAddress);
        }
    }
}
