package db

import (
	"context"
	"time"

	"database/sql"

	"code.justin.tv/feeds/log"
	service_common "code.justin.tv/feeds/service-common"
)

// InstrumentedDBClient handles non-functional concerns for db operations, such as recording stats.
type InstrumentedDBClient struct {
	Inner DB

	Log   log.Logger
	Stats *service_common.StatSender
}

var _ DB = &InstrumentedDBClient{}

func (d *InstrumentedDBClient) GetEvent(ctx context.Context, id string, getDeleted bool) (*Event, error) {
	var err error
	defer func(startTime time.Time) {
		d.recordStats(ctx, "get_event", startTime, err == nil)
	}(time.Now())

	event, err := d.Inner.GetEvent(ctx, id, getDeleted)
	return event, err
}

func (d *InstrumentedDBClient) GetEvents(ctx context.Context, ids []string, getDeleted bool) ([]*Event, error) {
	var err error
	defer func(startTime time.Time) {
		d.recordStats(ctx, "get_events", startTime, err == nil)
	}(time.Now())

	events, err := d.Inner.GetEvents(ctx, ids, getDeleted)
	return events, err
}

func (d *InstrumentedDBClient) GetEventIDsOrderedByAscStartTime(
	ctx context.Context, filter *BroadcastFilter, offset *EventIDsOrderedByAscTimeOffset, limit int) ([]EventIDItem, bool, error) {

	var err error
	defer func(startTime time.Time) {
		d.recordStats(ctx, "get_event_ids_ordered_by_asc_start_time", startTime, err == nil)
	}(time.Now())

	eventsIDItems, hasNextPage, err := d.Inner.GetEventIDsOrderedByAscStartTime(ctx, filter, offset, limit)
	return eventsIDItems, hasNextPage, err
}

func (d *InstrumentedDBClient) CreateLocalization(ctx context.Context, params *UpdateDBLocalizationParams) (*Localization, error) {
	var err error
	defer func(startTime time.Time) {
		d.recordStats(ctx, "create_localization", startTime, err == nil)
	}(time.Now())

	loc, err := d.Inner.CreateLocalization(ctx, params)
	return loc, err
}

func (d *InstrumentedDBClient) GetLocalization(ctx context.Context, eventID string, language string) (*Localization, error) {
	var err error
	defer func(startTime time.Time) {
		d.recordStats(ctx, "get_localization", startTime, err == nil)
	}(time.Now())

	loc, err := d.Inner.GetLocalization(ctx, eventID, language)
	return loc, err
}

func (d *InstrumentedDBClient) GetLocalizationsByEventID(ctx context.Context, eventID string) ([]*Localization, error) {
	var err error
	defer func(startTime time.Time) {
		d.recordStats(ctx, "get_localizations_by_event_id", startTime, err == nil)
	}(time.Now())

	loc, err := d.Inner.GetLocalizationsByEventID(ctx, eventID)
	return loc, err
}

func (d *InstrumentedDBClient) UpdateLocalization(ctx context.Context, eventID string, language string, params *UpdateDBLocalizationParams) (*Localization, error) {
	var err error
	defer func(startTime time.Time) {
		d.recordStats(ctx, "update_localization", startTime, err == nil)
	}(time.Now())

	loc, err := d.Inner.UpdateLocalization(ctx, eventID, language, params)
	return loc, err
}

func (d *InstrumentedDBClient) DeleteLocalization(ctx context.Context, eventID string, language string) (*Localization, error) {
	var err error
	defer func(startTime time.Time) {
		d.recordStats(ctx, "delete_localization", startTime, err == nil)
	}(time.Now())

	loc, err := d.Inner.DeleteLocalization(ctx, eventID, language)
	return loc, err
}

func (d *InstrumentedDBClient) CreateEvent(ctx context.Context, params *CreateDBEventParams) (*Event, error) {
	var err error
	defer func(startTime time.Time) {
		d.recordStats(ctx, "create_event", startTime, err == nil)
	}(time.Now())

	event, err := d.Inner.CreateEvent(ctx, params)
	return event, err
}

func (d *InstrumentedDBClient) GetEventsByParentIDs(ctx context.Context, parentIDs []string, getDeleted bool) ([]*Event, error) {
	var err error
	defer func(startTime time.Time) {
		d.recordStats(ctx, "get_events_by_parent_ids", startTime, err == nil)
	}(time.Now())

	events, err := d.Inner.GetEventsByParentIDs(ctx, parentIDs, getDeleted)
	return events, err
}

func (d *InstrumentedDBClient) UpdateEvent(ctx context.Context, eventID string, params *UpdateDBEventParams) (*Event, error) {
	var err error
	defer func(startTime time.Time) {
		d.recordStats(ctx, "update_event", startTime, err == nil)
	}(time.Now())

	event, err := d.Inner.UpdateEvent(ctx, eventID, params)
	return event, err
}

func (d *InstrumentedDBClient) DeleteEvent(ctx context.Context, eventID string, params *DeleteDBEventParams) (*Event, error) {
	var err error
	defer func(startTime time.Time) {
		d.recordStats(ctx, "delete_event", startTime, err == nil)
	}(time.Now())

	event, err := d.Inner.DeleteEvent(ctx, eventID, params)
	return event, err
}

func (d *InstrumentedDBClient) DeleteEvents(ctx context.Context, eventIDs []string, deleteParams *DeleteDBEventParams) ([]*Event, error) {
	var err error
	defer func(startTime time.Time) {
		d.recordStats(ctx, "delete_events", startTime, err == nil)
	}(time.Now())

	events, err := d.Inner.DeleteEvents(ctx, eventIDs, deleteParams)
	return events, err
}

func (d *InstrumentedDBClient) HardDeleteEventsByOwnerID(ctx context.Context, ownerID string) ([]*Event, error) {
	var err error
	defer func(startTime time.Time) {
		d.recordStats(ctx, "hard_delete_events_by_owner_id", startTime, err == nil)
	}(time.Now())

	events, err := d.Inner.HardDeleteEventsByOwnerID(ctx, ownerID)
	return events, err
}

func (d *InstrumentedDBClient) HasEventIDsOrderedByAscStartTime(ctx context.Context, filter *BroadcastFilter, offset *HasEventIDsOrderedByAscTimeOffset) (bool, error) {
	var err error
	defer func(startTime time.Time) {
		d.recordStats(ctx, "has_event_ids_ordered_by_asc_start_time", startTime, err == nil)
	}(time.Now())

	hasIDs, err := d.Inner.HasEventIDsOrderedByAscStartTime(ctx, filter, offset)
	return hasIDs, err
}

func (d *InstrumentedDBClient) CreateBroadcast(ctx context.Context, params *CreateDBBroadcastParams) (*Broadcast, error) {
	var err error
	defer func(startTime time.Time) {
		d.recordStats(ctx, "create_broadcast", startTime, err == nil)
	}(time.Now())

	broadcast, err := d.Inner.CreateBroadcast(ctx, params)
	return broadcast, err
}

func (d *InstrumentedDBClient) UpdateBroadcast(ctx context.Context, eventID, language string, params *UpdateDBBroadcastParams) (*Broadcast, error) {
	var err error
	defer func(startTime time.Time) {
		d.recordStats(ctx, "update_broadcast", startTime, err == nil)
	}(time.Now())

	broadcast, err := d.Inner.UpdateBroadcast(ctx, eventID, language, params)
	return broadcast, err
}

func (d *InstrumentedDBClient) DeleteBroadcastsByEventID(ctx context.Context, eventID string) ([]string, error) {
	var err error
	defer func(startTime time.Time) {
		d.recordStats(ctx, "delete_broadcasts_by_event_id", startTime, err == nil)
	}(time.Now())

	ids, err := d.Inner.DeleteBroadcastsByEventID(ctx, eventID)
	return ids, err
}

func (d *InstrumentedDBClient) DeleteBroadcastsByEventIDs(ctx context.Context, eventIDs []string) ([]string, error) {
	var err error
	defer func(startTime time.Time) {
		d.recordStats(ctx, "delete_broadcasts_by_event_ids", startTime, err == nil)
	}(time.Now())

	ids, err := d.Inner.DeleteBroadcastsByEventIDs(ctx, eventIDs)
	return ids, err
}

func (d *InstrumentedDBClient) DeleteBroadcastByEventIDAndLanguage(ctx context.Context, eventID string, language string) ([]string, error) {
	var err error
	defer func(startTime time.Time) {
		d.recordStats(ctx, "delete_broadcast_by_event_id_and_language", startTime, err == nil)
	}(time.Now())

	ids, err := d.Inner.DeleteBroadcastByEventIDAndLanguage(ctx, eventID, language)
	return ids, err
}

func (d *InstrumentedDBClient) GetEventIDsSortedByStartTime(ctx context.Context, filter *BroadcastFilter, desc bool, cursor string, limit int) (*EventIDs, error) {
	var err error
	defer func(startTime time.Time) {
		d.recordStats(ctx, "get_event_ids_sorted_by_start_time", startTime, err == nil)
	}(time.Now())

	eventIDs, err := d.Inner.GetEventIDsSortedByStartTime(ctx, filter, desc, cursor, limit)
	return eventIDs, err
}

func (d *InstrumentedDBClient) GetEventIDsSortedByHype(ctx context.Context, filter *BroadcastFilter, desc bool, cursor string, limit int) (*EventIDs, error) {
	var err error
	defer func(startTime time.Time) {
		d.recordStats(ctx, "get_event_ids_sorted_by_hype", startTime, err == nil)
	}(time.Now())

	eventIDs, err := d.Inner.GetEventIDsSortedByHype(ctx, filter, desc, cursor, limit)
	return eventIDs, err
}

func (d *InstrumentedDBClient) GetEventIDsSortedByID(ctx context.Context, filter *BroadcastFilter, cursor string, limit int) (*EventIDs, error) {
	var err error
	defer func(startTime time.Time) {
		d.recordStats(ctx, "get_event_ids_sorted_by_id", startTime, err == nil)
	}(time.Now())

	eventIDs, err := d.Inner.GetEventIDsSortedByID(ctx, filter, cursor, limit)
	return eventIDs, err
}

func (d *InstrumentedDBClient) GetBroadcastsByHype(ctx context.Context, filter *BroadcastFilter, desc bool, cursor string, limit int) ([]*Broadcast, error) {
	var err error
	defer func(startTime time.Time) {
		d.recordStats(ctx, "get_broadcasts_by_hype", startTime, err == nil)
	}(time.Now())

	broadcasts, err := d.Inner.GetBroadcastsByHype(ctx, filter, desc, cursor, limit)
	return broadcasts, err
}

func (d *InstrumentedDBClient) GetCollectionIDsByOwner(ctx context.Context, ownerID string, desc bool, cursor string, limit int) (*EventIDs, error) {
	var err error
	defer func(startTime time.Time) {
		d.recordStats(ctx, "get_collection_ids_by_owner", startTime, err == nil)
	}(time.Now())

	eventIDs, err := d.Inner.GetCollectionIDsByOwner(ctx, ownerID, desc, cursor, limit)
	return eventIDs, err
}

func (d *InstrumentedDBClient) GetEventAttributesForEventIDs(ctx context.Context, eventIDs []string, keys []string) (map[string]map[string]string, error) {
	var err error
	defer func(startTime time.Time) {
		d.recordStats(ctx, "get_event_attributes_for_event_ids", startTime, err == nil)
	}(time.Now())

	attributes, err := d.Inner.GetEventAttributesForEventIDs(ctx, eventIDs, keys)
	return attributes, err
}

func (d *InstrumentedDBClient) GetEventAttributes(ctx context.Context, eventID string) (map[string]string, error) {
	var err error
	defer func(startTime time.Time) {
		d.recordStats(ctx, "get_event_attributes", startTime, err == nil)
	}(time.Now())

	attributes, err := d.Inner.GetEventAttributes(ctx, eventID)
	return attributes, err
}

func (d *InstrumentedDBClient) SetEventAttributes(ctx context.Context, eventID string, attributes map[string]string) error {
	var err error
	defer func(startTime time.Time) {
		d.recordStats(ctx, "set_event_attributes", startTime, err == nil)
	}(time.Now())

	err = d.Inner.SetEventAttributes(ctx, eventID, attributes)
	return err
}

func (d *InstrumentedDBClient) GetEventStats(ctx context.Context, eventIDs []string) ([]*EventStats, error) {
	var err error
	defer func(startTime time.Time) {
		d.recordStats(ctx, "get_event_stats", startTime, err == nil)
	}(time.Now())

	eventStats, err := d.Inner.GetEventStats(ctx, eventIDs)
	return eventStats, err
}

func (d *InstrumentedDBClient) IncrementEventFollowCount(ctx context.Context, eventID string) (int64, error) {
	var err error
	defer func(startTime time.Time) {
		d.recordStats(ctx, "increment_event_follow_count", startTime, err == nil)
	}(time.Now())

	count, err := d.Inner.IncrementEventFollowCount(ctx, eventID)
	return count, err
}

func (d *InstrumentedDBClient) DecrementEventFollowCount(ctx context.Context, eventID string) (int64, error) {
	var err error
	defer func(startTime time.Time) {
		d.recordStats(ctx, "decrement_event_follow_count", startTime, err == nil)
	}(time.Now())

	count, err := d.Inner.DecrementEventFollowCount(ctx, eventID)
	return count, err
}

func (d *InstrumentedDBClient) StartOrJoinTx(ctx context.Context, opts *sql.TxOptions) (context.Context, bool, error) {
	return d.Inner.StartOrJoinTx(ctx, opts)
}

func (d *InstrumentedDBClient) CommitTx(ctx context.Context, createdTx bool) error {
	return d.Inner.CommitTx(ctx, createdTx)
}

func (d *InstrumentedDBClient) RollbackTxIfNotCommitted(ctx context.Context, createdTx bool) {
	d.Inner.RollbackTxIfNotCommitted(ctx, createdTx)
}

func (d *InstrumentedDBClient) recordStats(ctx context.Context, operationName string, startTime time.Time, succeeded bool) {
	// Record duration so that we can calculate latency.
	endTime := time.Now()
	duration := endTime.Sub(startTime)
	d.Stats.TimingDurationC(operationName+".time", duration, 1)

	// Distinguish between the operation succeeding, failing due the context being canceled, and failing due to a db
	// error.
	status := "success"
	if !succeeded {
		if ctx.Err() != nil {
			status = "ctx_error"
		} else {
			status = "db_error"
		}
	}

	// Record count so that we can calculate throughput.
	d.Stats.IncC(operationName+".status."+status, 1, 1)
}
