package parse

import (
	"bytes"
	"encoding/binary"
	"errors"
	"fmt"
	"io"
	"io/ioutil"

	"github.com/dgryski/dgolzo"
)

var RecordScannerEOF error = errors.New("end of records")

type RecordScanner struct {
	// immutable state
	fileData io.Reader // source byte stream
	fileName string
	// scanner state
	// file-level
	fh         *FileHeader
	stats      *StatRecord // summary stats block at top of file
	extensions map[uint16]*ExtensionMap
	blockIdx   uint32 // which block we're reading (0-indexed)
	// block-level
	dbh       *DataBlockHeader
	blockData io.Reader
	recordIdx uint32 // which record we're reading (0-indexed)
	// record-level
	rh         *RecordHeader
	recordData io.Reader
	record     *Record
}

func (rs *RecordScanner) Init(src io.Reader, filename string) error {
	var err error
	rs.blockIdx = 0
	rs.fileName = filename
	rs.recordIdx = 0
	rs.fileData = src
	if err = rs.readFileHeader(); err != nil {
		return fmt.Errorf("read file header err: %v", err)
	}

	if err = rs.readStatRecord(); err != nil {
		return fmt.Errorf("read stat record err: %v", err)
	}
	if err = rs.nextBlock(); err != nil {
		return fmt.Errorf("read block header err: %v", err)
	}
	rs.extensions = make(map[uint16]*ExtensionMap)
	return nil
}

// Applies a function to every record in the scanner. Returns an
// immediate error if one ever occurs.
func (rs *RecordScanner) Map(f func(*Record) error) error {
	records := make(chan *Record, 1)
	errors := make(chan error, 1)
	done := make(chan struct{})
	defer close(records)
	defer close(errors)
	go func() {
		for {
			rec, err := rs.Scan()
			if err == RecordScannerEOF {
				close(done)
				return
			} else if err != nil {
				errors <- err
			} else {
				records <- rec
			}
		}
	}()

	var (
		r   *Record
		err error
	)

	for {
		select {
		case <-done:
			return nil
		case r = <-records:
			err = f(r)
			if err != nil {
				return err
			}
		case err = <-errors:
			return err
		}
	}

	return nil
}

func (rs *RecordScanner) readFileHeader() error {
	h := &FileHeader{}
	if err := binary.Read(rs.fileData, binary.LittleEndian, h); err != nil {
		return err
	}
	if !h.validate() {
		return errors.New("Corrupted file header")
	}
	rs.fh = h
	return nil
}

func (rs *RecordScanner) readStatRecord() error {
	rec := &StatRecord{}
	if err := binary.Read(rs.fileData, binary.LittleEndian, rec); err != nil {
		return err
	}
	rs.stats = rec
	return nil
}

func (rs *RecordScanner) readDataBlockHeader() error {
	dbh := &DataBlockHeader{}
	if err := binary.Read(rs.fileData, binary.LittleEndian, dbh); err != nil {
		return err
	}
	rs.dbh = dbh
	return nil
}

func (rs *RecordScanner) loadBlock() error {
	rs.blockIdx += 1
	rs.recordIdx = 0
	buf := make([]byte, int(rs.dbh.Size))
	if _, err := io.ReadFull(rs.fileData, buf); err != nil {
		return fmt.Errorf("buffer fill err: %v", err)
	}

	decompressed := make([]byte, 20*int(rs.dbh.Size))
	comp, err := lzo.NewCompressor(lzo.Lzo1x_1)
	if err != nil {
		return fmt.Errorf("make decomp err: %v", err)
	}

	n, err := comp.Decompress(buf, decompressed)
	if err != nil {
		return fmt.Errorf("decompress err: %v (%d bytes decompressed, room for %d)", err, n, int(rs.dbh.Size)*20)
	}

	rs.blockData = bytes.NewReader(decompressed[:n])
	return nil
}

func (rs *RecordScanner) nextBlock() error {
	if rs.blockIdx >= rs.fh.NumBlocks {
		return RecordScannerEOF
	}
	var err error
	if err = rs.readDataBlockHeader(); err != nil {
		return fmt.Errorf("read block header err: %v", err)
	}
	if err = rs.loadBlock(); err != nil {
		return fmt.Errorf("load block err: %v", err)
	}
	// empty recorddata
	rs.recordData = bytes.NewReader(make([]byte, 0))
	return nil
}

func (rs *RecordScanner) Scan() (*Record, error) {
	var err error
	for {
		// load bytes for next record
		if err = rs.nextRecord(); err != nil {
			return nil, err
		}
		// parse bytes
		if err = rs.readRecord(); err != nil {
			return nil, err
		}
		// emit result
		if rs.record != nil {
			return rs.record, nil
		}
	}
}

func (rs *RecordScanner) readRecordHeader() error {
	rh := &RecordHeader{}
	if err := binary.Read(rs.blockData, binary.LittleEndian, rh); err != nil {
		return err
	}
	rs.rh = rh
	return nil
}

// load rs.recordData
func (rs *RecordScanner) nextRecord() error {
	var err error
	if rs.recordIdx >= rs.dbh.NumRecords {
		err = rs.nextBlock()
		if err != nil {
			return err
		}
	}
	rs.readRecordHeader()
	rs.recordData = io.LimitReader(rs.blockData, int64(rs.rh.Size-4))
	rs.recordIdx += 1
	return nil
}

func (rs *RecordScanner) readRecord() error {
	switch rs.rh.Type {
	case ExtensionMapType:
		m, err := ReadExtensionMap(rs.recordData, rs.rh)
		if err != nil {
			return fmt.Errorf("read ext map err: %v", err)
		}
		rs.extensions[m.MapID] = m
		rs.record = nil
	case CommonRecordType:
		cr, err := ReadCommonRecord(rs.recordData, rs.extensions)
		if err != nil {
			return fmt.Errorf("read comrec err: %v", err)
		}
		rs.record = NewRecord(cr)
		err = rs.record.ProcessExtensions(rs.recordData)
		if err != nil {
			return fmt.Errorf("process ext err: %v", err)
		}
	default:
		// unknown record type
		rs.record = nil
	}
	// if rs.dbh.NumRecords == 1 {
	// 	fmt.Printf("Single record. Header: %+v", rs.rh)
	// }

	// skip extra bytes
	io.Copy(ioutil.Discard, rs.recordData)
	return nil
}
