import urllib
import datetime

import sandbox.common.tvm as tvm
import sandbox.common.types.misc as ctm
import sandbox.common.types.task as ctt
import sandbox.common.types.user as ctu
import sandbox.common.types.database as ctd

from sandbox.yasandbox.database import mapping
import sandbox.yasandbox.controller.user as user_controller

from sandbox import common

from . import exceptions

__all__ = ("Request",)


# User-ticket used on local installations. Not a real thing! Generated with tvmknife.
FAKE_USER_TICKET = (
    "3:user:CA0Q__________9_GhIKBAiUkQYQlJEGINKF2MwEKAE:N5aYTgsitYVMs24xyxEIiEAX"
    "77TuIP8hxWHM40G-GOdUWjla-aATuIs6xE97uLWI7PweMjTWauNejn_5RYPzlwJxDY1832GeabONwU1r-QCyyxywL"
    "KQyD09WGcY03ZTrqpK26tLbPI-g8lyAFQ_KADi7afFhTuWDyKa4Nl15Fu0"
)


class Request(object):
    def __init__(self, req, logger, authenticate=True):
        """
        :param req: `flask.Request` instance
        :type req: `flask.Request`
        :param logger: Logger for the current request
        :param authenticate: call Blackbox to authorize the user
        """

        self.req = req
        self.logger = logger

        self.source = ctt.RequestSource.API
        self.admin_user = None
        self.session = None
        self.user = None
        self.auth_method = ctt.AuthMethod.NONE
        self.authdata = None
        timeout = self.req.headers.get(ctm.HTTPHeader.REQUEST_TIMEOUT)
        self.timeout = timeout if timeout is None else float(timeout)
        self.profiler = None

        # Get session, user and authdata for current request
        if authenticate:
            self.authenticate_request()
        self.read_preference = self.get_read_preference_for_request()

        self.request_started = datetime.datetime.utcnow()
        self.proxy_started = self.request_started
        self.proxy_finished = self.request_started
        self.base_url = req.base_url
        if req.headers.get(ctm.HTTPHeader.FORWARDED_SCHEME):
            self.base_url = req.base_url.replace(req.scheme, req.headers.get(ctm.HTTPHeader.FORWARDED_SCHEME), 1)

    @common.utils.singleton_property
    def query_string(self):
        return urllib.urlencode(sorted(self.req.args.items(multi=True)))

    @property
    def remote_ip(self):
        return self.req.remote_ip

    @property
    def quota_owner(self):
        tvm_service = self.tvm_service
        if tvm_service:
            return "TVM_SERVICE_{}".format(tvm_service)
        if self.admin_user is not None:
            return self.admin_user.login
        login = self.user and self.user.login
        return self.session.owner if self.is_task_session else login

    @common.utils.singleton_classproperty
    def config(self):
        return common.config.Registry()

    @common.utils.singleton_classproperty
    def blackboxer(self):
        import blackboxer
        return blackboxer

    @common.utils.singleton_classproperty
    def blackbox(self):
        return self.blackboxer.Blackbox(self.config.server.auth.blackbox_url)

    def __getattr__(self, item):
        return getattr(self.req, item)

    @property
    def id(self):
        return self.req.req_id

    @property
    def is_authenticated(self):
        if not self.config.server.auth.enabled:
            return True
        return self.authdata and self.authdata["status"]["value"] in ("VALID", "NEED_RESET")

    @property
    def need_reset_cookie(self):
        return self.authdata and self.authdata["status"]["value"] == "NEED_RESET"

    @property
    def is_task_session(self):
        return self.session and self.session.is_task_session

    @property
    def is_external_session(self):
        return self.session and self.session.is_external_session

    @property
    def is_statistics_sender(self):
        return not self.is_task_session and self.user and ctu.Role.STATISTICS in self.user.roles

    @property
    def tvm_service(self):
        return self.authdata and self.authdata.get("service")

    def split_auth_header(self, auth_header):
        return next(iter(auth_header.split()[1:]), None)

    @common.utils.singleton_property
    def oauth_token(self):
        auth_header = self.req.headers[ctm.HTTPHeader.AUTHORIZATION]
        return self.split_auth_header(auth_header)

    @common.utils.singleton_property
    def user_ticket(self):
        return self.req.headers[ctm.HTTPHeader.USER_TICKET]

    @common.utils.singleton_property
    def service_ticket(self):
        return self.req.headers[ctm.HTTPHeader.SERVICE_TICKET]

    def user_from_token(self, auth_token):
        session = self.get_auth_session_from_token(auth_token)
        user = self.check_validated_user(
            session.login, user_controller.User.validate(
                session.login, avoid_staff_validation=session.is_task_session or session.is_external_session
            )
        )
        return session, user

    def authenticate_request(self):
        if self.config.server.auth.enabled:
            self.try_get_auth_from_cookie()
            if self.is_authenticated:
                self.source = ctt.RequestSource.WEB
                self.user = user_controller.User.validate(self.authdata["login"]).user
                self.auth_method = ctt.AuthMethod.COOKIE

        if ctm.HTTPHeader.AUTHORIZATION in self.req.headers:
            auth_token = self.oauth_token
            self.session, self.user = self.user_from_token(auth_token)
            self.authdata = {"status": {"value": "VALID"}, "login": self.session.login}

            if self.session.is_task_session:
                self.auth_method = ctt.AuthMethod.TASK
                self.source = ctt.RequestSource.TASK
            elif self.session.is_external_session:
                self.auth_method = ctt.AuthMethod.EXTERNAL_SESSION
            else:
                self.auth_method = ctt.AuthMethod.OAUTH

        if ctm.HTTPHeader.USER_TICKET in self.req.headers:
            if ctm.HTTPHeader.SERVICE_TICKET not in self.req.headers:
                raise exceptions.Unauthorized("Service ticket must be in headers with user ticket")
            try:
                tvm.TVM.check_service_ticket(self.service_ticket)
                check_result = tvm.TVM.check_user_ticket(self.user_ticket)
                uid = str(check_result["uids"][0])
                self.user = self.check_validated_user(uid, user_controller.User.validate_from_uid(uid))
                self.session = None
                self.auth_method = ctt.AuthMethod.TVM
                self.authdata = {"status": {"value": "VALID"}, "login": self.user.login}
            except (common.rest.Client.HTTPError, tvm.TVM.Error) as ex:
                raise exceptions.Unauthorized("Error in user ticket validation: {}".format(ex))

        elif ctm.HTTPHeader.SERVICE_TICKET in self.req.headers:
            try:
                check_result = tvm.TVM.check_service_ticket(self.service_ticket)
                service_id = check_result["src"]
                self.session = None
                self.auth_method = ctt.AuthMethod.TVM
                self.authdata = {"status": {"value": "VALID"}, "service": service_id}
            except (common.rest.Client.HTTPError, tvm.TVM.Error) as ex:
                raise exceptions.Unauthorized("Error in service ticket validation: {}".format(ex))

        if ctm.HTTPHeader.ADMIN_AUTHORIZATION in self.req.headers:
            auth_token = self.split_auth_header(self.req.headers[ctm.HTTPHeader.ADMIN_AUTHORIZATION])
            _, admin_user = self.user_from_token(auth_token)
            if admin_user.super_user:
                self.admin_user = admin_user
            else:
                raise exceptions.Forbidden("Token in {} must be admin token".format(ctm.HTTPHeader.ADMIN_AUTHORIZATION))

        if not self.user:
            self.user = user_controller.User.anonymous
        elif self.user.super_user and ctm.HTTPHeader.CURRENT_USER in self.req.headers:
            admin = self.user.login
            login = self.req.headers[ctm.HTTPHeader.CURRENT_USER]
            self.user = self.check_validated_user(login, user_controller.User.validate(login))
            if self.session:
                self.session.login = login
            if self.authdata:
                self.authdata["login"] = login
            self.logger.warning("Processing request %s from %r on behalf of %r", self.id, admin, login)

    def try_get_auth_from_cookie(self):
        sid = self.req.cookies.get("Session_id", None)
        if sid is None:
            return

        login = user_controller.User.get_login_by_sid(sid)
        if login:
            self.authdata = {"status": {"value": "VALID"}, "login": login}
            return

        try:
            data = self.blackbox.sessionid(self.remote_ip, sid, self.config.server.web.address.host)
        except self.blackboxer.ConnectionError as con_exc:
            self.logger.error("BlackboxConnectionError: %s", con_exc)
            login = user_controller.User.get_login_by_sid(sid, False)
            if login:
                self.authdata = {"status": {"value": "VALID"}, "login": login}
                return
            else:
                self.req.rejected_in_progress = True
                raise exceptions.ServiceUnavailable("Blackbox unavailable.")
        except self.blackboxer.BlackboxError as exc:
            self.logger.error("BlackboxError: %s", exc)
        else:
            if "login" in data:
                self.authdata = data
            else:
                self.logger.error("BlackboxError, invalid response: %s", data)

        if self.authdata:
            user_controller.User.set_session_id(self.authdata["login"], sid)

    def get_auth_session_from_token(self, token):
        try:
            cache = user_controller.OAuthCache.get(token)
            if cache:
                return cache
        except user_controller.OAuthCache.SessionExpired as ex:
            raise exceptions.Gone(str(ex))

        if not self.config.server.auth.enabled:
            raise exceptions.Unauthorized("Authorization disabled")

        try:
            data = self.blackbox.oauth(self.remote_ip, token)
        except self.blackboxer.ConnectionError as con_exc:
            self.logger.error("BlackboxConnectionError: %s", con_exc)
            cache = mapping.OAuthCache.objects(token=token).first()
            if cache is not None:
                return cache

            self.req.rejected_in_progress = True
            raise exceptions.ServiceUnavailable("Blackbox unavailable.")
        except self.blackboxer.BlackboxError as exc:
            self.logger.error("BlackboxError: %s", exc)
            raise

        if data["status"]["value"] != "VALID":
            raise exceptions.Forbidden(
                "OAuth-token '{}{}' is not valid. Authorization status: {}, error: {}".format(
                    str(token)[:8], "*" * max(len(str(token)) - 8, 3), data["status"], data["error"],
                )
            )

        required_scope = self.config.server.auth.oauth.required_scope
        if required_scope and required_scope not in data["oauth"]["scope"].split():
            raise exceptions.Forbidden(
                "OAuth-token is valid, but application '{}' has no scope '{}' in allowed '{}'".format(
                    data["oauth"]["client_id"], required_scope, data["oauth"]["scope"]
                )
            )

        try:
            return user_controller.OAuthCache.refresh(data["login"], token, app_id=data["oauth"]["client_id"])
        except ValueError as ex:
            raise exceptions.Forbidden(str(ex))

    @staticmethod
    def check_validated_user(login, validated_user, is_request_author=True):
        if not validated_user:
            exception_type = exceptions.Unauthorized if is_request_author else exceptions.BadRequest
            raise exception_type(
                "Unable to validate user '{}' with valid oauth token or user ticket via staff".format(login)
            )
        elif validated_user.is_dismissed:
            exception_type = exceptions.UserDismissed if is_request_author else exceptions.BadRequest
            raise exception_type(
                "Unable to validate DISMISSED user '{}' with valid oauth token or user ticket via staff".format(login)
            )

        return validated_user.user

    def get_read_preference_for_request(self):
        if self.user is not None and self.user is not user_controller.User.anonymous:
            if self.source != ctt.RequestSource.API or self.user.super_user:
                return ctd.ReadPreference.PRIMARY

            rp = self.req.headers.get(ctm.HTTPHeader.READ_PREFERENCE, None)
            if rp:
                return getattr(
                    ctd.ReadPreference, rp.upper(),
                    self.config.server.mongodb.default_read_preference
                )
        elif self.user is user_controller.User.anonymous and self.source == ctt.RequestSource.TASK:
            return ctd.ReadPreference.PRIMARY_PREFERRED

        return self.config.server.mongodb.default_read_preference

    def get_tvm_headers(self, dst, need_user_ticket=True):
        """
        Fetch TVM tickets and return them as a dictionary keyed with respective HTTP headers
        (`X-Ya-Service-Ticket` and `X-Ya-User-Ticket`). If unable to fetch, the tickets are omitted from the result.

        :param dst: destination alias (as configured in tvmtool)
        :param need_user_ticket: if True, request user ticket using credentials from the request
        :return: dictionary with successfully fetched TVM tickets
        """

        dsts = [dst]
        headers = {}

        need_get_user_ticket = need_user_ticket

        if self.auth_method == ctt.AuthMethod.TVM and need_user_ticket:
            headers[ctm.HTTPHeader.USER_TICKET] = self.user_ticket
            need_get_user_ticket = False

        if need_get_user_ticket:
            # Blackbox requires a separate service ticket
            dsts.append("blackbox")

        try:
            service_tickets = tvm.TVM.get_service_ticket(dsts)
        except tvm.TVM.Error as tvm_exc:
            self.logger.error("Could not get TVM service ticket: %s", tvm_exc)
            return {}

        headers[ctm.HTTPHeader.SERVICE_TICKET] = service_tickets[dst]

        if not need_get_user_ticket:
            return headers

        if not self.config.server.auth.enabled:
            headers[ctm.HTTPHeader.USER_TICKET] = FAKE_USER_TICKET
            return headers

        if self.auth_method not in ctt.AuthMethod.Group.BLACKBOX:
            return headers

        kwargs = dict(
            userip=self.remote_ip,
            host=self.config.server.web.address.host,
            get_user_ticket="yes",
            headers={ctm.HTTPHeader.SERVICE_TICKET: service_tickets["blackbox"]},
        )

        try:
            if self.auth_method == ctt.AuthMethod.COOKIE:
                data = self.blackbox.sessionid(
                    sessionid=self.req.cookies.get("Session_id", ""),
                    sslsessionid=self.req.cookies.get("sessionid2", ""),
                    **kwargs
                )
            elif self.auth_method == ctt.AuthMethod.OAUTH:
                data = self.blackbox.oauth(oauth_token=self.oauth_token, **kwargs)
            else:
                raise ValueError("Unsupported auth method: {}".format(self.auth_method))

            headers[ctm.HTTPHeader.USER_TICKET] = data["user_ticket"]

        except Exception:
            self.logger.warning("Could not fetch TVM2 user-ticket", exc_info=True)

        return headers
