package s3

import (
	"bufio"
	"bytes"
	"encoding/csv"
	"fmt"
	"io"
)

// TsvMerger merges multiple TSV files into one, fixing up the different number/order of columns in source TSVs.
type TsvMerger struct {
	columns    []string
	columnsMap map[string]int

	writer           *countingWriter
	compressedWriter *countingWriter
}

func NewTsvMerger(columns []string, alg CompressionAlg, innerWriter io.WriteCloser) (*TsvMerger, error) {
	compressedWriter := &countingWriter{writer: innerWriter}
	writer, err := NewCompressedWriter(alg, compressedWriter)
	if err != nil {
		return nil, err
	}

	columnsMap := make(map[string]int, len(columns))
	for i, c := range columns {
		columnsMap[c] = i
	}

	csvW := csv.NewWriter(writer)
	csvW.Comma = '\t'
	err = csvW.Write(columns)
	if err != nil {
		return nil, fmt.Errorf("writing TSV header failed: %v", errToString(err))
	}
	csvW.Flush()

	return &TsvMerger{
		columns:          columns,
		columnsMap:       columnsMap,
		writer:           writer,
		compressedWriter: compressedWriter,
	}, nil
}

func (m *TsvMerger) Append(columns []string, alg CompressionAlg, r io.Reader) error {
	compressedR, err := NewCompressedReader(alg, r)
	if err != nil {
		return nil
	}

	sameColumns := stringSlicesEqual(columns, m.columns)
	if sameColumns {
		return m.appendFastPath(compressedR)
	}

	csvR := csv.NewReader(compressedR)
	csvR.Comma = '\t'
	csvR.ReuseRecord = true

	header, err := csvR.Read()
	if err != nil {
		return err
	}
	if !stringSlicesEqual(columns, header) {
		return fmt.Errorf("different columns: expected %v, got %v", columns, header)
	}

	writeMap := make([]int, len(columns))
	for fromIdx, c := range columns {
		toIdx, ok := m.columnsMap[c]
		if !ok {
			return fmt.Errorf("column %s not included in write columns", c)
		}
		writeMap[fromIdx] = toIdx
	}
	writeRecord := make([]string, len(m.columns))

	csvW := csv.NewWriter(m.writer)
	csvW.Comma = '\t'
	defer csvW.Flush()
	for {
		record, err := csvR.Read()
		if err != nil {
			if err == io.EOF {
				return nil
			}
			return err
		}

		if len(record) != len(columns) {
			return fmt.Errorf("strange record length: %v (length must be %d)", record, len(columns))
		}
		for i := 0; i < len(record); i++ {
			writeRecord[writeMap[i]] = record[i]
		}

		err = csvW.Write(writeRecord)
		if err != nil {
			return err
		}
	}
}

func (m *TsvMerger) appendFastPath(r io.Reader) error {
	bufR := bufio.NewReader(r)

	// Validate that the file is indeed a TSV with the same columns.
	headerBytes, err := bufR.ReadBytes('\n')
	if err != nil && err != io.EOF {
		return err
	}
	csvR := csv.NewReader(bytes.NewReader(headerBytes))
	csvR.Comma = '\t'
	csvR.ReuseRecord = true
	header, err := csvR.Read()
	if err != nil {
		return err
	}
	if !stringSlicesEqual(m.columns, header) {
		return fmt.Errorf("different columns in fast path: expected %v, got %v", m.columns, header)
	}

	_, err = io.Copy(m.writer, bufR)
	return err
}

func (m *TsvMerger) Close() error {
	return m.writer.Close()
}

func (m *TsvMerger) TotalBytes() int64 {
	return m.writer.written
}

func (m *TsvMerger) CompressedBytes() int64 {
	return m.compressedWriter.written
}

func stringSlicesEqual(s1 []string, s2 []string) bool {
	if len(s1) != len(s2) {
		return false
	}
	for i := 0; i < len(s1); i++ {
		if s1[i] != s2[i] {
			return false
		}
	}
	return true
}
