from __future__ import print_function

import logging
import socket

import grpc
import hashlib

from ci.api_client.py import ci_yp_service_discovery
from ci.proto import internal_api_pb2_grpc
from ci.tasklet.common.proto.service_pb2 import CiEnv
from ci.tasklet.common.proto.service_pb2 import GetCommitsResponse
from google.protobuf import empty_pb2, json_format
from grpc._channel import _InactiveRpcError
from ratelimit import limits
from ratelimit import sleep_and_retry
from tasklet import runtime
from tasklet.services.ci.proto import ci_pb2, ci_pb2_grpc

logger = logging.getLogger(__name__)


def get_ci_env(tasklet_context):
    if tasklet_context and tasklet_context.ci_url:
        if not tasklet_context.ci_url.startswith('https://a.yandex-team.ru/'):
            return CiEnv.CI_TESTING
    return CiEnv.CI_STABLE


class CiServiceLogging(ci_pb2_grpc.CiServicer):
    def GetContext(self, request, context):
        return ci_pb2.CiContext()

    def Inject(self, request, context):
        runtime.inject(request, self)
        return empty_pb2.Empty()

    def UpdateProgress(self, request, context):
        logger.info('UpdateProgress request:\n%s', json_format.MessageToJson(request))
        return empty_pb2.Empty()

    def GetCommits(self, request, context):
        logger.info('GetCommits request:\n%s', json_format.MessageToJson(request))
        response = GetCommitsResponse()
        authors = ('pochemuto', 'teplosvet', 'anmakon')
        commits = ('CI-1687 yav access demo', 'MBO-13312 finishing', 'CI-2225 fix ConcurrentTest')

        for n, message in enumerate(commits):
            commit = response.commits.add()
            commit.message = message
            commit.revision.number = 8402296 - n * 7
            commit.revision.hash = self._hash(message)
            commit.issues.append(message[:message.index(' ')])
            commit.author = authors[n % 3]

        logger.info('GetCommits response:\n%s', json_format.MessageToJson(response))
        return response

    def _hash(self, message):
        sha = hashlib.sha1()
        sha.update(message.encode())
        return sha.hexdigest()


class CiService(ci_pb2_grpc.CiServicer):
    def __init__(self):
        self.client_name = 'tasklet-ci-service:{}'.format(socket.gethostname())
        self.endpoints = None
        self.client = None

    def GetContext(self, request, context):
        return ci_pb2.CiContext()

    def Inject(self, request, context):
        runtime.inject(request, self)

        return empty_pb2.Empty()

    def UpdateProgress(self, request, context):
        return self._update_progress(request)

    @sleep_and_retry
    @limits(calls=10, period=1)
    def _update_progress(self, request):
        logger.info('UpdateProgress request:\n%s', json_format.MessageToJson(request))

        if not request.HasField('job_instance_id'):
            raise ValueError('job_instance_id is missing in TaskletProgress. Please fill it from context')

        if request.job_instance_id.flow_launch_id == '':
            logger.info('job_instance_id is empty. Skip UpdateProgress')
            return empty_pb2.Empty()

        def action(client):
            client.UpdateTaskletProgress(request)
            logger.info('UpdateProgress sent successfully')
            return empty_pb2.Empty()

        return self._execute(action, request.ci_env)

    def GetCommits(self, request, context):
        return self._get_commits(request)

    @sleep_and_retry
    @limits(calls=10, period=1)
    def _get_commits(self, request):
        logger.info('GetCommits request:\n%s', json_format.MessageToJson(request))

        def action(client):
            response = client.GetCommits(request)
            logger.info('GetCommits response:\n%s', json_format.MessageToJson(response))
            return response

        return self._execute(action, request.ci_env)

    def _execute(self, action, env, retries=5):
        ret = 0
        while True:
            try:
                if not self.client:
                    if not self.endpoints:
                        self.endpoints = ci_yp_service_discovery.get_all_endpoints(
                            client_name=self.client_name,
                            cluster_names=self._get_cluster_names(env),
                            endpoint_set_id=self._get_endpoint_set_id(env)
                        )

                    channel = grpc.insecure_channel(self.endpoints.pop())
                    self.client = internal_api_pb2_grpc.InternalApiStub(channel)

                return action(self.client)
            except (grpc.RpcError, _InactiveRpcError) as e:
                logger.exception('Error when executing gRPC call')
                if getattr(e, 'status', grpc.StatusCode.UNKNOWN) in \
                        (grpc.StatusCode.INTERNAL,
                         grpc.StatusCode.NOT_FOUND,
                         grpc.StatusCode.INVALID_ARGUMENT,
                         grpc.StatusCode.UNIMPLEMENTED):
                    raise
                else:
                    ret += 1
                    if ret >= retries:
                        raise
                    else:
                        self.client = None  # Try another client anyway (or should we check status?)

    @staticmethod
    def _get_endpoint_set_id(env):
        if not env or env == CiEnv.CI_STABLE:
            return ci_yp_service_discovery.STABLE_CI_ENDPOINT_SET_ID
        else:
            return ci_yp_service_discovery.TESTING_CI_ENDPOINT_SET_ID

    @staticmethod
    def _get_cluster_names(env):
        if not env or env == CiEnv.CI_STABLE:
            return ci_yp_service_discovery.STABLE_CLUSTER_NAMES
        else:
            return ci_yp_service_discovery.TESTING_CLUSTER_NAMES
