package clickhouse

import (
	"database/sql/driver"
	"errors"
	"fmt"
	"io"
	"log"
	"sort"
	"strings"

	ch "github.com/ClickHouse/clickhouse-go"
)

// Table-related types and functions.

type ColumnType int

const (
	ColumnUnknown = ColumnType(iota)
	ColumnFloat64
	ColumnInt64
	ColumnString
	ColumnDateTime64
)

type TableSchema map[string]ColumnType

func GetClickhouseType(t ColumnType) string {
	switch t {
	case ColumnFloat64:
		return "Float64"
	case ColumnInt64:
		return "Int64"
	case ColumnString:
		return "String"
	case ColumnDateTime64:
		// We depend on Clickhouse timezone setting here.
		return "DateTime64(0)"
	default:
		return "UnknownType"
	}
}

func GetColumnTypeFromClickhouse(columnTypeStr string) ColumnType {
	switch columnTypeStr {
	case "Float64":
		return ColumnFloat64
	case "Int64":
		return ColumnInt64
	case "String":
		return ColumnString
	case "DateTime64(0)":
		return ColumnDateTime64
	default:
		return ColumnUnknown
	}
}

func (s TableSchema) SortedColumns() []string {
	result := make([]string, 0, len(s))
	for str := range s {
		result = append(result, str)
	}
	sort.Strings(result)
	return result
}

func TableExists(pool *ClickhousePool, tableName string) (bool, error) {
	var exists bool
	sql := fmt.Sprintf("EXISTS TABLE `%s`", tableName)
	err := RunQuery(pool, sql, func(rows []map[string]driver.Value) error {
		if len(rows) != 1 {
			return fmt.Errorf("strange number of rows: %d", len(rows))
		}
		row := rows[0]
		for _, value := range row {
			valueInt, ok := value.(uint8)
			if !ok {
				return fmt.Errorf("strange column: %v", row)
			}
			exists = valueInt == 1
			return nil
		}
		return errors.New("strange: no columns")
	})
	if err != nil {
		return false, fmt.Errorf("error checking if table %s exists: %v", tableName, err)
	}
	return exists, nil
}

func CreateDistributedTable(pool *ClickhousePool, tableName string, shardTableName string) error {
	err := RunTx(pool, func(conn ch.Clickhouse) error {
		sql := fmt.Sprintf("CREATE TABLE IF NOT EXISTS `%s` ON CLUSTER '{cluster}' AS `%s`.`%s` ENGINE = Distributed('{cluster}', '%s', '%s', rand())",
			tableName, pool.database, shardTableName, pool.database, shardTableName)
		log.Printf("running %s\n", sql)
		stmt, err := conn.Prepare(sql)
		if err != nil {
			return err
		}
		defer func() {
			err := stmt.Close()
			if err != nil {
				log.Printf("ERROR: closing statement failed: %v\n", err)
			}
		}()

		//goland:noinspection GoDeprecation
		_, err = stmt.Exec([]driver.Value{})
		if err != nil {
			return err
		}

		log.Printf("created distributed table %s over table %s in Clickhouse: %s\n", tableName, shardTableName, sql)
		return nil
	})
	if err != nil {
		return fmt.Errorf("creating table %s failed: %v", tableName, err)
	}
	return err
}

func CreateReplicatedTable(pool *ClickhousePool, tableName string, table TableSchema, orderBy string, partitionBy string) error {
	err := RunTx(pool, func(conn ch.Clickhouse) error {
		columnsDesc := make([]string, 0, len(table))
		for _, col := range table.SortedColumns() {
			columnDesc := fmt.Sprintf("`%s` %s", col, GetClickhouseType(table[col]))
			columnsDesc = append(columnsDesc, columnDesc)
		}

		tableEngine := "ReplicatedMergeTree('/clickhouse/tables/{shard}/{database}/{table}', '{replica}')"
		sql := fmt.Sprintf("CREATE TABLE IF NOT EXISTS `%s` ON CLUSTER '{cluster}' (%s) ENGINE = %s PARTITION BY %s ORDER BY %s",
			tableName, strings.Join(columnsDesc, ", "), tableEngine, partitionBy, orderBy)
		log.Printf("running %s\n", sql)
		stmt, err := conn.Prepare(sql)
		if err != nil {
			return err
		}
		defer func() {
			err := stmt.Close()
			if err != nil {
				log.Printf("ERROR: closing statement failed: %v\n", err)
			}
		}()

		//goland:noinspection GoDeprecation
		_, err = stmt.Exec([]driver.Value{})
		if err != nil {
			return err
		}

		log.Printf("created table %s in Clickhouse: %s\n", tableName, sql)
		return nil
	})
	if err != nil {
		return fmt.Errorf("creating table %s failed: %v", tableName, err)
	}
	return err
}

func DescribeTable(pool *ClickhousePool, tableName string) (TableSchema, error) {
	var result TableSchema
	sql := fmt.Sprintf("DESCRIBE TABLE `%s`", tableName)
	err := RunQuery(pool, sql, func(rows []map[string]driver.Value) error {
		result = make(TableSchema, len(rows))
		for _, row := range rows {
			column := row["name"]
			if column == nil {
				return fmt.Errorf("no column 'name' in describe results: %v", row)
			}
			columnStr, ok := column.(string)
			if !ok || columnStr == "" {
				return fmt.Errorf("strange column 'name' in describe: %v", row)
			}
			columnType := row["type"]
			if columnType == nil {
				return fmt.Errorf("no column 'type' in describe results: %v", row)
			}
			columnTypeStr, ok := columnType.(string)
			if !ok || columnTypeStr == "" {
				return fmt.Errorf("strange column 'type' in describe: %v", row)
			}

			columnTypeEnum := GetColumnTypeFromClickhouse(columnTypeStr)
			if columnTypeEnum == ColumnUnknown {
				return fmt.Errorf("unsupported column type '%s' in describe: %v", columnTypeStr, row)
			}
			result[columnStr] = columnTypeEnum
		}

		return nil
	})
	if err != nil {
		return nil, fmt.Errorf("describing table %s failed: %v", tableName, err)
	}
	return result, err
}

func AlterTable(pool *ClickhousePool, tableName string, table TableSchema, prevTable TableSchema) error {
	var newColumns []string
	var newColumnTypes []ColumnType
	for column, columnType := range table {
		prevType, ok := prevTable[column]
		if ok {
			if prevType == columnType {
				continue
			}
			return fmt.Errorf("previous column '%s' type: %d, now required: %d", column, prevType, columnType)
		}
		newColumns = append(newColumns, column)
		newColumnTypes = append(newColumnTypes, columnType)
	}

	err := RunTx(pool, func(conn ch.Clickhouse) error {
		var addColumnsDesc []string
		for i := 0; i < len(newColumns); i++ {
			addColumnsDesc = append(addColumnsDesc, fmt.Sprintf("ADD COLUMN `%s` %s",
				newColumns[i], GetClickhouseType(newColumnTypes[i])))
		}
		sql := fmt.Sprintf("ALTER TABLE `%s` ON CLUSTER '{cluster}' %s", tableName, strings.Join(addColumnsDesc, ", "))
		log.Printf("running %s\n", sql)
		stmt, err := conn.Prepare(sql)
		if err != nil {
			return err
		}
		defer func() {
			err := stmt.Close()
			if err != nil {
				log.Printf("ERROR: closing statement failed: %v\n", err)
			}
		}()

		//goland:noinspection GoDeprecation
		_, err = stmt.Exec([]driver.Value{})
		if err != nil {
			return err
		}

		log.Printf("altered table %s in Clickhouse: %s\n", tableName, sql)
		return nil
	})
	if err != nil {
		return fmt.Errorf("altering table %s failed: %v", tableName, err)
	}
	return nil
}

type TableInfo struct {
	Bytes int64
	Rows  int64
}

func GetTableInfos(pool *ClickhousePool) (map[string]TableInfo, error) {
	sql := fmt.Sprintf("SELECT name, total_rows, total_bytes from cluster('{cluster}', system.tables) where database = '%s'", pool.database)
	var result map[string]TableInfo
	err := RunQuery(pool, sql, func(rows []map[string]driver.Value) error {
		result = make(map[string]TableInfo, len(rows))
		for _, row := range rows {
			table := row["name"]
			if table == nil {
				return fmt.Errorf("no column 'name' in table size results: %v", row)
			}
			tableStr, ok := table.(string)
			if !ok || tableStr == "" {
				return fmt.Errorf("strange column 'name' in table size results: %v", row)
			}
			totalBytes := row["total_bytes"]
			if totalBytes == nil {
				continue
			}
			totalBytesInt, ok := totalBytes.(uint64)
			if !ok {
				return fmt.Errorf("strange column 'total_bytes' in table size results: %v", row)
			}
			totalRows := row["total_rows"]
			if totalRows == nil {
				continue
			}
			totalRowsInt, ok := totalRows.(uint64)
			if !ok {
				return fmt.Errorf("strange column 'total_rows' in table size results: %v", row)
			}
			if info, ok := result[tableStr]; ok {
				info.Bytes += int64(totalBytesInt)
				info.Rows += int64(totalRowsInt)
				result[tableStr] = info
			} else {
				result[tableStr] = TableInfo{
					Bytes: int64(totalBytesInt),
					Rows:  int64(totalRowsInt),
				}
			}
		}

		return nil
	})
	return result, err
}

type stringset map[string]bool

func GetTablePartitions(pool *ClickhousePool) (map[string][]string, error) {
	sql := fmt.Sprintf("SELECT table, partition FROM cluster('{cluster}', system.parts) WHERE database = '%s'", pool.database)
	var parts map[string]stringset
	err := RunQuery(pool, sql, func(rows []map[string]driver.Value) error {
		parts = make(map[string]stringset, len(rows))
		for _, row := range rows {
			table := row["table"]
			if table == nil {
				return fmt.Errorf("no column 'table' in parts results: %v", row)
			}
			tableStr, ok := table.(string)
			if !ok || tableStr == "" {
				return fmt.Errorf("strange column 'table' in table size results: %v", row)
			}
			partition := row["partition"]
			if partition == nil {
				return fmt.Errorf("no column 'partition' in parts results: %v", row)
			}
			partitionStr, ok := partition.(string)
			if !ok || partitionStr == "" {
				return fmt.Errorf("strange column 'partition' in table size results: %v", row)
			}
			if parts[tableStr] == nil {
				parts[tableStr] = stringset{}
			}
			parts[tableStr][partitionStr] = true
		}
		return nil
	})
	if err != nil {
		return nil, err
	}

	ret := map[string][]string{}
	for table, set := range parts {
		for partition := range set {
			ret[table] = append(ret[table], partition)
		}
	}
	return ret, nil
}

func GetServerTimezone(pool *ClickhousePool) (string, error) {
	sql := "SELECT timezone()"
	var result string
	err := RunQuery(pool, sql, func(rows []map[string]driver.Value) error {
		if len(rows) != 1 {
			return fmt.Errorf("strange number of rows: %d", len(rows))
		}
		for _, value := range rows[0] {
			valueStr, ok := value.(string)
			if !ok {
				return fmt.Errorf("strange column in timezone results: %v", rows[0])
			}
			result = valueStr
			return nil
		}
		return errors.New("strange: no columns")
	})
	return result, err
}

func TableDropPartition(pool *ClickhousePool, tableName string, part string) error {
	return RunTx(pool, func(conn ch.Clickhouse) error {
		sql := fmt.Sprintf("ALTER TABLE `%s` ON CLUSTER '{cluster}' DROP PARTITION %s", tableName, part)
		log.Printf("running %s\n", sql)
		stmt, err := conn.Prepare(sql)
		if err != nil {
			return err
		}
		defer func() {
			err := stmt.Close()
			if err != nil {
				log.Printf("ERROR: closing statement failed: %v\n", err)
			}
		}()

		//goland:noinspection GoDeprecation
		_, err = stmt.Exec([]driver.Value{})
		if err != nil {
			return err
		}

		log.Printf("dropped partition %s for table %s in Clickhouse\n", part, tableName)
		return nil
	})
}

// A wrapper for Run() which calls Begin()/Rollback()/Commit() for you.
func RunTx(pool *ClickhousePool, proc ConnProc) error {
	return pool.Run(func(conn ch.Clickhouse) error {
		_, err := conn.Begin()
		if err != nil {
			return err
		}
		commited := false
		defer func() {
			if !commited {
				_ = conn.Rollback()
			}
		}()

		err = proc(conn)
		if err != nil {
			return err
		}

		err = conn.Commit()
		if err != nil {
			return err
		}
		commited = true

		return nil
	})
}

type RowProc func(rows []map[string]driver.Value) error

// Runs sql query, reads all rows into memory and converts them into maps column -> value. proc is called at most once.
func RunQuery(pool *ClickhousePool, sql string, proc RowProc) error {
	return RunTx(pool, func(conn ch.Clickhouse) error {
		stmt, err := conn.Prepare(sql)
		if err != nil {
			return err
		}
		defer func() {
			err := stmt.Close()
			if err != nil {
				log.Printf("ERROR: closing statement failed: %v\n", err)
			}
		}()
		//goland:noinspection GoDeprecation
		rows, err := stmt.Query([]driver.Value{})
		if err != nil {
			return err
		}
		defer func() {
			err := rows.Close()
			if err != nil {
				log.Printf("ERROR: closing rows failed: %v\n", err)
			}
		}()

		// NOTE: We must read all the rows first and call proc only after that. While it is not as efficient as calling
		// proc on each
		mapRows, err := rowsToMaps(rows)
		if err != nil {
			return err
		}
		return proc(mapRows)
	})
}

func rowsToMaps(rows driver.Rows) ([]map[string]driver.Value, error) {
	var result []map[string]driver.Value
	numColumns := len(rows.Columns())
	values := make([]driver.Value, numColumns)
	for {
		err := rows.Next(values)
		if err != nil {
			if err == io.EOF {
				break
			} else {
				return nil, err
			}
		}
		m := map[string]driver.Value{}
		for i := 0; i < numColumns; i++ {
			m[rows.Columns()[i]] = values[i]
		}
		result = append(result, m)
	}
	return result, nil
}
