import argparse
import base64
import hashlib
import json
import logging
import logging.handlers
import os
import sys
import time

from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager

from boto.s3.connection import S3Connection

import tornado.web
import tornado.gen
import tornado.ioloop
from tornado.httpclient import AsyncHTTPClient

from infra.yasm.unistat import Unistat


S3_BUCKET_NAME = "yasm-img"
S3_ACCESS_KEY_ID_ENV_VARIABLE = "S3_ACCESS_KEY_ID"
S3_SECRET_ACCESS_KEY_ENV_VARIABLE = "S3_SECRET_ACCESS_KEY"
NO_IMAGE_NAME = "no_image.png"
S3_HEAD_REQUEST_TIMEOUT = 1.5

log = logging.getLogger(__name__)


def make_bucket_url(s3_endpoint, filename):
    return "https://{}.{}/{}".format(S3_BUCKET_NAME, s3_endpoint, filename)


@contextmanager
def time_it(unistat, metric_name):
    start_time = time.time()
    yield
    unistat.push(metric_name, time.time() - start_time)


class ReadyHandler(tornado.web.RequestHandler):
    @tornado.gen.coroutine
    def get(self):
        self.set_header("Content-Type", "application/json")
        self.write(json.dumps({
            "status": "OK"
        }))


class SnapshotHandler(tornado.web.RequestHandler):
    S3_ERRORS_SIGNAL = "snapshot_s3_request_errors"
    SNAPSHOTS_UPLOADED_SIGNAL = "snapshots_uploaded"
    SNAPSHOTS_UPLOAD_TIME_SIGNAL = "snapshot_upload_time"
    SNAPSHOT_HANDLE_TIME_SIGNAL = "snapshot_handle_time"

    def initialize(self, upload_executor, s3_endpoint, s3_access_key_id, s3_secret_access_key, unistat):
        self._upload_executor = upload_executor
        self._s3_endpoint = s3_endpoint
        self._s3_access_key_id = s3_access_key_id
        self._s3_secret_access_key = s3_secret_access_key
        self._unistat = unistat

    @classmethod
    def prepare_unistat_signals(cls, unistat):
        unistat.create_float(cls.S3_ERRORS_SIGNAL)
        unistat.create_float(cls.SNAPSHOTS_UPLOADED_SIGNAL)
        unistat.create_histogram(cls.SNAPSHOTS_UPLOAD_TIME_SIGNAL)
        unistat.create_histogram(cls.SNAPSHOT_HANDLE_TIME_SIGNAL)

    @tornado.gen.coroutine
    def post(self, *args, **kwargs):
        response, code = yield self._upload_executor.submit(self._handle_png, self.request.body)
        self.set_header("Content-Type", "application/json")
        self.set_status(code)
        self.write(json.dumps(response))

    def _handle_png(self, data):
        with time_it(self._unistat, self.SNAPSHOT_HANDLE_TIME_SIGNAL):
            try:
                png = base64.b64decode(data)
                if not png:
                    log.info("got empty image")
                    return {"error": "empty image"}, 400
            except Exception as decode_error:
                log.exception(decode_error)
                return {"error": str(decode_error)}, 400

            try:
                md5_id = hashlib.md5(png).hexdigest()
                fname = "{}.png".format(md5_id)

                try:
                    with time_it(self._unistat, self.SNAPSHOTS_UPLOAD_TIME_SIGNAL):
                        conn = S3Connection(host=self._s3_endpoint, aws_access_key_id=self._s3_access_key_id,
                                            aws_secret_access_key=self._s3_secret_access_key)
                        bucket = conn.get_bucket(S3_BUCKET_NAME, validate=False)
                        key = bucket.new_key(fname)
                        key.set_contents_from_string(png, headers={"Content-Type": "image/png"})
                except Exception:
                    self._unistat.push(self.S3_ERRORS_SIGNAL, 1)
                    raise

                self._unistat.push(self.SNAPSHOTS_UPLOADED_SIGNAL, 1)
                log.info("snapshot {} was uploaded to {} at {}".format(fname, S3_BUCKET_NAME, self._s3_endpoint))

                return {"url": make_bucket_url(self._s3_endpoint, fname)}, 200
            except Exception as e:
                log.exception(e)
                return {"error": str(e)}, 500


class ImgHandler(tornado.web.RequestHandler):
    REDIRECT_TO_NO_IMAGE_SIGNAL = "img_requests_redirected_to_no_image"
    REDIRECT_TO_S3_SIGNAL = "img_requests_redirected_to_s3"
    S3_ERRORS_SIGNAL = "img_s3_request_errors"
    S3_HEAD_REQUEST_TIME = "img_s3_head_request_time"

    def initialize(self, s3_endpoint, s3_access_key_id, s3_secret_access_key, unistat):
        self._s3_endpoint = s3_endpoint
        self._s3_access_key_id = s3_access_key_id
        self._s3_secret_access_key = s3_secret_access_key
        self._unistat = unistat

    @classmethod
    def prepare_unistat_signals(cls, unistat):
        unistat.create_float(cls.REDIRECT_TO_NO_IMAGE_SIGNAL)
        unistat.create_float(cls.REDIRECT_TO_S3_SIGNAL)
        unistat.create_float(cls.S3_ERRORS_SIGNAL)
        unistat.create_histogram(cls.S3_HEAD_REQUEST_TIME)

    @tornado.gen.coroutine
    def get(self, filename, *args, **kwargs):
        redirect_to_no_image = yield self._file_not_in_s3(filename)
        if redirect_to_no_image:
            log.warning("redirecting {} to {}".format(filename, NO_IMAGE_NAME))
            redirect_to = make_bucket_url(self._s3_endpoint, NO_IMAGE_NAME)
            self._unistat.push(self.REDIRECT_TO_NO_IMAGE_SIGNAL, 1)
        else:
            redirect_to = make_bucket_url(self._s3_endpoint, filename)
            self._unistat.push(self.REDIRECT_TO_S3_SIGNAL, 1)
        self.set_header("Location", redirect_to)
        self.set_status(302)

    @tornado.gen.coroutine
    def _file_not_in_s3(self, filename):
        with time_it(self._unistat, self.S3_HEAD_REQUEST_TIME):
            http_response = yield AsyncHTTPClient().fetch(
                make_bucket_url(self._s3_endpoint, filename),
                method="HEAD",
                request_timeout=S3_HEAD_REQUEST_TIMEOUT,
                raise_error=False)
        if http_response.code >= 500:
            self._unistat.push(self.S3_ERRORS_SIGNAL, 1)
        raise tornado.gen.Return(http_response.code == 404)


class StatsHandler(tornado.web.RequestHandler):
    def initialize(self, unistat):
        self._unistat = unistat

    @tornado.gen.coroutine
    def get(self, *args, **kwargs):
        self.set_header("Content-Type", "application/json")
        self.write(self._unistat.to_json(all_signals=True))


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--port",
                        dest="http_port",
                        type=int,
                        required=True,
                        help="http_port")

    parser.add_argument("--s3-endpoint",
                        dest="s3_endpoint",
                        type=str,
                        default="s3.mdst.yandex.net",
                        help="s3 endpoint to use")

    parser.add_argument("--log",
                        dest="log_path",
                        type=str,
                        required=True,
                        help="path to log file")

    return parser.parse_args()


def init_log(log_file):
    level = logging.INFO
    logging.root.setLevel(level)
    handler = logging.handlers.TimedRotatingFileHandler(
        filename=log_file,
        when="midnight",
        backupCount=5
    )
    handler.setLevel(level)
    handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s (%(module)s) %(message)s"))
    logging.root.addHandler(handler)


def main():
    args = parse_args()

    init_log(args.log_path)

    s3_access_key_id = os.environ.get(S3_ACCESS_KEY_ID_ENV_VARIABLE)
    s3_secret_access_key = os.environ.get(S3_SECRET_ACCESS_KEY_ENV_VARIABLE)

    if not s3_access_key_id or not s3_secret_access_key:
        sys.exit("S3 access keys are not specified.")

    log.info("Starting snapshoter.")

    upload_executor = ThreadPoolExecutor(4)
    unistat = Unistat()
    SnapshotHandler.prepare_unistat_signals(unistat)
    ImgHandler.prepare_unistat_signals(unistat)

    app = tornado.web.Application([
        (r"/snapshot/?", SnapshotHandler, dict(
            upload_executor=upload_executor,
            s3_endpoint=args.s3_endpoint,
            s3_access_key_id=s3_access_key_id,
            s3_secret_access_key=s3_secret_access_key,
            unistat=unistat
        )),
        (r"/ready/?", ReadyHandler),
        (r"/img/(?P<filename>[a-z0-9]*\.png)", ImgHandler, dict(
            s3_endpoint=args.s3_endpoint,
            s3_access_key_id=s3_access_key_id,
            s3_secret_access_key=s3_secret_access_key,
            unistat=unistat
        )),
        (r"/stats/?", StatsHandler, dict(
            unistat=unistat
        ))
    ])
    app.listen(args.http_port)

    tornado.ioloop.IOLoop.current().start()
