/*
Package twitchclient for use with a twitchserver.
*/
package twitchclient

import (
	"context"
	"crypto/tls"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"io/ioutil"
	"math/rand"
	"net"
	"net/http"
	"net/url"
	"strings"
	"time"

	"github.com/cactus/go-statsd-client/statsd"

	"code.justin.tv/common/chitin"
	"code.justin.tv/common/config"
	"code.justin.tv/feeds/ctxlog"
	"code.justin.tv/foundation/xray"
)

const (
	defaultIdleConnTimeout       = 55 * time.Second // set lower than the ALB/ELB idle timeout (60s).
	defaultDialTimeout           = 30 * time.Second
	defaultDialKeepAlive         = 30 * time.Second
	defaultExpectContinueTimeout = 30 * time.Second
	defaultTLSHandshakeTimeout   = 5 * time.Second
)

func init() {
	rand.Seed(time.Now().UnixNano())
}

// Client is an interface which describes the behavior a standard twitch http client must implement.
type Client interface {
	NewRequest(method string, path string, body io.Reader) (*http.Request, error)
	Do(context.Context, *http.Request, ReqOpts) (*http.Response, error)
	DoNoContent(context.Context, *http.Request, ReqOpts) (*http.Response, error)
	DoJSON(context.Context, interface{}, *http.Request, ReqOpts) (*http.Response, error)
}

// Logger is a logging interface which offers the ability to log, debug log, and do so with
// dimensions attached to the context, like request ID.
type Logger interface {
	DebugCtx(ctx context.Context, params ...interface{})
	Debug(params ...interface{})
	LogCtx(ctx context.Context, params ...interface{})
	Log(params ...interface{})
}

// ClientConf provides the configuration for a new Client
type ClientConf struct {
	// Host (required) configures the client to connect to a specific URI host.
	// If not specified, URIs created by the client will default to use "http://".
	Host string

	// Transport supplies configuration for the client's HTTP transport.
	Transport TransportConf

	// CheckRedirect specifies the policy for handling redirects.
	CheckRedirect func(req *http.Request, via []*http.Request) error

	// Enables tracking of DNS and request timings.
	Stats statsd.Statter

	// Specify a custom stat prefix for DNS timing stats, defaults to "dns"
	// The sample rate depends on the sample rate in ReqOpts.
	DNSStatsPrefix string

	// Avoid sending the Twitch-Repository header (which is otherwise automatically included).
	// Please set to true when calling 3rd party clients (rollbar, facebook, etc)
	SuppressRepositoryHeader bool

	// The code.justin.tv/chat/timing sub-transaction name, defaults to "twitchhttp"
	TimingXactName string

	// Optional TLS config for making secure requests
	TLSClientConfig *tls.Config

	// Used to modify the BaseRoundTripper used for requests.
	// Wrappers are applied in order: RTW[2](RTW[1](RTW[0](baseRT)))
	RoundTripperWrappers []func(http.RoundTripper) http.RoundTripper

	// Base RoundTripper that makes the http request.
	// By default it is an *http.Transport provided by twitchclient with the current config,
	// but it could be overridden to create mock RoundTripper for tests or similar.
	BaseRoundTripper http.RoundTripper

	// An optional logger. The default implementation prints normal logs to stdout,
	// discards debug logs, and does not log any context dimensions like request ID.
	Logger Logger
	// ctxlog will append logging dimensions to the context using this key. A logger may
	// want to log those context values.
	DimensionKey interface{}
	/* ElevateKey is a key for request contexts. The value will be a boolean which dictates whether or not logs should be elevated to a
	higher log level. Logging implementations may inspect the context for this key and respond accordingly. It must be set to the same
	value as twitchserver. */
	ElevateKey interface{}

	// Used to prefix StatName is requests. One usage of this parameter is to prefix
	// all client stats with your service, e.g. "service.users_service". There
	// shouldn't be a leading period on StatName or a trailing period on StatNamePrefix.
	StatNamePrefix string

	// DialTimeout is the maximum amount of time a dial will wait for a connect to complete.
	//
	// When using TCP and dialing a host name with multiple IP addresses, the timeout may be divided
	// between them.
	//
	// With or without a timeout, the operating system may impose its own earlier timeout.
	// For instance, TCP timeouts are often around 3 minutes.
	//
	// If zero, defaults to `defaultDialTimeout`.
	// To disable timeout, set to a negative value.
	DialTimeout time.Duration

	// DialKeepAlive specifies the keep-alive period for an active network connection.
	// Network protocols that do not support keep-alives ignore this field.
	//
	// If zero, defaults to `defaultDialKeepAlive`.
	// Set to a negative value to disable keep-alive.
	DialKeepAlive time.Duration
}

// TransportConf provides configuration options for the HTTP transport
type TransportConf struct {
	// MaxIdleConnsPerHost controls the maximum number of idle TCP connections that can exist in the
	// connection pool for a specific host.
	// Defaults to `http.DefaultMaxIdleConnsPerHost` (2).
	MaxIdleConnsPerHost int

	// IdleConnTimeout is the maximum amount of time an idle (keep-alive) connection will remain idle
	// before closing itself.
	// If zero, defaults to `defaultIdleconnTimeout`.
	// To disable timeout, set to a negative value.
	IdleConnTimeout time.Duration

	// TLSHandshakeTimeout specifies the maximum amount of time to wait for a TLS handshake.
	// If zero, defaults to `defaultTLSHandshakeTimeout`.
	// To disable timeout, set to a negative value.
	TLSHandshakeTimeout time.Duration

	// ExpectContinueTimeout, if non-zero, specifies the amount of time to wait for a server's first
	// response headers after fully writing the request headers if the request has an
	// "Expect: 100-continue" header.
	//
	// This time does not include the time to send the request header.
	//
	// If zero, defaults to `defaultExpectContinueTimeout`.
	// To disable timeout, set to a negative value.
	//
	// If timeout is disabled, this causes the body to by sent immediately, without waiting for the
	// server to approve.
	ExpectContinueTimeout time.Duration
}

// NewHTTPClient builds a new http.Client using ClientConf,
// with stats, xact, headers and chitin roundtrippers.
// The option ClientConf.Host is ignored (http.Client has no host field).
//
// Use request context to activate stats and add auth headers, i.e.:
//     client := twitchclient.NewHTTPClient(ClientConf{Stats: stats})
//
//     ctx := context.Background()
//     ctx = twitchclient.WithTimingStats(ctx, "statname", 0.2) // track timing stats on "statname.status"
//     ctx = twitchclient.WithTwitchAuthorization(ctx, opts.AuthorizationToken) // add "Twitch-Authorization" header
//     ctx = twitchclient.WithTwitchClientRowID(ctx, opts.ClientRowID) // add "Twitch-Client-Row-ID" header
//     ctx = twitchclient.WithTwitchClientID(ctx, opts.ClientID) // add "Twitch-Client-ID" header
//     req, _ := http.NewRequest("GET", host+"/path", body)
//     req = req.WithContext(ctx)
//
//     response, err := client.Do(req)
//
func NewHTTPClient(conf ClientConf) *http.Client {
	applyDefaults(&conf)
	ctxLog := &ctxlog.Ctxlog{
		CtxDims:       &ctxDimensions{conf.DimensionKey},
		StartingIndex: rand.Int63(),
		ElevateKey:    conf.ElevateKey,
	}
	return newHTTPClient(context.Background(), conf, ctxLog)
}

func newTransport(conf ClientConf) *http.Transport {
	return &http.Transport{
		DialContext: (&net.Dialer{
			Timeout:   conf.DialTimeout,
			KeepAlive: conf.DialKeepAlive,
		}).DialContext,
		ExpectContinueTimeout: conf.Transport.ExpectContinueTimeout,
		IdleConnTimeout:       conf.Transport.IdleConnTimeout,
		Proxy:                 http.ProxyFromEnvironment,
		MaxIdleConnsPerHost:   conf.Transport.MaxIdleConnsPerHost,
		TLSClientConfig:       conf.TLSClientConfig,
		TLSHandshakeTimeout:   conf.Transport.TLSHandshakeTimeout,
	}
}

func newHTTPClient(ctx context.Context, conf ClientConf, ctxLog *ctxlog.Ctxlog) *http.Client {
	// Chitin RoundTripper
	rt, err := chitin.RoundTripper(ctx, conf.BaseRoundTripper) // with ctx.Background() will use req.Context()
	if err != nil {                                            // new versions of chitin never return an error here
		panic("You are using a very old version of chitin, please upgrade")
	}

	rt = xray.RoundTripper(rt)

	// RoundTripperWrappers
	for _, wrap := range conf.RoundTripperWrappers {
		rt = wrap(rt) // user provided wrappers (i.e. hystrix support)
	}

	// Twitch specific RoundTripper (stats, xact and headers)
	rt = wrapWithTwitchHTTPRoundTripper(rt, conf, ctxLog)

	return &http.Client{
		Transport:     rt,
		CheckRedirect: conf.CheckRedirect,
	}
}

func applyDefaults(conf *ClientConf) {
	if conf.Stats == nil {
		conf.Stats = config.Statsd()
	}
	if conf.Logger == nil {
		conf.Logger = &defaultLogger{}
	}
	if conf.DimensionKey == nil {
		conf.DimensionKey = new(int)
	}
	if conf.DNSStatsPrefix == "" {
		conf.DNSStatsPrefix = "dns"
	}

	conf.DialKeepAlive = defaultIfUnset(conf.DialKeepAlive, defaultDialKeepAlive)
	conf.DialTimeout = defaultIfUnset(conf.DialTimeout, defaultDialTimeout)

	transport := &conf.Transport
	transport.ExpectContinueTimeout = defaultIfUnset(transport.ExpectContinueTimeout, defaultExpectContinueTimeout)
	transport.IdleConnTimeout = defaultIfUnset(transport.IdleConnTimeout, defaultIdleConnTimeout)
	transport.TLSHandshakeTimeout = defaultIfUnset(transport.TLSHandshakeTimeout, defaultTLSHandshakeTimeout)

	if conf.BaseRoundTripper == nil {
		conf.BaseRoundTripper = newTransport(*conf)
	}
}

// Returns `def` when `val` is zero.
// To coerce a zero-value (for the purposes of disabling a timeout), pass a negative `val`.
func defaultIfUnset(val, def time.Duration) time.Duration {
	switch {
	case val < 0:
		return time.Duration(0)
	case val == 0:
		return def
	default:
		return val
	}
}

// NewClient builds a new twitchclient.Client using ClientConf.
// NOTE: You should use NewHTTPClient instead.
func NewClient(conf ClientConf) (Client, error) {
	applyDefaults(&conf)
	hostURL, err := sanitizeHostURL(conf.Host)
	if err != nil {
		return nil, err
	}

	ctxLog := &ctxlog.Ctxlog{
		StartingIndex: rand.Int63(),
		CtxDims:       &ctxDimensions{conf.DimensionKey},
		ElevateKey:    conf.ElevateKey,
	}

	return &client{
		host:   hostURL,
		conf:   &conf,
		logger: conf.Logger,
		ctxLog: ctxLog,
	}, nil
}

type client struct {
	host   *url.URL
	conf   *ClientConf
	logger Logger
	ctxLog *ctxlog.Ctxlog
}

var _ Client = (*client)(nil)

// NewRequest creates an *http.Request using the configured host as the base for the path.
func (c *client) NewRequest(method string, path string, body io.Reader) (*http.Request, error) {
	u, err := url.Parse(path)
	if err != nil {
		return nil, err
	}

	return http.NewRequest(method, c.host.ResolveReference(u).String(), body)
}

func (c *client) modifyReqOpts(opts *ReqOpts) {
	// Default ReqOpt StatSampleRate is 0.1 if StatName was specified
	if opts.StatName != "" && opts.StatSampleRate == 0 {
		opts.StatSampleRate = 0.1
	}

	// Add StatNamePrefix to StatName
	if c.conf.StatNamePrefix != "" {
		opts.StatName = fmt.Sprintf("%s.%s", c.conf.StatNamePrefix, opts.StatName)
	}
}

// do execute a request but leaves error handling and cleanup to the caller.
func (c *client) do(ctx context.Context, req *http.Request, opts ReqOpts) (*http.Response, error) {
	c.modifyReqOpts(&opts)

	// Annotate ctx with opts (used by twitchHTTPRoundTripper)
	ctx = WithReqOpts(ctx, opts)

	// Use provided ctx if req.Context() was not set
	if req.Context() == context.Background() {
		req = req.WithContext(ctx)
	}
	// Make new httpClient using ctx for chitin.RoundTripper
	httpClient := newHTTPClient(ctx, *c.conf, c.ctxLog)
	return httpClient.Do(req)
}

// Do executes a request using the given context.
// Do does not close the response body.
// Do does not perform any error handling on behalf of the user.
// It's the callers responsibility to handle errors and close the response body when no error exists.
func (c *client) Do(ctx context.Context, req *http.Request, opts ReqOpts) (*http.Response, error) {
	return c.do(ctx, req, opts)
}

// DoNoContent executes a request using the given Context for Trace support.
// DoNoContent will always close the response body and returns an error on 4xx, 5xx status codes.
// DoNoContent is meant to be used when the caller is not interested in reading the response body.
func (c *client) DoNoContent(ctx context.Context, req *http.Request, opts ReqOpts) (*http.Response, error) {
	resp, err := c.do(ctx, req, opts)
	if err != nil {
		return nil, err
	}

	defer func() {
		if closeErr := resp.Body.Close(); closeErr != nil {
			c.logger.Log(fmt.Sprintf("Could not close response body: %v", closeErr))
		}
	}()

	return resp, parseErrorIfPresent(resp)
}

// DoJSON executes a request, then deserializes the response.
// NOTE: a *twitchclient.Error is returned on 4xx errors, but not on 5xx errors.
// This is problematic because it makes it very difficult for handlers (like Visage) to know if an error
// was caused because of a 5xx backend error or it is a logical error returned by the client.
// For this reason you should not use DoJSON. It is better if you use `Do` instead and always return an error with status.
func (c *client) DoJSON(ctx context.Context, data interface{}, req *http.Request, opts ReqOpts) (*http.Response, error) {
	resp, err := c.do(ctx, req, opts)
	if err != nil {
		return nil, err
	}

	defer func() {
		if closeErr := resp.Body.Close(); closeErr != nil {
			c.logger.Log(fmt.Sprintf("Could not close response body: %v", closeErr))
		}
	}()

	if err = parseErrorIfPresent(resp); err != nil {
		return resp, err
	}

	if resp.StatusCode != http.StatusNoContent {
		err = json.NewDecoder(resp.Body).Decode(data)
		if err != nil {
			return resp, fmt.Errorf("Unable to read response body: %s", err)
		}
	}
	return resp, nil
}

func parseErrorIfPresent(resp *http.Response) error {
	if resp.StatusCode >= 500 {
		body, err := ioutil.ReadAll(resp.Body)
		if err != nil {
			return errors.New(resp.Status + ": unable to read response body for more error information")
		}
		return errors.New(resp.Status + ": " + string(body))
	}

	if resp.StatusCode >= 400 {
		return HandleFailedResponse(resp)
	}
	return nil
}

func sanitizeHostURL(host string) (*url.URL, error) {
	if host == "" {
		return nil, errors.New("Host cannot be blank")
	}
	if !strings.HasPrefix(host, "http") {
		host = fmt.Sprintf("http://%v", host)
	}
	hostURL, err := url.Parse(host)
	if err != nil {
		return nil, err
	}
	return hostURL, nil
}
