import time
import ssl
import grpc
import random
import logging
import requests
from requests.adapters import HTTPAdapter

import greeter_service_pb2
import greeter_service_pb2_grpc

# add path with proto
import os
import sys
_PACKAGE_PATH = os.path.realpath(os.path.dirname(__file__))
sys.path.append(os.path.join(_PACKAGE_PATH, 'proto'))

try:
    from yandex.cloud.loadtesting.agent.v1 import trail_service_pb2, trail_service_pb2_grpc, \
        agent_registration_service_pb2, agent_registration_service_pb2_grpc
except ImportError:
    import trail_service_pb2
    import trail_service_pb2_grpc
    import agent_registration_service_pb2
    import agent_registration_service_pb2_grpc

try:
    from yandex.cloud.loadtesting.agent.v1 import test_service_pb2_grpc, test_service_pb2, test_pb2
except ImportError:
    import test_service_pb2_grpc
    import test_service_pb2
    import test_pb2

LOGGER = logging.getLogger(__name__)  # pylint: disable=C0103


class APIClient(object):
    REQUEST_ID_HEADER = 'X-Request-ID'

    def __init__(self):
        self._base_url = None

    @property
    def base_url(self):
        if not self._base_url:
            raise ValueError("Base url is not set")
        else:
            return self._base_url

    @base_url.setter
    def base_url(self, url):
        self._base_url = url

    class UnderMaintenance(Exception):
        message = "API is under maintenance"

    class NotAvailable(Exception):
        desc = "API is not available"

        def __init__(self, request, response):
            self.message = "%s\n%s\n%s" % (self.desc, request, response)
            super(self.__class__, self).__init__(self.message)

    class StoppedFromOnline(Exception):
        """http code 410"""
        message = "Shooting is stopped from online"

    class JobNotCreated(Exception):
        pass

    class NetworkError(Exception):
        pass

    def second_data_to_push_item(self, data, stat, timestamp, overall, case):
        """
        @data: SecondAggregateDataItem
        """
        api_data = {
            'overall': overall,
            'case': case,
            'net_codes': [],
            'http_codes': [],
            'time_intervals': [],
            'trail': {
                'time': str(timestamp),
                'reqps': stat["metrics"]["reqps"],
                'resps': data["interval_real"]["len"],
                'expect': data["interval_real"]["total"] / 1000.0 / data["interval_real"]["len"],
                'disper': 0,
                'self_load':
                    0,  # TODO abs(round(100 - float(data.selfload), 2)),
                'input': data["size_in"]["total"],
                'output': data["size_out"]["total"],
                'connect_time': data["connect_time"]["total"] / 1000.0 / data["connect_time"]["len"],
                'send_time':
                    data["send_time"]["total"] / 1000.0 / data["send_time"]["len"],
                'latency':
                    data["latency"]["total"] / 1000.0 / data["latency"]["len"],
                'receive_time': data["receive_time"]["total"] / 1000.0 / data["receive_time"]["len"],
                'threads': stat["metrics"]["instances"],  # TODO
            }
        }

        for q, value in zip(data["interval_real"]["q"]["q"],
                            data["interval_real"]["q"]["value"]):
            api_data['trail']['q' + str(q)] = value / 1000.0

        for code, cnt in data["net_code"]["count"].items():
            api_data['net_codes'].append({'code': int(code),
                                          'count': int(cnt)})

        for code, cnt in data["proto_code"]["count"].items():
            api_data['http_codes'].append({'code': int(code),
                                           'count': int(cnt)})

        api_data['time_intervals'] = self.convert_hist(data["interval_real"][
            "hist"])
        return api_data

    @staticmethod
    def convert_hist(hist):
        data = hist['data']
        bins = hist['bins']
        return [
            {
                "from": 0,  # deprecated
                "to": b / 1000.0,
                "count": count,
            } for b, count in zip(bins, data)
        ]

    def push_test_data(self):
        raise NotImplementedError

    def unlock_target(self, target):
        raise NotImplementedError


class LPRequisites():
    CONFIGINFO = ('configinfo.txt', 'configinfo')
    MONITORING = ('jobmonitoringconfig.txt', 'monitoringconfig')
    CONFIGINITIAL = ('configinitial.txt', 'configinitial')


class CloudGRPCClient(APIClient):

    class NotAvailable(Exception):
        pass

    def __init__(
            self,
            core_interrupted,
            base_url=None,
            api_attempts=10,
            api_timeout=0.5,
            connection_timeout=100.0):
        super().__init__()
        self.core_interrupted = core_interrupted
        self._base_url = base_url
        self.api_attempts = api_attempts
        self.connection_timeout = connection_timeout
        self.max_api_timeout = 120
        self._token = None
        creds = self._get_creds(self._base_url)
        self.channel = grpc.secure_channel(self._base_url, creds)
        self._cloud_instance_id = None
        self._set_connection(self.token, self.cloud_instance_id)

    @staticmethod
    def _get_creds(url):
        cert = ssl.get_server_certificate(tuple(url.split(':')))
        creds = grpc.ssl_channel_credentials(cert.encode('utf-8'))
        return creds

    @staticmethod
    def _get_call_creds(token):
        return grpc.access_token_call_credentials(token)

    # TODO retry
    def _set_connection(self, token, cloud_instance_id):
        try:
            stub_greeter = greeter_service_pb2_grpc.GreeterServiceStub(self.channel)
            response = stub_greeter.SayHello(
                greeter_service_pb2.SayHelloRequest(tank_id=cloud_instance_id),
                timeout=self.connection_timeout,
                metadata=[('authorization', f'Bearer {self.token}')]
                # credentials=self._get_call_creds(token)
            )
            if response.code == 0:
                self.stub = trail_service_pb2_grpc.TrailServiceStub(self.channel)
                self.test_stub = test_service_pb2_grpc.TestServiceStub(self.channel)
                LOGGER.info('Init connection to cloud load testing is succeeded')
        except Exception as e:
            raise self.NotAvailable(f"Couldn't connect to cloud load testing: {e}")

    @property
    def token(self):
        if self._token is None:
            self._token = get_iam_token()
        return self._token

    @property
    def cloud_instance_id(self):
        if self._cloud_instance_id is None:
            self._cloud_instance_id = get_current_instance_id()
        return self._cloud_instance_id

    def api_timeouts(self):
        for attempt in range(self.api_attempts - 1):
            multiplier = random.uniform(1, 1.5)
            yield min(2**attempt * multiplier, self.max_api_timeout)

    @staticmethod
    def build_codes(codes_data):
        codes = []
        for code in codes_data:
            codes.append(trail_service_pb2.Trail.Codes(code=int(code['code']), count=int(code['count'])))
        return codes

    @staticmethod
    def build_intervals(intervals_data):
        intervals = []
        for interval in intervals_data:
            intervals.append(trail_service_pb2.Trail.Intervals(to=interval['to'], count=int(interval['count'])))
        return intervals

    def convert_to_proto_message(self, items):
        trails = []
        for item in items:
            trail_data = item["trail"]
            trail = trail_service_pb2.Trail(
                overall=int(item["overall"]),
                case_id=item["case"],
                time=trail_data["time"],
                reqps=int(trail_data["reqps"]),
                resps=int(trail_data["resps"]),
                expect=trail_data["expect"],
                input=int(trail_data["input"]),
                output=int(trail_data["output"]),
                connect_time=trail_data["connect_time"],
                send_time=trail_data["send_time"],
                latency=trail_data["latency"],
                receive_time=trail_data["receive_time"],
                threads=int(trail_data["threads"]),
                q50=trail_data.get('q50'),
                q75=trail_data.get('q75'),
                q80=trail_data.get('q80'),
                q85=trail_data.get('q85'),
                q90=trail_data.get('q90'),
                q95=trail_data.get('q95'),
                q98=trail_data.get('q98'),
                q99=trail_data.get('99'),
                q100=trail_data.get('q100'),
                http_codes=self.build_codes(item['http_codes']),
                net_codes=self.build_codes(item['net_codes']),
                time_intervals=self.build_intervals(item['time_intervals']),
            )
            trails.append(trail)
        return trails

    def send_trails(self, instance_id, trails):
        try:
            request = trail_service_pb2.CreateTrailRequest(
                compute_instance_id=str(instance_id),
                data=trails
            )
            result = self.stub.Create(
                request,
                timeout=self.connection_timeout,
                metadata=[('authorization', f'Bearer {self.token}')]
                # credentials=self._get_call_creds(self.token)
            )
            LOGGER.debug('Send trails: %s', trails)
            return result.code
        except grpc._channel._InactiveRpcError as err:
            if err.code() == grpc.StatusCode.UNAVAILABLE:
                raise self.NotAvailable('Connection is closed. Try to set it again.')
            raise err
        except Exception as err:
            raise err

    def push_test_data(
            self,
            data_item,
            stat_item,
            interrupted_event):
        items = []
        ts = data_item["ts"]
        for case_name, case_data in data_item["tagged"].items():
            if case_name == "":
                case_name = "__NOTAG__"
            push_item = self.second_data_to_push_item(case_data, stat_item, ts,
                                                      0, case_name)
            items.append(push_item)
        overall = self.second_data_to_push_item(data_item["overall"],
                                                stat_item, ts, 1, '')
        items.append(overall)

        api_timeouts = self.api_timeouts()
        while not interrupted_event.is_set():
            try:
                code = self.send_trails(self.cloud_instance_id, self.convert_to_proto_message(items))
                if code == 0:
                    break
            except self.NotAvailable as err:
                if not self.core_interrupted.is_set():
                    try:
                        timeout = next(api_timeouts)
                        self._set_connection(self.token)
                        LOGGER.warn("GRPC error, will retry in %ss...", timeout)
                        time.sleep(timeout)
                        continue
                    except StopIteration:
                        raise err
                else:
                    break

    def unlock_target(self, *args):
        return

    def set_imbalance_and_dsc(self, cloud_job_id, rps, comment, timestamp):

        try:
            request = test_service_pb2.UpdateTestRequest(
                id=str(cloud_job_id),
                imbalance_point=rps,
                imbalance_ts=timestamp,
                imbalance_comment=comment
            )
            result = self.test_stub.Update(
                request,
                timeout=self.connection_timeout,
                metadata=[('authorization', f'Bearer {self.token}')]
            )
            LOGGER.debug('Set imbalance %s at %s. Comment: %s', rps, timestamp, comment)
            return result.code
        except grpc.RpcError as err:
            if err.code() in (grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.DEADLINE_EXCEEDED):
                raise self.NotAvailable('Connection is closed. Try to set it again.')
            raise err

# ====== HELPER ======
COMPUTE_INSTANCE_METADATA_URL = 'http://169.254.169.254/computeMetadata/v1/instance/?recursive=true'
COMPUTE_INSTANCE_SA_TOKEN_URL = 'http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default/token'


def get_instance_metadata():
    url = COMPUTE_INSTANCE_METADATA_URL
    try:
        session = requests.Session()
        session.mount(url, HTTPAdapter(max_retries=5))
        response = session.get(url, headers={"Metadata-Flavor": "Google"}).json()
        LOGGER.debug('Instance metadata %s', response)
        return response
    except Exception as e:
        LOGGER.error(f"Couldn't get instance metadata of current vm: {e}")
        raise RuntimeError("Couldn't get instance metadata of current vm: {e}")


def get_current_instance_id():
    if response := get_instance_metadata():
        return response.get('id')
    raise RuntimeError("Metadata is empty")


def get_iam_token():
    url = COMPUTE_INSTANCE_SA_TOKEN_URL
    try:
        session = requests.Session()
        session.mount(url, HTTPAdapter(max_retries=5))
        raw_response = session.get(url, headers={"Metadata-Flavor": "Google"})
        response = raw_response.json()
        iam_token = response.get('access_token')
        LOGGER.debug("Get IAM token")
        return iam_token
    except Exception as e:
        LOGGER.error(f"Couldn't get iam token for instance service account: {e}")
        raise RuntimeError("Couldn't get iam token for instance service account: {e}")
