import os
import re
import cgi
import json
import time
import uuid
import urllib
import httplib
import logging
import urlparse
import traceback
import collections
import distutils.util
import datetime as dt
import cStringIO as sio

import requests

import tornado.web
import tornado.ioloop

from sandbox import common
import sandbox.common.types.user as ctu
import sandbox.common.types.misc as ctm
import sandbox.common.types.database as ctd

from sandbox.yasandbox import context
from sandbox.yasandbox import controller
from sandbox.yasandbox.database import mapping

import sandbox.serviceapi.web.exceptions

from sandbox.web import api
from sandbox.web import helpers
from sandbox.web import response
from sandbox.web import controller as web_controller
from sandbox.web.server import workers


logger = logging.getLogger(__name__)


class SandboxRequestMeasure(common.patterns.Abstract):
    """
    Request's performance-related data:

    - `received`: when the request is accepted by a server, datetime (UTC)

    - `processing_time`: for how long the request has been being processed, milliseconds

    - `completed`: when the request was replied to, datetime (UTC)

    - `high_priority`: whether the request had a high priority (came from UI)

    - `code`: HTTP response code
    """

    __slots__ = ("received", "processing_time", "completed", "high_priority", "code")
    __defs__ = (None, None, None, False, None)

    def __init__(self):
        super(SandboxRequestMeasure, self).__init__()
        self.received = dt.datetime.utcnow()

    @property
    def duration(self):
        return ((self.completed if self.completed else dt.datetime.utcnow()) - self.received).total_seconds()


TaskSession = collections.namedtuple("TaskSession", ("client", "task", "vault", "token"))


class SandboxRequest(object):
    Source = mapping.Audit.RequestSource
    settings = common.config.Registry().server
    production = common.config.Registry().common.installation == ctm.Installation.PRODUCTION

    def __init__(
        self,
        remote_ip,
        headers,
        method,
        path,
        raw_data='',
        target=(),
        client_address='',
        params=None,
        cookies=None,
        is_binary=False,
        profid=None,
        query=None,
    ):
        params = params or {}
        cookies = cookies or {}
        self.remote_ip = remote_ip
        self.headers = headers
        self.path = path.split('?')[0]
        if not self.production and self.headers.get("Referer", "").startswith("https"):
            self.uri = 'https://{}{}'.format(target[1], self.path)
        else:
            scheme = self.headers.get(ctm.HTTPHeader.FORWARDED_SCHEME, target[0])
            self.uri = '{}://{}{}'.format(scheme, target[1], self.path)
        self.raw_path = path
        self.raw_data = raw_data
        self.client_address = client_address
        self.id = self.headers.get(ctm.HTTPHeader.REQUEST_ID) or ('_' + uuid.uuid4().hex[1:])
        self.internal_id = uuid.uuid4().hex
        self.session = None
        self.token_source = None

        self.method = method
        self.params = params
        self.cookies = cookies
        self.is_binary = is_binary
        self.profid = profid
        self.query = query

        timeout = headers.get(ctm.HTTPHeader.REQUEST_TIMEOUT)
        self.timeout = timeout if timeout is None else float(timeout)

        if headers.get(ctm.HTTPHeader.CURRENT_USER):
            self.user = controller.User.get(headers[ctm.HTTPHeader.CURRENT_USER])
        else:
            self.user = controller.User.anonymous
        if ctm.HTTPHeader.REQUEST_SOURCE in headers:
            self.source = headers[ctm.HTTPHeader.REQUEST_SOURCE].lower()
        else:
            self.source = self.Source.API
        self.measures = SandboxRequestMeasure()
        self.remote_method = None
        self.handler = None
        self.ctx = None

        self.quota_owner = self.user.login
        if 'Authorization' in self.headers:
            # Try to authorize user with OAuth token if he has no session cookies.
            authorization_header = self.headers['Authorization']
            oauth_token = next(iter(authorization_header.split()[1:]), None)
            oauth = mapping.OAuthCache.objects(token=oauth_token).first()

            if not oauth:
                raise response.HttpErrorResponse(
                    httplib.GONE, "Authorization failed after proxying to legacy: there is no such token in the cache."
                )

            self.oauth = oauth

            self.token_source = oauth.source
            if self.source == self.Source.TASK:
                self.quota_owner = oauth.owner
                self.token_source = ctu.TokenSource.CLIENT
                self.session = TaskSession(oauth.source.split(":")[1], oauth.task_id, oauth.vault, oauth_token)

        if self.user.super_user and ctm.HTTPHeader.FORWARDED_USER in headers:
            euser = controller.User.get(headers.get(ctm.HTTPHeader.FORWARDED_USER))
            self.user = euser if euser else self.user

        self.read_preference = self.settings.mongodb.default_read_preference
        if self.user is not controller.User.anonymous:
            if self.source not in (self.Source.API, self.Source.RPC) or self.user.super_user:
                self.read_preference = ctd.ReadPreference.PRIMARY
            else:
                rp = self.headers.get(ctm.HTTPHeader.READ_PREFERENCE)
                if rp:
                    self.read_preference = getattr(ctd.ReadPreference, rp.upper(), self.read_preference)
        elif self.source == self.Source.TASK:
            self.read_preference = ctd.ReadPreference.PRIMARY_PREFERRED

        # Create default context for current request
        # No logger because logger is non-serializable
        self.ctx = context.Context(None, self.user)

    @common.utils.singleton_property
    def query_string(self):
        return urllib.urlencode(sorted(self.params.items()))

    def response_ctx(self):
        return {'request': self}

    @property
    def is_authenticated(self):
        return not self.settings.auth.enabled or self.user != controller.User.anonymous

    def __contains__(self, key):
        return key in self.params

    def __getitem__(self, key):
        res = self.params.get(key, [])
        if self.is_binary:
            return res
        return [v.decode('utf-8') for v in res]

    def __setitem__(self, key, value):
        self.params[key] = [value]

    def __repr__(self):
        return "Request({}, {}, {}, {}, '{}', {!r})".format(
            self.id, self.user.login, self.remote_ip, self.method, self.path, self.params
        )

    def get(self, key, default_value=None):
        values = self[key]
        if values:
            return values[0]
        else:
            return default_value

    def getlist(self, key):
        return self[key]

    def getint(self, key, default_value=0):
        try:
            return int(self.get(key))
        except (ValueError, TypeError):
            return default_value

    def getfloat(self, key, default_value=0.):
        try:
            return float(self.get(key))
        except (ValueError, TypeError):
            return default_value

    def getboolean(self, key, default_value=False):
        v = self.get(key, default_value='').strip().lower()
        if not v:
            return False

        bool_dict = {
            'false': False,
            '0': False,
            'undefined': False,
            'true': True,
            '1': True,
        }
        return bool_dict.get(v, default_value)

    def getcookie(self, key):
        return self.cookies.get(key)

    def has_cookie(self, key):
        return key in self.cookies

    @property
    def measure_stat_data(self):
        res = {
            'id': self.id,
            'code': self.measures.code,
            'raddr': self.client_address,
            'login': self.user.login,
            'source': self.source.lower(),
            'duration': self.measures.duration,
            'received': helpers.utcdt2iso(self.measures.received),
            'remote_method': self.remote_method or self.path,
            'high_priority': self.measures.high_priority,
        }
        return res


class LocalSandboxRequest(SandboxRequest):
    """
    This class represents a HTTP request sent locally on the serverside.
    For more information, see `controller.dispatch.RestClient`.
    """

    RES_ATTR_RE = re.compile("^" + api.v1.Api.basePath + api.v1.resource.ResourceAttribute.path.format(id="\d+"))

    def __init__(self, method, path, params, headers, task_id, author, jailed=True):
        """
        :param method: HTTP request type
        :param path: endpoint to send the request into
        :param params: request parameters (a query or body dictionary)
        :param headers: HTTP headers
        :param task_id: identifier of a task which performs the request, if there's one
        :param author: request author (typically that of a task)
        :param jailed: disallow entity creation (= POST method)
        """

        path_prefix = api.v1.Api.basePath
        if not path.startswith(path_prefix):
            path = path_prefix + path
        if jailed and method == ctm.RequestMethod.POST:
            if not (
                path.startswith(path_prefix + "/task/current/") or
                path.startswith(path_prefix + "/task/{}/".format(task_id)) or
                self.RES_ATTR_RE.match(path)
            ):
                raise ValueError(
                    "Creation via REST API is disabled on the serverside (in on_create, on_save and on enqueue hooks)"
                )

        remote_ip = "127.0.0.1"
        client_address = "127.0.0.1"
        cookies = None
        is_binary = False
        profid = None
        self.__request = request = getattr(mapping.base.tls, "request", None)
        headers = dict(headers or {})
        if request:
            # Use headers from the original request AND headers passed via dispatched client
            request_headers = dict(request.headers or {})
            request_headers.update(headers)
            headers = request_headers
            cookies = dict(request.cookies or {})
            remote_ip = request.remote_ip
            client_address = request.client_address
            is_binary = request.is_binary
            profid = request.profid
        if not params:
            params = {}
        data = params.get("data")
        params = params.get("params") or {}
        for name, value in params.iteritems():
            if isinstance(value, dict):
                params[name] = json.dumps(value)
        SandboxRequest.__init__(
            self,
            remote_ip=remote_ip,
            headers=headers,
            method=method,
            path=path,
            raw_data=data,
            target=("http", "localhost"),
            client_address=client_address,
            params={k: map(str, common.utils.chain(v)) for k, v in params.iteritems()},
            cookies=cookies,
            is_binary=is_binary,
            profid=profid,
            query=urllib.urlencode(params) if method == ctm.RequestMethod.GET else None
        )
        self.session = TaskSession(common.config.Registry().this.id, task_id, "", "") if task_id else None
        if self.__request:
            self.read_preference = self.__request.read_preference
        if author and (not request or not jailed):
            user = controller.User.get(author)
            if user:
                self.user = user
            else:
                self.user = controller.User.anonymous

    def restore(self):
        mapping.base.tls.request = self.__request
        if self.__request:
            context.set_current(self.__request.ctx)


class BaseSandboxHandler(tornado.web.RequestHandler):
    """ Abstract base request handler class. """
    settings = common.config.Registry()

    def _service_headers(self):
        # Service state headers
        self.set_header(ctm.HTTPHeader.BACKEND_NODE, self.settings.this.fqdn)
        sm = controller.Settings
        mode = sm.mode()
        self.set_header(ctm.HTTPHeader.SERVICE_MODE, mode.lower())
        if mode != sm.OperationMode.NORMAL:
            if sm.state != sm.DatabaseState.OK:
                host = sm.model.updates.executor
            else:
                host = sm.model.operation_mode_set_by

            if host:
                self.set_header(ctm.HTTPHeader.DB_LOCKED_BY, host)

    @property
    def remote_ip(self):
        # ServiceApi appends 0.0.0.0 to X-Forwarded-For header, and we need to skip it.
        # Tornado >=4.5 has `trusted_downstream` parameter, which allows to filter out some ip addresses,
        # but we're using 4.3, so we have to do it manually.
        ip = self.request.headers.get(ctm.HTTPHeader.FORWARDED_FOR, None)
        if ip:
            for ip in reversed(ip.split(",")):
                ip = ip.strip()
                if ip == "0.0.0.0":
                    continue
                return ip
        return self.request.connection.context._orig_remote_ip


class SandboxHandler(BaseSandboxHandler):
    _NormalizeURL = re.compile(r'[/\\]+')
    reply_headers = {}
    http_responses = httplib.responses.copy()
    http_responses.update({
        422: "Unprocessable Entity"
    })

    class Profile(common.patterns.Abstract):
        __slots__ = ("id", "collector", "started")
        __defs__ = [None] * 3

    def send_ok(self, content_type='text/html', content='', status=200):
        self.set_status(status, reason=self.http_responses.get(status))
        # Common headers
        self.set_header('Content-Type', content_type + '; charset=utf-8')
        for kv in self.reply_headers.iteritems():
            self.set_header(*kv)

        self._service_headers()
        self.write(content)

    def send_data(self, content_type, fp):
        self.set_status(200)
        self.set_header('Content-Type', content_type + '; charset=utf-8')

        while fp:
            chunk = fp.read(2048)
            if not chunk:
                break
            self.write(chunk)
            self.flush()

    def send_redirect(self, url, set_content_disposition=False):
        """Sends a redirect to the given (optionally relative) URL."""
        if self._headers_written:
            raise common.errors.ViewError("Cannot redirect after headers have been written")
        self.set_status(httplib.FOUND)
        # Remove whitespace
        if isinstance(url, unicode):
            url = url.encode("utf-8")
        url = re.sub(r"[\x00-\x20]+", "", url)
        self._service_headers()
        self.set_header("Location", urlparse.urljoin(self.request.uri, url))
        for kv in self.reply_headers.iteritems():
            self.set_header(*kv)
        if set_content_disposition:
            filename = urlparse.urlsplit(url).path.split('/')[-1]
            self.set_header("Content-Disposition", 'attachment; filename=\"{0}\";'.format(filename))

    def _url_args(self):
        return urlparse.parse_qs(self.request.query)

    def _get_cookies(self):
        res = {}
        for k, v in self.cookies.items():
            res[k] = v.value
        return res

    def _parse_request(self, method=ctm.RequestMethod.GET):
        if method in (ctm.RequestMethod.GET, ctm.RequestMethod.HEAD):
            return self._url_args(), "", False
        if method == ctm.RequestMethod.DELETE:
            return {}, self.request.body, False
        if method not in (ctm.RequestMethod.POST, ctm.RequestMethod.PUT):
            return {}, "", False

        content_type = self.request.headers.get('content-type')
        if content_type is None:
            raise response.HttpErrorResponse(httplib.BAD_REQUEST, 'Content-type header is missed')
        content_type, pdict = cgi.parse_header(content_type)

        raw_data = self.request.body
        request_dict = {}
        is_binary = False
        if 'application/x-www-form-urlencoded' in content_type:
            request_dict = urlparse.parse_qs(raw_data)

        elif 'multipart/form-data' in content_type:
            request_dict = cgi.parse_multipart(sio.StringIO(raw_data), pdict)
            is_binary = True

        request_dict.update(self._url_args())
        return request_dict, raw_data, is_binary

    def _process(self, hp):
        # Drop incoming requests if server is overloaded
        if (
            len(workers.Workers().in_progress) > self.settings.server.web.max_requests_in_progress and
            not self.request.path.startswith("/api/v1.0/service")
        ):
            self.reply(response.HttpErrorResponse(httplib.SERVICE_UNAVAILABLE, "Too many requests in progress"))
            return False

        pr = None
        mapping.base.tls.request = None
        if self.settings.server.profiler.performance.enabled:
            if any(re.search(_, self.request.path) for _ in self.settings.server.profiler.performance.paths):
                pr = self.Profile()
                pr.id = self.request.headers.get(ctm.HTTPHeader.REQUEST_ID)
                if not pr.id:
                    pr.id = uuid.uuid4().hex
                if hp:
                    pr.started = time.time()
                    pr.collector = common.profiler.Profiler()
                    pr.collector.enable()
                setattr(self, "profid", pr.id)

        path = '/'.join(self._NormalizeURL.split(self.request.path.split('?')[0]))

        req = None
        try:
            params, raw_data, is_binary = self._parse_request(self.request.method)
            request_host = self.request.host
            if ctm.HTTPHeader.REQUEST_TARGET in self.request.headers:
                request_host = urlparse.urlparse(self.request.headers[ctm.HTTPHeader.REQUEST_TARGET]).netloc
            req = SandboxRequest(
                remote_ip=self.remote_ip,
                headers=dict(self.request.headers),
                method=self.request.method,
                path=self.request.path,
                raw_data=raw_data,
                target=(self.request.protocol, request_host),
                client_address=self.remote_ip,
                params=params,
                cookies=self._get_cookies(),
                is_binary=is_binary,
                profid=pr.id if pr and not hp else None,
                query=self.request.query,
            )

            self.id = req.id
            self.user = req.user
            self.source = req.source
            hp = self.hp = req.source == req.Source.WEB or self.request.headers.get("X-High-Priority")

            self.set_header(ctm.HTTPHeader.CURRENT_USER, self.user.login)
            m = web_controller.dispatch(path, req)
            if not callable(m):
                raise response.HttpErrorResponse(httplib.NOT_FOUND, 'url not found %s' % path)
            req.measures.high_priority = bool(hp)

            workers.Workers().in_progress[req.id] = req
            return (
                self.reply(m(req), req)
                if "service/status" in req.path else
                workers.Workers().async_reply(self, hp, m, req)
            )

        except response.HttpResponseBase as reply:
            self.reply(reply, req)

        except sandbox.serviceapi.web.exceptions.RETRIABLE_EXCEPTIONS as ex:
            if mapping.is_query_error(ex):
                http_code = httplib.BAD_REQUEST
            else:
                http_code = httplib.SERVICE_UNAVAILABLE

            if isinstance(ex, mapping.AutoReconnect):
                controller.Settings.on_master_lost()

            self.reply(
                response.HttpErrorResponse(http_code, traceback.format_exc(100), 'text/plain'), req,
            )

        # While ServiceAPI protects us from excessively large requests,
        # XMLRPC seems to double the body size (see: test__large_xmlrpc_request)
        except common.errors.DataSizeError as ex:
            logger.exception(
                "Request %r to path %s is rejected due to its large size", req and req.id, path
            )
            if req is not None:
                ret = workers.make_reply_to_large_request(req.source, ex)
                self.reply(ret, workers.FakeRequest(req))

        except Exception:
            logger.exception("Error processing request %s", req and req.id)
            message = traceback.format_exc(100)
            message += "\nServer: " + self.settings.this.fqdn
            self.reply(response.HttpErrorResponse(httplib.INTERNAL_SERVER_ERROR, message, 'text/plain'), req)

        finally:
            if pr and pr.collector:
                pr.collector.disable()
                th = self.settings.server.profiler.performance.threshold
                ts = int((time.time() - pr.started) * 1000)
                if not th or ts >= th:
                    path = os.path.join(self.settings.server.profiler.performance.data_dir, '{}_{}'.format(ts, pr.id))
                    pr.collector.dump_to_file(path)
                    logger.info("Request performance profile dump saved to '%s'", path)
        return False

    def reply(self, reply, req=None):
        if req and req.measures.processing_time:
            self.set_header(ctm.HTTPHeader.REQ_DURATION, str(float(req.measures.processing_time) / 1000))
            self.set_header(ctm.HTTPHeader.REQ_METRICS, req.ctx.spans.as_short_string(group_by_name=True))
            self.set_header(ctm.HTTPHeader.REQUEST_MEASURES, json.dumps({
                "method": req.handler,
                "mongodb_duration": req.ctx.spans.total_duration_for("mongodb_op") * 1000,
                "serviceq_duration": req.ctx.spans.total_duration_for("serviceq_call") * 1000,
            }))

        if isinstance(reply, response.HttpResponseBase):
            for name, value in reply.headers.iteritems():
                self.set_header(name, value)
            if isinstance(reply, response.HttpResponse):
                self.send_ok(reply.content_type, content=reply.content, status=reply.code)
            elif isinstance(reply, response.HttpRedirect):
                self.send_redirect(reply.redirect_url, reply.set_content_disposition)
            elif isinstance(reply, response.HttpErrorResponse):
                self.send_ok(reply.content_type, reply.msg, status=reply.code)
        else:
            self.send_ok('text/html', content=reply)
        tornado.ioloop.IOLoop.instance().add_callback(self._finish_request)

        if req:
            setattr(self, "rp", req.read_preference[0])
            workers.Workers().in_progress.pop(req.id, None)
            req.measures.completed = dt.datetime.utcnow()
            req.measures.code = self.get_status()
            workers.Workers().last_requests << req

        return True

    def _finish_request(self):
        if self.request.connection.stream.closed():
            return
        if self._finished:
            return
        try:
            self.finish()
        except IOError:
            logger.exception('Error in _finish_request')

    def __get(self):
        path = self.request.path
        referer = self.request.headers.get("Referer")
        force_legacy_ui = self.request.cookies.get("force_legacy_ui")
        force_legacy_ui = 'false' if not force_legacy_ui else force_legacy_ui.value
        # Support for legacy UI links from external sources
        if (
            path.startswith('/sandbox') and
            not (referer and urlparse.urlparse(referer).netloc == self.request.host) and
            not distutils.util.strtobool(force_legacy_ui)
        ):
            qpath = path
            if self.request.query:
                qpath += "?" + self.request.query
            logging.info(
                "Processing legacy UI redirect path %r, referer %r (%r), host: %r",
                qpath, referer, urlparse.urlparse(referer or '').hostname, self.request.host
            )
            _ALIASES = (
                (re.compile(r'^/sandbox/resources/redirect\?resource_id=(\d+)(?:&relpath=(.+))?'), None),
                (re.compile(r'^/sandbox/scheduler/\w+\?scheduler_id=(\d+)'), r'/scheduler/{0}/view'),
                (re.compile(r'^/sandbox/resources/\w+\?resource_id=(\d+)'), r'/resource/{0}/view'),
                (re.compile(r'^/sandbox/tasks/\w+\?task_id=(\d+)'), r'/task/{0}/view'),
                (re.compile(r'^/sandbox/scheduler/list'), r'/schedulers'),
                (re.compile(r'^/sandbox/resources/list'), r'/resources'),
                (re.compile(r'^/sandbox/releases/list'), r'/releases'),
                (re.compile(r'^/sandbox/clients/list'), r'/clients'),
                (re.compile(r'^/sandbox/tasks/list'), r'/tasks'),
                (re.compile(r'^/sandbox/oauth/token'), r'/oauth'),
                (re.compile(r'^/sandbox/admin/vault'), r'/admin/vault'),
                (re.compile(r'^/sandbox/admin/groups'), r'/admin/groups'),
            )
            for regex, pattern in _ALIASES:
                m = regex.match(qpath)
                if m:
                    if not pattern:
                        break
                    self.reply(response.HttpRedirect(pattern.format(*m.groups())))
                    return

        return workers.Workers().put(self._process, False)

    @tornado.web.asynchronous
    def get(self):
        return self.__get()

    @tornado.web.asynchronous
    def head(self):
        return self.__get()

    @tornado.web.asynchronous
    def options(self):
        return self.__get()

    @tornado.web.asynchronous
    def post(self):
        return workers.Workers().put(self._process, bool(self.request.cookies or "Referer" in self.request.headers))

    @tornado.web.asynchronous
    def put(self):
        return workers.Workers().put(self._process, bool(self.request.cookies or "Referer" in self.request.headers))

    @tornado.web.asynchronous
    def delete(self):
        return workers.Workers().put(self._process, bool(self.request.cookies or "Referer" in self.request.headers))


class MDSProxyHandler(BaseSandboxHandler):
    """ This handler proxies all requests to static files (only favicon and index) from MDS S3 storage. """

    _suffix = ""

    def __init__(self, application, request, **kwargs):
        self.mds_url = self.settings.server.web.static.mds_s3_url
        super(MDSProxyHandler, self).__init__(application, request, **kwargs)

    def initialize(self, installation_type=None):
        if installation_type and installation_type in (ctm.Installation.PRE_PRODUCTION, ctm.Installation.LOCAL):
            self.__class__._suffix = installation_type.lower()
        super(MDSProxyHandler, self).initialize()

    def _get(self, head=False):
        path = self.request.path.lstrip("/")
        if not path.startswith("favicon"):
            path = "index.html"
        if self._suffix and (path != "index.html" or self._suffix != "local"):
            basename, ext = path.split(".")
            path = ".".join([basename, self._suffix, ext])
        url = "/".join([self.mds_url, path])
        logging.info("Proxy %s from %r", "HEAD" if head else "content", url)
        hdrs = dict(self.request.headers)
        map(hdrs.pop, [k for k in hdrs if k.lower() == "host"])
        r = (requests.head if head else requests.get)(url, headers=hdrs, timeout=60)
        self.set_status(r.status_code)
        for k, v in r.headers.iteritems():
            if k.lower() not in ("content-length", "content-encoding", "transfer-encoding", "connection"):
                self.set_header(k, v)
        self._service_headers()
        if not head and r.status_code != httplib.NOT_MODIFIED:
            self.write(r.content)

    def get(self):
        return self._get()

    def head(self):
        return self._get(True)
