#pragma once
#include <cassert>
#include <cinttypes>
#include <cstddef>
#include <limits>
#include <map>
#include <memory>
#include <string>
#include <vector>

class AVCParsedNalu {
public:
    int forbidden_zero_bit = 1;
    int nal_ref_idc = 0;
    int nal_unit_type = 0;
    size_t slice_header_size = 1;
};

class AVCParsedSpsNalu : public AVCParsedNalu {
public:
    std::pair<int, int> resolution() const; // width,height

public:
    int profile_idc = 0;
    int constraint_set0_flag = 0;
    int constraint_set1_flag = 0;
    int constraint_set2_flag = 0;
    int constraint_set3_flag = 0;
    int constraint_set4_flag = 0;
    int constraint_set5_flag = 0;
    int reserved_zero_2bits = 0;
    int level_idc = 0;
    int seq_parameter_set_id = 0;
    int chroma_format_idc = 0;
    int separate_colour_plane_flag = 0;
    int bit_depth_luma_minus8 = 0;
    int bit_depth_chroma_minus8 = 0;
    int qpprime_y_zero_transform_bypass_flag = 0;
    int seq_scaling_matrix_present_flag = 0;
    int seq_scaling_list_present_flag[12]; // TODO confirm 12
    int UseDefaultScalingMatrix4x4Flag[6];
    int ScalingList4x4[6][16];
    int UseDefaultScalingMatrix8x8Flag[6];
    int ScalingList8x8[6][64];
    int log2_max_frame_num_minus4 = 0;
    int pic_order_cnt_type = 0;
    int log2_max_pic_order_cnt_lsb_minus4 = 0;
    int delta_pic_order_always_zero_flag = 0;
    int offset_for_non_ref_pic = 0;
    int offset_for_top_to_bottom_field = 0;
    int num_ref_frames_in_pic_order_cnt_cycle = 0;
    int offset_for_ref_frame[256];
    int max_num_ref_frames = 0;
    int gaps_in_frame_num_value_allowed_flag = 0;
    int pic_width_in_mbs_minus1 = 0;
    int pic_height_in_map_units_minus1 = 0;
    int frame_mbs_only_flag = 0;
    int mb_adaptive_frame_field_flag = 0;
    int direct_8x8_inference_flag = 0;
    int frame_cropping_flag = 0;
    int frame_crop_left_offset = 0;
    int frame_crop_right_offset = 0;
    int frame_crop_top_offset = 0;
    int frame_crop_bottom_offset = 0;
    // <VUI>
    int vui_parameters_present_flag = 0;
    int aspect_ratio_info_present_flag = 0;
    int aspect_ratio_idc = 0;
    int sar_width = 0;
    int sar_height = 0;
    int overscan_info_present_flag = 0;
    int overscan_appropriate_flag = 0;
    int video_signal_type_present_flag = 0;
    int video_format = 0;
    int video_full_range_flag = 0;
    int colour_description_present_flag = 0;
    int colour_primaries = 0;
    int transfer_characteristics = 0;
    int matrix_coefficients = 0;
    int chroma_loc_info_present_flag = 0;
    int chroma_sample_loc_type_top_field = 0;
    int chroma_sample_loc_type_bottom_field = 0;
    // </ VUI>
    // TODO HRD
};

class AVCParsedPpsNalu : public AVCParsedNalu {
public:
    int pic_parameter_set_id = 0;
    int seq_parameter_set_id = 0;
    int entropy_coding_mode_flag = 0;
    int bottom_field_pic_order_in_frame_present_flag = 0;
    int num_slice_groups_minus1 = 0;
    int slice_group_map_type = 0;
    int run_length_minus1[8];
    int top_left[8];
    int bottom_right[8];
    int slice_group_change_direction_flag = 0;
    int slice_group_change_rate_minus1 = 0;
    int pic_size_in_map_units_minus1 = 0;
    int slice_group_id[255]; // TODO confirm this!
    int num_ref_idx_l0_default_active_minus1 = 0;
    int num_ref_idx_l1_default_active_minus1 = 0;
    int weighted_pred_flag = 0;
    int weighted_bipred_idc = 0;
    int pic_init_qp_minus26 = 0;
    int pic_init_qs_minus26 = 0;
    int chroma_qp_index_offset = 0;
    int deblocking_filter_control_present_flag = 0;
    int constrained_intra_pred_flag = 0;
    int redundant_pic_cnt_present_flag = 0;
    int transform_8x8_mode_flag = 0;
    int pic_scaling_matrix_present_flag = 0;
    int pic_scaling_list_present_flag[12];
    int UseDefaultScalingMatrix4x4Flag[6];
    int ScalingList4x4[6][16];
    int UseDefaultScalingMatrix8x8Flag[6];
    int ScalingList8x8[6][64];
    int second_chroma_qp_index_offset = 0;
};

class AVCParsedVclNalu : public AVCParsedNalu {
public:
    // P, B, I, SP, SI
    // The values are part of the spec. Do not modify
    enum {
        SliceType_P = 0,
        SliceType_B = 1,
        SliceType_I = 2,
        SliceType_SP = 3,
        SliceType_SI = 4,
    };

    struct PredWeightTable {
        int luma_log2_weight_denom = 0;
        int chroma_log2_weight_denom = 0;
        int luma_weight_l0_flag = 0;
        int luma_weight_l0[32];
        int luma_offset_l0[32];
        int chroma_weight_l0_flag = 0;
        int chroma_weight_l0[32][2];
        int chroma_offset_l0[32][2];
        int luma_weight_l1_flag = 0;
        int luma_weight_l1[32];
        int luma_offset_l1[32];
        int chroma_weight_l1_flag = 0;
        int chroma_weight_l1[32][2];
        int chroma_offset_l1[32][2];
    };

    struct DecRefPicMarking {
        int no_output_of_prior_pics_flag = 0;
        int long_term_reference_flag = 0;
        int adaptive_ref_pic_marking_mode_flag = 0;
        int difference_of_pic_nums_minus1 = 0;
        int long_term_pic_num = 0;
        int long_term_frame_idx = 0;
        int max_long_term_frame_idx_plus1 = 0;
    };

    int first_mb_in_slice = 0;
    int slice_type = 0;
    int pic_parameter_set_id = 0;
    int colour_plane_id = 0;
    int frame_num = 0;
    int field_pic_flag = 0;
    int bottom_field_flag = 0;
    int idr_pic_id = 0;
    int pic_order_cnt_lsb = 0;
    int delta_pic_order_cnt_bottom = 0;
    int delta_pic_order_cnt[2];
    int redundant_pic_cnt = 0;
    int direct_spatial_mv_pred_flag = 0;
    int num_ref_idx_active_override_flag = 0;
    int num_ref_idx_l0_active_minus1 = 0;
    int num_ref_idx_l1_active_minus1 = 0;
    int ref_pic_list_modification_flag_l0 = 0;
    int modification_of_pic_nums_idc = 0;
    int abs_diff_pic_num_minus1 = 0;
    int long_term_pic_num = 0;
    int ref_pic_list_modification_flag_l1 = 0;
    int cabac_init_idc = 0;
    int slice_qp_delta = 0;
    int sp_for_switch_flag = 0;
    int slice_qs_delta = 0;
    int disable_deblocking_filter_idc = 0;
    int slice_alpha_c0_offset_div2 = 0;
    int slice_beta_offset_div2 = 0;
    int slice_group_change_cycle = 0;
    PredWeightTable predWeightTable;
    DecRefPicMarking decRefPicMarking;
};

class AVCParser {
private:
    int m_currentSpsId = -1;
    int m_currentPpsId = -1;
    std::map<int, AVCParsedSpsNalu> m_sps;
    std::map<int, AVCParsedPpsNalu> m_pps;

public:
    enum {
        NalTypeSlice = 1,
        NalTypeIDR = 5,
        NalTypeSEI = 6,
        NalTypeSPS = 7,
        NalTypePPS = 8,
        NalTypeAUD = 9,
    };
    AVCParser() = default;
    // TODO return a shared pointer to the parsed NALU
    // returns the slice header size on vcl nalus otherwise 0.
    size_t parseNalu(const uint8_t* data, size_t size);
    size_t parseNalu(const std::vector<uint8_t>& data) { return parseNalu(data.data(), data.size()); }
    const AVCParsedSpsNalu& currentSps() const { return m_sps.at(m_currentSpsId); }
    const AVCParsedPpsNalu& currentPps() const { return m_pps.at(m_currentPpsId); }

public:
    // AnnexB utils
    static size_t findStartCode(const uint8_t* data, size_t size, size_t* len);
    static size_t findStartCodeIncremental(const uint8_t* data, size_t size, size_t prevSize, size_t* len);
    static std::vector<uint8_t> toAnnexB(const std::vector<uint8_t>& frame, const std::vector<uint8_t>& extradata);

public:
    // Extradata utils
    struct Extradata {
        uint8_t version = 0;
        uint8_t profile = 0;
        uint8_t compatibility = 0;
        uint8_t level = 0;
        uint8_t lengthSize = 0;
        std::vector<std::vector<uint8_t>> sps;
        std::vector<std::vector<uint8_t>> pps;
    };
    static Extradata parseExtradata(const std::vector<uint8_t>& extradata);
    static std::vector<uint8_t> getExtradataFromFrame(const std::vector<uint8_t>& frame);

public:
    static AVCParsedSpsNalu parseSps(const uint8_t* data, size_t size);
    static AVCParsedPpsNalu parsePps(const uint8_t* data, size_t size, const std::map<int, AVCParsedSpsNalu>& sps);
    static AVCParsedVclNalu parseVclSliceHeader(const uint8_t* data, size_t size, const std::map<int, AVCParsedSpsNalu>& sps, const std::map<int, AVCParsedPpsNalu>& pps);

    static AVCParsedSpsNalu parseSps(const std::vector<uint8_t>& data) { return parseSps(data.data(), data.size()); }
    static AVCParsedPpsNalu parsePps(const std::vector<uint8_t>& data, const std::map<int, AVCParsedSpsNalu>& sps) { return parsePps(data.data(), data.size(), sps); }
    static AVCParsedVclNalu parseVclSliceHeader(const std::vector<uint8_t>& data, const std::map<int, AVCParsedSpsNalu>& sps, const std::map<int, AVCParsedPpsNalu>& pps) { return parseVclSliceHeader(data.data(), data.size(), sps, pps); }
};

class NalBuffer : public std::vector<uint8_t> {
private:
    uint32_t m_nalu_mask = 0;

public:
    static const uint32_t IdSampleMask = (1 << AVCParser::NalTypeIDR);
    static const uint32_t VclSampleMask = (1 << AVCParser::NalTypeIDR) | (1 << AVCParser::NalTypeSlice);
    static const uint32_t SyncSampleMask = (1 << AVCParser::NalTypeIDR) | (1 << AVCParser::NalTypeSPS) | (1 << AVCParser::NalTypePPS);

    void clear() { m_nalu_mask = 0, std::vector<uint8_t>::clear(); }
    size_t lengthSize() const { return 4; }
    bool isSyncSample() const { return SyncSampleMask == (SyncSampleMask & m_nalu_mask); }
    bool isIDRSample() const { return IdSampleMask & m_nalu_mask; }
    bool isVideoCodingLayer() const { return VclSampleMask & m_nalu_mask; }
    inline void addNalu(const std::vector<uint8_t>& nalu) { addNalu(nalu.data(), nalu.size()); }
    void addNalu(const uint8_t* data, size_t size)
    {
        assert(size);
        int8_t nal_unit_type = data[0] & 0x1F;
        if (AVCParser::NalTypeAUD != nal_unit_type) {
            reserve(4 + size);
            push_back(static_cast<uint8_t>(size >> 24));
            push_back(static_cast<uint8_t>(size >> 16));
            push_back(static_cast<uint8_t>(size >> 8));
            push_back(static_cast<uint8_t>(size >> 0));
            insert(end(), data, data + size);
            m_nalu_mask |= (1 << nal_unit_type);
        }
    }
};

class NalIterator {
public:
    class Nal {
    public:
        const uint8_t* data = nullptr;
        size_t size = 0;
        uint8_t type = 0;

    private:
        friend NalIterator;
        size_t remain = 0;
        size_t lengthSize = 0;
        Nal(const uint8_t* data, size_t size, size_t lengthSize)
            : data(data)
            , size(0) // yes, is is correct
            , remain(size) // yes, is is correct
            , lengthSize(lengthSize)
        {
            ++(*this);
        }

    public:
        const Nal& operator*() const { return (*this); }
        bool operator!=(const Nal& that) const { return this->data != that.data; }
        void operator++()
        {
            do { // Handle 0 size NALUs. They are illegal, be best to be liberal
                remain -= size, data += size, size = 0;
                for (size_t i = lengthSize; i && remain; ++data, --remain, --i) {
                    size = (size << 8) | (*data);
                }
            } while (!size && remain);

            if (0 != size && size <= remain) {
                type = data[0] & 0x1f;
            } else {
                data = nullptr;
                size = lengthSize = remain = type = 0;
            }
        }
    };

private:
    const uint8_t* m_data = nullptr;
    size_t m_size = 0;
    size_t m_lengthSize = 0;

public:
    NalIterator() = default;
    NalIterator(const uint8_t* data, size_t size, size_t lengthSize = 4)
        : m_data(data)
        , m_size(size)
        , m_lengthSize(lengthSize)
    {
    }

    NalIterator(const std::vector<uint8_t>& data, size_t lengthSize = 4)
        : NalIterator(data.data(), data.size(), lengthSize)
    {
    }

    const Nal end() const { return Nal(nullptr, 0, 0); }
    const Nal begin() const { return Nal(m_data, m_size, m_lengthSize); }
};
