package loggers

import (
	"bufio"
	"errors"
	"fmt"
	"io"
	"io/ioutil"
	"os"
	"path"
	"path/filepath"
	"sort"
	"time"

	"code.justin.tv/sse/malachai/pkg/config"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/firehose"
	"github.com/aws/aws-sdk-go/service/firehose/firehoseiface"
	"github.com/boltdb/bolt"
	"github.com/natefinch/lumberjack"
	"github.com/sirupsen/logrus"
)

const (
	// If buffer size reaches this limit, we will send the logs to fireshose.
	oneMB       = 1024 * 1024 //1MB
	logFileName = "service.log"

	//firehose limits
	putRecordBatchRecordLimit = 500
	putRecordBatchSizeLimit   = 4 * oneMB

	defaultServiceName            = "unknownServiceName"
	defaultFirehoseLogDir         = "/var/log/malachai"
	defaultFirehoseReportInterval = 60
	defaultMaxLogFileSize         = 100 * oneMB

	// s3 data to athena table column mappig is case sensitive.
	// use following const when logging to firehose withFields.

	//LogFieldEvent maps to `event` field in athena
	LogFieldEvent = "event"
	//LogFieldIPAddress maps `ip_address` field in athena
	LogFieldIPAddress = "ip_address"
	//LogFieldServiceName maps `service_name` field in athena
	LogFieldServiceName = "service_name"
	//LogFieldServiceID maps `service_id` field in athena
	LogFieldServiceID = "service_id"
	//LogFieldHostname maps to `hostname` field in athena
	LogFieldHostname = "hostname"
	//LogFieldLDAPUser maps to `ldap_user` field in athena
	LogFieldLDAPUser = "ldap_user"
	//LogFieldAWSPrincipal maps to `aws_prinicipal` field in athena
	LogFieldAWSPrincipal = "aws_principal"
	//LogFieldCallerID maps to  `caller_id` field in athena
	LogFieldCallerID = "caller_id"
	//LogFieldCallerName maps to  `caller_name` field in athena
	LogFieldCallerName = "caller_name"
	//LogFieldCalleeID maps to `callee_id` field in athena
	LogFieldCalleeID = "callee_id"
	//LogFieldCalleeName maps to `callee_name` field in athena
	LogFieldCalleeName = "callee_name"
	//LogFieldJWTID maps to `jwt_id` field in athena
	LogFieldJWTID = "jwt_id"
	//LogFieldStatusCode maps to `status_code` field in athena
	LogFieldStatusCode = "status_code"
	//LogFieldURLPath maps to `url_path` field in athena
	LogFieldURLPath = "url_path"
)

// LogRotatorAPI so that we can mock out github.com/natefinch/lumberjack
type LogRotatorAPI interface {
	Write(p []byte) (n int, err error)
	Close() (err error)
	Rotate() (err error)
}

// FirehoseLogger logs event to Kinesis firehose.
// the logs will be in json format and will be stored in s3.
// Eventually these logs would be queryable via athena
type FirehoseLogger struct {
	// FieldLogger is used to write the json formatted lines locally to be processed by firehose client
	logrus.FieldLogger
	fh         firehoseiface.FirehoseAPI
	logRotator LogRotatorAPI
	// stderr logger is used to write error logs generated by firehose
	logger            logrus.FieldLogger
	config            *FirehoseLoggerConfig
	stopReportingChan chan struct{}

	bolt *bolt.DB
}

// FirehoseLoggerConfig config related to firehose.
type FirehoseLoggerConfig struct {
	Environment                string
	FirehoseDeliveryStreamName string
	awsConfig                  *aws.Config
	ServiceName                string
	LogLevel                   logrus.Level

	//so that we can mock in tests accordingly
	LogDir         string
	maxLogFileSize int
	reportInterval time.Duration
}

func (cfg *FirehoseLoggerConfig) fillDefaults() (err error) {
	res, err := config.GetResources(cfg.Environment)
	if err != nil {
		return
	}

	if cfg.reportInterval == 0 {
		cfg.reportInterval = time.Duration(defaultFirehoseReportInterval) * time.Second
	}

	if cfg.FirehoseDeliveryStreamName == "" {
		cfg.FirehoseDeliveryStreamName = res.FirehoseDeliveryStreamName
	}

	if cfg.LogDir == "" {
		cfg.LogDir = defaultFirehoseLogDir
	}

	// Create a subdirectory with service name to avoid mixing log files.
	cfg.LogDir = filepath.Join(cfg.LogDir, cfg.ServiceName)

	if cfg.maxLogFileSize == 0 {
		cfg.maxLogFileSize = defaultMaxLogFileSize
	}

	cfg.LogLevel = logrus.InfoLevel

	cfg.awsConfig = config.AWSConfig(res.Region, res.FirehoseLoggerRoleArn)
	return
}

// NewFirehoseLogger returns a new instance of FirehoseLogger,
// it writes logs to local disk at a specified location config.Logdir. Errors
// generated from trying to write the log to disk will be emitted to stderr.
func NewFirehoseLogger(cfg *FirehoseLoggerConfig, db *bolt.DB) (fhl *FirehoseLogger, err error) {
	if cfg == nil {
		cfg = &FirehoseLoggerConfig{}
	}
	err = cfg.fillDefaults()
	if err != nil {
		return
	}

	fhl, err = firehoseLogger(cfg, nil, db)
	if err != nil {
		return
	}
	return
}

// Close closes file pointers and stops goroutines
func (fhl *FirehoseLogger) Close() {
	fhl.stopReportingEventsToFirehose()
	err := fhl.logRotator.Close()
	if err != nil {
		// no need to report this error
		err = nil
	}
}

func firehoseLogger(cfg *FirehoseLoggerConfig, logRotator LogRotatorAPI, db *bolt.DB) (fhl *FirehoseLogger, err error) {
	if !pathExists(cfg.LogDir) {
		err = os.MkdirAll(cfg.LogDir, 0755)
		if err != nil {
			return
		}
	}

	if logRotator == nil {
		logRotator = &lumberjack.Logger{
			Filename: path.Join(cfg.LogDir, logFileName),
			MaxSize:  cfg.maxLogFileSize,
			MaxAge:   7, // keep around 7 days of logs
		}
	}

	fieldLogger := &logrus.Logger{
		Out:       logRotator,
		Formatter: &logrus.JSONFormatter{},
		Level:     cfg.LogLevel,
	}

	stderrLogger := &logrus.Logger{
		Out:       os.Stderr,
		Formatter: new(logrus.JSONFormatter),
		Hooks:     make(logrus.LevelHooks),
		Level:     cfg.LogLevel,
	}

	sess := session.Must(session.NewSession(cfg.awsConfig))

	if cfg.awsConfig != nil {
		stderrLogger.Debugf("configured with region: %s", aws.StringValue(cfg.awsConfig.Region))
	} else {
		stderrLogger.Debug("no aws config sent")
	}

	err = initBoltDBPointerBucket(db)
	if err != nil {
		return
	}

	fhl = &FirehoseLogger{
		fh:                firehose.New(sess),
		config:            cfg,
		FieldLogger:       fieldLogger,
		logRotator:        logRotator,
		logger:            stderrLogger,
		stopReportingChan: make(chan struct{}, 1),
		bolt:              db,
	}

	go fhl.reportEventsToFirehose()

	return
}

func (fhl *FirehoseLogger) stopReportingEventsToFirehose() {
	fhl.stopReportingChan <- struct{}{}
	close(fhl.stopReportingChan)
	return
}

func (fhl *FirehoseLogger) reportEventsToFirehose() {
	for {
		reportInterval := time.NewTimer(fhl.config.reportInterval).C
		select {
		case <-reportInterval:
			err := fhl.processLogs()
			if err != nil {
				// if for whatever reason we were not able to process the file, it should
				// still be there.
				fhl.logger.Errorf("failed to processed rotated logs. err: '%s', willl retry.", err.Error())
			}
		case <-fhl.stopReportingChan:
			return
		}
	}
}

func (fhl *FirehoseLogger) processLogs() (err error) {
	fhl.logger.Debug("processing log files")
	absLogDir, err := filepath.Abs(fhl.config.LogDir)
	if err != nil {
		fhl.logger.Errorf("could not get absolute path for log dir: '%s', err: %s", fhl.config.LogDir, err.Error())
		return
	}
	files, err := readDir(absLogDir)
	if err != nil {
		fhl.logger.Errorf("failed to read log dir '%s', err: %s", absLogDir, err.Error())
		return
	}

	currentLogFilePath := path.Join(absLogDir, logFileName)
	var archivedLogFiles []string

	lastProcessedFilePath, err := lastProcessedFilePath(fhl.bolt)
	if err != nil {
		fhl.logger.Warn("unable to retrieve the last processed log file path: " + err.Error())
		err = nil
	}

	for _, absFilePath := range files {
		if !compareLogFiles(absFilePath, lastProcessedFilePath) {
			fhl.logger.Debugf("skipping log file '%s' because it doesn't seem older than last processed file '%s'", absFilePath, lastProcessedFilePath)
			continue
		}

		//Ignore the current log file.
		if absFilePath == currentLogFilePath {
			continue
		}
		archivedLogFiles = append(archivedLogFiles, absFilePath)
	}

	err = fhl.processArchivedLogFiles(archivedLogFiles)
	if err != nil {
		return
	}
	//Now process the current log file.
	err = fhl.processLogFile(currentLogFilePath)
	return
}

func (fhl *FirehoseLogger) processArchivedLogFiles(files []string) (err error) {

	//Sort file names so that we always process them in order.
	sort.Strings(files)

	for _, file := range files {
		fhl.logger.Debug("processing log file: " + file)
		err = fhl.processLogFile(file)
		if err != nil {
			err = fmt.Errorf("could not process archived log file '%s'. err: %s", file, err.Error())
			return
		}

		err = recordFileOffset(fhl.bolt, int64(0))
		if err != nil {
			err = fmt.Errorf("could not record file offset after successfully processing file: %s, err: %s", file, err.Error())
			return
		}

		err = recordLastProcessedFile(fhl.bolt, file)
		if err != nil {
			err = fmt.Errorf("could not record last processed log file: " + err.Error())
			return
		}
	}
	return
}

//Returns location of the file as offset that have been processed.
func (fhl *FirehoseLogger) processLogFile(absFilePath string) (err error) {

	if !pathExists(absFilePath) {
		err = fmt.Errorf("log file '%s' does not exist", absFilePath)
		return
	}
	file, err := os.OpenFile(absFilePath, os.O_RDONLY|os.O_EXCL, 0400)
	if err != nil {
		err = fmt.Errorf("error opening file '%s', err: %s", absFilePath, err.Error())
	}
	defer func() {
		err := file.Close()
		if err != nil {
			fhl.logger.Warn("failed to close processed log file. err: " + err.Error())
		}
	}()

	offset, err := fileOffset(fhl.bolt)
	if err != nil {
		err = errors.New("failed to get file offset form boltdb. err: " + err.Error())
		return
	}
	fhl.logger.Debugf("starting to read file '%s' from position: %d", absFilePath, offset)

	lastOffset, pErr := fhl.processLogFileWithSeeker(file, offset)
	err = recordFileOffset(fhl.bolt, lastOffset)
	if pErr != nil {
		// We care more about processing err than the bolt db error.
		err = pErr
	}
	return
}

func (fhl *FirehoseLogger) processLogFileWithSeeker(file io.ReadSeeker, offset int64) (newOffset int64, err error) {
	newOffset = offset
	if _, err = file.Seek(offset, io.SeekStart); err != nil {
		return
	}
	pos := offset
	s := bufio.NewScanner(file)

	//scanLines: scans the file and updates the position of offset(pos), start position
	// is untouched until the lines are processed
	scanLines := func(data []byte, atEOF bool) (advance int, token []byte, err error) {
		advance, token, err = bufio.ScanLines(data, atEOF)
		pos += int64(advance)
		return
	}
	s.Split(scanLines)

	var records [][]byte
	var recordCount, requestSize, recordSize int
	for s.Scan() {
		if len(s.Bytes()) == 0 {
			continue
		}
		// Append \n to record before sending to firehose so athena
		// can read the records correctly.
		//see: https://forums.aws.amazon.com/thread.jspa?threadID=244858
		recordSize = len(s.Bytes()) + 1
		// If we are going over the record count limit or size limit then we should not apend.
		if recordCount+1 <= putRecordBatchRecordLimit && requestSize+recordSize < putRecordBatchSizeLimit {
			recordCount++
			requestSize += recordSize
			records = append(records, append(s.Bytes(), []byte("\n")...))
			continue
		}
		err = fhl.putRecordBatch(records)
		if err != nil {
			return
		}

		// We have already read the line so adjust the content accordingly.
		newOffset = pos - int64(recordSize)
		recordCount = 1
		records = [][]byte{
			append(s.Bytes(), []byte("\n")...),
		}
		requestSize = recordSize
	}

	if s.Err() == nil {
		if len(records) > 0 {
			err = fhl.putRecordBatch(records)
			if err != nil {
				return
			}
		}
		newOffset = pos
	}
	return
}

func (fhl *FirehoseLogger) putRecordBatch(byteRecords [][]byte) (err error) {
	if len(byteRecords) == 0 {
		return errors.New("cannot send empty batch request.request")
	}
	if len(byteRecords) > putRecordBatchRecordLimit {
		return fmt.Errorf("cannot send %d records in one batch request, limit is %d", len(byteRecords), putRecordBatchRecordLimit)
	}

	var firehoseRecords []*firehose.Record

	for _, byteRecord := range byteRecords {
		record := &firehose.Record{
			Data: byteRecord,
		}
		firehoseRecords = append(firehoseRecords, record)
	}
	input := &firehose.PutRecordBatchInput{
		Records:            firehoseRecords,
		DeliveryStreamName: aws.String(fhl.config.FirehoseDeliveryStreamName),
	}
	_, err = fhl.fh.PutRecordBatch(input)
	return
}

//readDir returns content of dir as absolute path
func readDir(absDirPath string) (files []string, err error) {
	if !pathExists(absDirPath) {
		return
	}
	fileNames, err := ioutil.ReadDir(absDirPath)
	if err != nil {
		return
	}
	for _, file := range fileNames {
		absFilePath := path.Join(absDirPath, file.Name())
		files = append(files, absFilePath)
	}
	return
}

func pathExists(absPath string) bool {
	if _, err := os.Stat(absPath); err != nil {
		if os.IsNotExist(err) {
			return false
		}
	}
	return true
}
