package codec

import (
	"bytes"
	"fmt"
	"io"

	"github.com/dgryski/go-bitstream"
)

// AudioSpecificConfig is the AAC sequence Header from ISO/IEC 14496-3
type AudioSpecificConfig struct {
	// AudioObjectType is basically MPEG-2 AAC profile like LC
	AudioObjectType int // 5 or 11 bits
	// Preconfigured sample rate index
	SamplingFrequencyIndex int // 4 unsigned bits
	// The computed sample frequency
	SamplingFrequency         int // 24 unsigned bits
	ChannelConfigurationIndex int // 4 unsigned bits

	FrameLengthFlag bool
}

const (
	aotBitsBase  = 5
	aotBitsExtra = 6
	AotAacLc     = 2
	AotSbr       = 5
	AotPs        = 29
	AotEscape    = (1 << aotBitsBase) - 1
	AotMax       = (1 << aotBitsBase) + (1 << aotBitsExtra) - 1
)

const samplingFrequencyIndexBits = 4
const samplingFrequencyIndexEscape = (1 << samplingFrequencyIndexBits) - 1
const samplingFrequencyEscapeBits = 24

const channelConfigurationBits = 4

func readAot(r *bitstream.BitReader) (int, error) {
	val, err := r.ReadBits(aotBitsBase)
	if err != nil {
		return 0, err
	}
	if val != AotEscape {
		return int(val), nil
	}
	val, err = r.ReadBits(aotBitsBase)
	if err != nil {
		return 0, err
	}
	return int(AotEscape + 1 + val), nil
}

func readSamplingFrequency(r *bitstream.BitReader) (index int, freq int, err error) {
	sfi, err := r.ReadBits(samplingFrequencyIndexBits)
	if err != nil {
		return 0, 0, err
	}
	index = int(sfi)

	if index == samplingFrequencyIndexEscape {
		sf, err := r.ReadBits(samplingFrequencyEscapeBits)
		if err != nil {
			return 0, 0, err
		}
		freq = int(sf)
	} else if index >= len(Mpeg4SamplingFrequencyTable) {
		return 0, 0, fmt.Errorf("Unknown sampling frequency index %d", index)
	} else {
		freq = Mpeg4SamplingFrequencyTable[index]
	}

	if freq == 0 {
		return 0, 0, fmt.Errorf("Sampling frequency is 0")
	}
	return index, freq, nil
}

var Mpeg4SamplingFrequencyTable = []int{96000, 88200, 64000, 48000, 44100, 32000, 24000, 22050, 16000, 12000, 11025, 8000, 7350}

// Read an ASC from a bytes.Reader
func ReadASC(br *bytes.Reader) (*AudioSpecificConfig, error) {
	asc := AudioSpecificConfig{}
	r := bitstream.NewReader(br)

	aot, err := readAot(r)
	if err != nil {
		return nil, err
	}
	asc.AudioObjectType = aot

	samplingFrequencyIndex, samplingFrequency, err := readSamplingFrequency(r)
	if err != nil {
		return nil, err
	}
	asc.SamplingFrequencyIndex = samplingFrequencyIndex
	asc.SamplingFrequency = samplingFrequency

	chanConfig, err := r.ReadBits(channelConfigurationBits)
	if err != nil {
		return nil, err
	}
	asc.ChannelConfigurationIndex = int(chanConfig)

	if asc.AudioObjectType == AotSbr || asc.AudioObjectType == AotPs {
		_, _, err = readSamplingFrequency(r)
		if err != nil {
			return nil, err
		}
		coreAot, err := readAot(r)
		if err != nil {
			return nil, err
		}
		if coreAot == 22 {
			_, err = r.ReadBits(channelConfigurationBits)
			if err != nil {
				return nil, err
			}
		}
	}

	switch asc.AudioObjectType {
	case 1, 2, 3, 4, 5, 6, 7, 17, 19, 20, 21, 22, 23, 29:
		bit, err := r.ReadBit()
		if err != nil {
			return nil, err
		}
		asc.FrameLengthFlag = bool(bit)
	}

	// TODO read sync extensions

	return &asc, nil
}

func WriteASC(w io.Writer, asc *AudioSpecificConfig) error {
	b := bitstream.NewWriter(w)
	var err error
	if asc.AudioObjectType < 0 || asc.AudioObjectType == AotMax || asc.AudioObjectType == AotEscape {
		return fmt.Errorf("AOT %d is inexpressible", asc.AudioObjectType)
	}
	if asc.AudioObjectType > 4 {
		return fmt.Errorf("This module cannot codify exotic AOTs (aot %d)", asc.AudioObjectType)
	}
	if asc.AudioObjectType > AotEscape {
		err = b.WriteBits(AotEscape, aotBitsBase)
		if err != nil {
			return err
		}
		err = b.WriteBits(uint64(asc.AudioObjectType)-AotEscape-1, aotBitsExtra)
		if err != nil {
			return err
		}
	} else {
		err = b.WriteBits(uint64(asc.AudioObjectType), aotBitsBase)
		if err != nil {
			return err
		}
	}

	err = b.WriteBits(uint64(asc.SamplingFrequencyIndex), samplingFrequencyIndexBits)
	if err != nil {
		return err
	}
	if asc.SamplingFrequencyIndex == samplingFrequencyIndexEscape {
		err = b.WriteBits(uint64(asc.SamplingFrequency), samplingFrequencyEscapeBits)
		if err != nil {
			return err
		}
	}

	err = b.WriteBits(uint64(asc.ChannelConfigurationIndex), channelConfigurationBits)
	if err != nil {
		return err
	}

	err = writeGASpecificConfig(b, asc)
	if err != nil {
		return err
	}

	return b.Flush(bitstream.Zero)
}

func writeGASpecificConfig(b *bitstream.BitWriter, asc *AudioSpecificConfig) error {
	if asc.ChannelConfigurationIndex == 0 {
		return fmt.Errorf("The module cannot codify free form channel configurations")
	}
	err := b.WriteBit(bitstream.Bit(asc.FrameLengthFlag))
	if err != nil {
		return err
	}
	err = b.WriteBits(0, 2) // dependsOnCoreCoder, extensionFlag
	return err
}
