# coding: utf-8
import logging
import logging.handlers
import signal
import time
import threading

import ylock

import click
import raven

import nanny_rpc_client

from infra.nanny.yp_lite_api.proto import pod_sets_api_pb2
from infra.nanny.yp_lite_api.py_stubs import pod_sets_api_stub

from nanny_repo import repo_api_stub, repo_api_pb2
from infra.yp_lite_shaker.nanny_client import NannyClient

DEFAULT_YP_LITE_UI_RPC_URL = 'https://yp-lite-ui.nanny.yandex-team.ru/api/yplite'
DEFAULT_NANNY_RPC_URL = 'https://nanny.yandex-team.ru/api'
HOUR_IN_SEC = 60 * 60
REQUEST_TIMEOUT = 3
MAX_NANNY_RETRIES = 2
MAX_YP_LITE_RETRIES = 10
TIME_TO_MOVE = 120
DEFAULT_FORMATTER_PATTERN = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
FORMATTER = logging.Formatter(DEFAULT_FORMATTER_PATTERN)
YT_LOCK_DIR = '//tmp/yp-lite-shaker/test'
YP_LITE_SHAKER_LOCK = 'yp-lite-shaker-lock'


class YpLiteShakerError(Exception):
    pass


class YpLiteShaker(threading.Thread):
    def __init__(self, yp_lite_ui_rpc_url, yp_lite_ui_token, nanny_rpc_url, nanny_token, sleep_time, sentry_url,
                 service_id, stopped, handler, move_cooldown):
        self.yp_lite_ui = nanny_rpc_client.RetryingRpcClient(rpc_url=yp_lite_ui_rpc_url + '/pod-sets/',
                                                             oauth_token=yp_lite_ui_token)
        self.nanny_repo = nanny_rpc_client.RetryingRpcClient(rpc_url=nanny_rpc_url + '/repo/',
                                                             oauth_token=nanny_token)
        self.repo_stub = repo_api_stub.RepoServiceStub(self.nanny_repo)
        self.pod_sets_stub = pod_sets_api_stub.YpLiteUIPodSetsServiceStub(self.yp_lite_ui)
        self.sleep_time = sleep_time
        self.nanny_client = NannyClient(token=nanny_token, req_timeout=REQUEST_TIMEOUT, max_retries=MAX_NANNY_RETRIES)
        self.wait_retries = MAX_YP_LITE_RETRIES

        self.raven_client = raven.Client(sentry_url)

        self.service_id = service_id
        self.stopped = stopped
        self.move_cooldown = move_cooldown

        self.logger = logging.getLogger(service_id)
        if handler:
            self.logger.addHandler(handler)
        self.logger.setLevel(logging.INFO)

        threading.Thread.__init__(self)

    def get_clusters(self):
        nanny_response = self.nanny_client.get_service_instances(self.service_id)
        if not nanny_response.get('content').get('yp_pod_ids', {}).get('pods', {}):
            logging.info('The service does not have YP pods')
            return set()
        return {pod['cluster'] for pod in nanny_response['content']['yp_pod_ids']['pods']}

    def list_all_pods(self):
        result = []
        for cluster in self.get_clusters():
            req_pb = pod_sets_api_pb2.ListPodsRequest()
            req_pb.cluster = cluster
            req_pb.service_id = self.service_id
            result += self.pod_sets_stub.list_pods(req_pb).pods

        return result

    @staticmethod
    def is_ready(iss_spec, iss_status):  # copied from nanny.model.yp_lite.yp_client.YpClient
        """
        :type iss_spec: yp_proto.yp.client.api.proto.cluster_api_pb2.HostConfiguration
        :type iss_status: yp_proto.yp.client.api.proto.cluster_api_pb2.HostCurrentState
        :rtype
        """
        target_conf_id = None
        for i in iss_spec.instances:
            if i.targetState == 'ACTIVE':
                target_conf_id = i.id.configuration.groupStateFingerprint
                break
        if not target_conf_id:
            return False
        for i in iss_status.currentStates:
            if i.workloadId.configuration.groupStateFingerprint == target_conf_id:
                return i.currentState == 'ACTIVE'
        return False

    def is_moved_successfully(self, policy_pb):
        if policy_pb is None:
            req_pb = repo_api_pb2.GetReplicationPolicyRequest(policy_id=self.service_id)
            resp_pb = self.repo_stub.get_replication_policy(req_pb)
            policy_pb = resp_pb.policy
        for pod_pb in self.list_all_pods():
            if pod_pb.spec.node_id in policy_pb.spec.down_nodes:
                self.logger.info('Node {} is in down_nodes'.format(pod_pb.spec.node_id))
                return False
            if not self.is_ready(pod_pb.spec.iss, pod_pb.status.agent.iss):
                self.logger.info(
                    'Node {} is moved from down nodes, but is not activated yet'.format(pod_pb.spec.node_id))
                return False
        return True

    def wait(self):
        for i in range(self.wait_retries):
            if self.stopped.wait(timeout=self.sleep_time):
                return
            self.logger.info('Checking if nodes have moved successfully...')
            if self.is_moved_successfully(None):
                self.logger.info('Nodes have moved successfully')
                return
            self.logger.info('Nodes have not moved yet, sleeping for {} seconds'.format(self.sleep_time))
        self.logger.error("Down nodes haven't moved after {} seconds".format(MAX_YP_LITE_RETRIES * self.sleep_time))
        raise YpLiteShakerError(
            "Down nodes haven't moved after {} seconds".format(MAX_YP_LITE_RETRIES * self.sleep_time))

    def set_down_nodes(self, nodes_id, policy_pb):
        self.logger.info('Setting down following nodes: {}'.format(nodes_id))
        policy_pb.spec.down_nodes[:] = nodes_id
        req_pb = repo_api_pb2.UpdateReplicationPolicyRequest(meta=policy_pb.meta, spec=policy_pb.spec)
        resp_pb = self.repo_stub.update_replication_policy(req_pb)
        self.logger.info(resp_pb)

    def move(self):
        req_pb = repo_api_pb2.GetReplicationPolicyRequest(policy_id=self.service_id)
        resp_pb = self.repo_stub.get_replication_policy(req_pb)
        policy_pb = resp_pb.policy
        last_modification_time = policy_pb.meta.last_modification_time.ToSeconds()
        cur_time = time.time()
        if last_modification_time + self.move_cooldown > cur_time:
            return self.move_cooldown - (cur_time - last_modification_time)

        if not self.is_moved_successfully(policy_pb):
            self.logger.error("Down nodes haven't moved after {} seconds".format(cur_time - last_modification_time))
            raise YpLiteShakerError(
                "Down nodes haven't moved after {} seconds".format(cur_time - last_modification_time))

        self.set_down_nodes([pod_pb.spec.node_id for pod_pb in self.list_all_pods()], policy_pb)
        self.wait()
        return self.move_cooldown

    def run(self):
        if self.stopped.is_set():
            return
        try:
            delay_time = self.move()
        except Exception:
            self.raven_client.captureException()
            delay_time = self.sleep_time

        self.stopped.wait(timeout=delay_time)


@click.command()
@click.argument('service_ids', nargs=-1)
@click.option('--yp-lite-ui-token', envvar='YP_LITE_UI_TOKEN', required=True)
@click.option('--yp-lite-ui-rpc-url', envvar='YP_LITE_UI_RPC_URL', default=DEFAULT_YP_LITE_UI_RPC_URL)
@click.option('--nanny-token', envvar='NANNY_TOKEN', required=True)
@click.option('--nanny-rpc-url', envvar='NANNY_RPC_URL', default=DEFAULT_NANNY_RPC_URL)
@click.option('--sleep-time', default=TIME_TO_MOVE)
@click.option('--sentry-url', envvar='SENTRY_URL', required=True)
@click.option('--logging-path', default=None)
@click.option('--move-cooldown', default=HOUR_IN_SEC)
@click.option('--yt-token', envvar='YT_TOKEN', required=True)
@click.option('--yt-lock-dir', default=YT_LOCK_DIR)
def run(service_ids, yp_lite_ui_token, yp_lite_ui_rpc_url, nanny_token, nanny_rpc_url, sleep_time, sentry_url,
        logging_path, move_cooldown, yt_token, yt_lock_dir):
    yp_lite_ui_rpc_url = yp_lite_ui_rpc_url.rstrip('/')
    nanny_rpc_url = nanny_rpc_url.rstrip('/')
    stopped = threading.Event()
    if logging_path is None:
        handler = None
        logging.basicConfig(level=logging.INFO, format=DEFAULT_FORMATTER_PATTERN)
    else:
        handler = logging.handlers.TimedRotatingFileHandler(logging_path, when='MIDNIGHT')
        handler.setLevel(logging.INFO)
        handler.setFormatter(FORMATTER)

    lock_manager = ylock.create_manager(backend='yt', prefix=yt_lock_dir, token=yt_token)
    yt_lock = lock_manager.lock(YP_LITE_SHAKER_LOCK, timeout=move_cooldown*2, block=True)

    def stop(signum, frame):
        stopped.set()

    signal.signal(signal.SIGINT, stop)
    signal.signal(signal.SIGTERM, stop)

    while not stopped.is_set():
        shakers = [YpLiteShaker(yp_lite_ui_rpc_url, yp_lite_ui_token, nanny_rpc_url,
                                nanny_token, sleep_time, sentry_url, service_id, stopped,
                                handler, move_cooldown) for service_id in service_ids]
        with yt_lock:
            for shaker in shakers:
                shaker.start()
            for shaker in shakers:
                shaker.join()


if __name__ == '__main__':
    run()
