import re
import six
import time
import logging
import funcsigs
import requests

from retrying import retry
from functools import partial
from six.moves.urllib.parse import urljoin
from cached_property import cached_property, cached_property_with_ttl
from simplejson import JSONDecodeError

from saas.library.python.gencfg import GencfgGroup
from saas.library.python.singleton import Singleton
from saas.library.python.common_functions import connection_error

from .saas_slot_errors import SlotError, SlotUnreachable, CommandResultReadTimeout


class Slot(six.with_metaclass(Singleton)):
    SLOT_ID_REGEXP = re.compile(r'(?P<host>[._a-zA-Z0-9\-]+):(?P<port>\d+)')
    RTYSERVER_COMMANDS = frozenset({
        'abort', 'check_nprofile', 'clear_data_status', 'delete_file', 'disable_nprofile', 'download_configs_from_dm', 'enable_nprofile', 'execute_script',
        'get_async_command_info', 'get_config', 'get_configs_hashes', 'get_file', 'get_info_server', 'get_metric', 'get_must_be_alive', 'get_queue_perf',
        'get_status', 'help', 'no_file_operations', 'put_file', 'reopenlog', 'reset', 'restart', 'set_config', 'set_queue_opts', 'shutdown', 'stop', 'take_file'
    })
    _RETRY_CONF = {
        'stop_max_attempt_number': 5,
        'stop_max_delay': 5 * 1000,  # stop after 5 seconds
        'wait_random_min': 120,
        'wait_random_max': 1000,
        'wrap_exception': False
    }
    LOGGER = logging.getLogger(__name__)

    @classmethod
    def _get_instance_id(cls, args, kwargs):
        """
        Singleton interface
        """
        signature = funcsigs.signature(cls.__init__)
        bound_params = signature.bind(cls, *args, **kwargs)
        port = bound_params.arguments['port'] if 'port' in bound_params.arguments else signature.parameters['port'].default
        return bound_params.arguments['host'], port

    @classmethod
    def _extra_actions(cls, instance, args, kwargs):
        """
        Singleton customisation
        """
        signature = funcsigs.signature(cls.__init__)
        bound_params = signature.bind(cls, *args, **kwargs)
        if 'physical_host' in bound_params.arguments and not instance._physical_host:
            instance._physical_host = bound_params.arguments['physical_host']
        if 'geo' in bound_params.arguments and not instance.geo:
            instance.geo = bound_params.arguments['geo']
        if 'shards_min' in bound_params.arguments and 'shards_max' in bound_params.arguments and not (instance._shards_min or instance._shards_max):
            instance._shards_min = bound_params.arguments['shards_min']
            instance._shards_max = bound_params.arguments['shards_max']

    def __init__(self, host, port=80, physical_host=None, control_port_offset=3, geo=None, shards_min=None, shards_max=None, nanny_service_id=None):
        self.host = host
        self.port = int(port)
        self.geo = geo  # TODO: try to guess
        self.control_port = self.port + int(control_port_offset)
        self._physical_host = physical_host
        self._info_server = None
        self._shards_min = shards_min
        self._shards_max = shards_max
        self._nanny_service_id = nanny_service_id
        self.session = requests.session()

        self._id = '{}:{}'.format(self.host, self.port)
        self._command_url = 'http://{}:{}'.format(self.host, self.control_port)

    @classmethod
    def from_id(cls, slot_id, **kwargs):
        match = cls.SLOT_ID_REGEXP.match(slot_id)
        if match:
            kwargs.update({
                'host': match.group('host'),
                'port': int(match.group('port')),
            })
            return cls(**kwargs)
        else:
            raise ValueError('Can\'t construct slot from id: {}'.format(slot_id))

    def __eq__(self, other):
        if isinstance(other, Slot):
            return True if self._id == other._id else False
        else:
            return NotImplemented

    def __ne__(self, other):
        return not self.__eq__(other)

    def __hash__(self):
        return hash(self._id)

    def __repr__(self):
        return 'Slot({})'.format(self.id)

    def __str__(self):
        return self.id

    @property
    def id(self):
        return self._id

    @retry(retry_on_exception=connection_error, **_RETRY_CONF)
    def _make_request(self, path='/', **kwargs):
        url = urljoin(self._command_url, path)
        response = self.session.get(url, **kwargs)
        response.raise_for_status()
        return response

    @retry(**_RETRY_CONF)
    def _request_with_retry(self, path='/', **kwargs):
        self.LOGGER.debug('Request with path=%s, params=%s', path, kwargs)
        return self._make_request(path=path, **kwargs)

    def make_request(self, path='/', retryable=False, **params):
        try:
            if retryable:
                return self._request_with_retry(path=path, **params)
            else:
                return self._make_request(path=path, **params)
        except requests.exceptions.ConnectionError as e:
            raise SlotUnreachable(e)
        except requests.exceptions.ReadTimeout as e:
            raise CommandResultReadTimeout(e)
        except requests.RequestException as e:
            raise SlotError(inner_exception=e, slot=self)

    @property
    def shards_min(self):
        if self._shards_min is None:
            self._shards_min = self.info_server.get('config', {}).get('Server', {}).get('ShardMin', None)

        return self._shards_min

    @property
    def shards_max(self):
        if self._shards_max is None:
            self._shards_max = self.info_server.get('config', {}).get('Server', {}).get('ShardMin', None)

        return self._shards_max

    @property
    def interval(self):
        if self.shards_min is not None and self.shards_max is not None:
            return {'min': self.shards_min, 'max': self.shards_max}
        else:
            return None

    @interval.setter
    def interval(self, value):
        self._shards_min = value['min']
        self._shards_max = value['max']

    @property
    def instance_topology(self):
        try:
            server_info = self.info_server['slot']
            dynamic_properties = server_info['dynamicProperties']
            gencfg_group = dynamic_properties.get('GENCFG_GROUP', None)
            gencfg_release = dynamic_properties.get('GENCFG_RELEASE', None)
            if gencfg_group:
                return GencfgGroup(gencfg_group, gencfg_release, validate=False)
            else:
                return None
        except (SlotError, KeyError):
            return None

    @property
    def physical_host(self):
        if not self._physical_host:
            try:
                self._physical_host = self.info_server['slot']['properties']['NODE_NAME']
            except (SlotError, KeyError):
                return None
        return self._physical_host

    @cached_property
    def _this_server_commands(self):
        return self._get_this_server_commands()

    def _get_this_server_commands(self):
        response = None
        try:
            response = self.execute_command('help', retryable=True)
            return frozenset(response.json()['this server commands'].split(', '))
        except JSONDecodeError as e:
            self.LOGGER.error('Slot %s returned response %s with unparsable body %s for request %s', self, response, e.doc, response.request.url)
            raise e
        except requests.RequestException:
            return frozenset()

    def __getattr__(self, item):
        if item in object.__getattribute__(self, 'RTYSERVER_COMMANDS') or item in object.__getattribute__(self, '_this_server_commands'):
            return partial(object.__getattribute__(self, 'execute_command'), command=item)
        else:
            return object.__getattribute__(self, item)

    def execute_command(self, command, request_params=None, retryable=False, **kwargs):
        kwargs['command'] = command

        request_params = request_params if request_params else {}
        request_params['timeout'] = request_params.get('timeout', 90)
        request_params['params'] = request_params['params'].update(kwargs) if request_params.get('params', None) else kwargs

        return self.make_request(path='/', retryable=retryable, **request_params)

    def shutdown(self):
        try:
            self.execute_command('shutdown')
            return True
        except CommandResultReadTimeout:
            self.LOGGER.info('Read timeout while shutting down %s. Probably stopped too fast.', self)
        except SlotError as e:
            self.LOGGER.error('Unexpected error %s while shutting down %s.', e.message, self)
            return None

    def get_tass(self, retryable=False):
        return self.make_request(path='/tass', retryable=retryable).json()

    def get_metrics(self, retryable=False):
        return self.make_request(path='/metric', retryable=retryable).text.strip()

    def get_full_status(self, retryable=False):
        return self.make_request(path='/status', retryable=retryable).text.strip()

    def get_brief_status(self, retryable=False):
        return self.make_request(path='/status', retryable=retryable, params={'brief': 1}, timeout=0.7).text.strip()

    def detach(self, params=None):
        request_params = {
            'timeout': None,
        }
        command_params = {
            'async': 'no',
            'min_shard': self.shards_min,
            'max_shard': self.shards_max,
            'sharding_type': 'url_hash'
        }
        if params:
            command_params.update(params)
        return self.synchronizer(action='detach', request_params=request_params, **command_params)

    @property
    def is_down(self):
        try:
            self.get_brief_status(retryable=True)
            return False
        except SlotError:
            return True

    def deploy(self, wait=300):
        if not self.safe_shutdown():
            self.LOGGER.warning('%s not restarted (not UP)', self)

        for count in range(wait):
            try:
                status = self.get_brief_status(retryable=True)
                self.LOGGER.debug('Slot %s:%s is in status %s', self.host, self.port, status)
                if status in ['OK', 'Fusion_Banned']:
                    return True
                else:
                    time.sleep(1)
            except SlotError:
                time.sleep(1)

        return False

    def restart(self, timeout=300):
        self.LOGGER.debug('Restarting slot %s with timeout %d', self, timeout)
        timeout = time.time() + timeout
        try:
            self.shutdown()
        except requests.RequestException:
            return False
        else:
            while time.time() < timeout:
                try:
                    status = self.get_brief_status(retryable=True)
                except SlotError:
                    time.sleep(3)
                else:
                    if status in {'Starting', }:
                        time.sleep(3)
                    else:
                        return status in {'Repair_Index', 'Restoring_Backup', 'Restoring_Realtime', 'OK', 'Fusion_Banned'}

    def safe_shutdown(self):
        self.wait_up()
        try:
            self.shutdown()
        except requests.RequestException:
            return self.wait_up()
        return True

    def wait_up(self, timeout=300):
        time_out = time.time() + timeout
        while self.is_down and time.time() < time_out:
            time.sleep(1)
        return not self.is_down

    @cached_property_with_ttl(30)
    def info_server(self):
        # type: () -> Mapping[str, Union[str, Mapping]]
        try:
            return self.execute_command('get_info_server', retryable=True).json()['result']
        except requests.RequestException as e:
            raise SlotError(e)

    @cached_property_with_ttl(30)
    def controller_config(self):
        controller_sections = [section['Controller'] for section in self.info_server['config']['DaemonConfig'] if 'Controller' in section.keys()]
        assert len(controller_sections) == 1
        controller_section = controller_sections[0]
        assert len(controller_section) == 1
        return controller_section[0]

    @property
    def dm_options(self):
        dm_section = self.controller_config['DMOptions']
        assert len(dm_section) == 1
        return dm_section[0]

    def has_backup(self):
        backup_indicator_fields = ('SyncPath', 'BackupTable')
        # This logic originaly is ported from https://a.yandex-team.ru/arc/trunk/arcadia/junk/anikella/saas_dm/saas_deploy/saas_deploy/templates/service_cluster_table.html?rev=4502118#L282
        # Some improvements are on he way
        has_backup = False
        conf = self.info_server['config']['Server'][0]['ModulesConfig'][0]

        if 'DOCFETCHER' not in conf.keys():
            self.LOGGER.debug("No DOCFETCHER module found in %s/%s", self._ctype, self._service)

        for docfetcher in conf['DOCFETCHER']:

            if docfetcher["Enabled"] != '1':
                self.LOGGER.debug("DOCFETCHER module is disabled in %s/%s", self._ctype, self._service)
            else:

                streams = [st for st in docfetcher['Stream'] if st.get('Enabled', '0') != '0']
                if len(streams) > 1:
                    self.LOGGER.debug("More than 1 stream found in %s/%s", self._ctype, self._service)
                    for stream in streams:
                        if stream.get('SyncPath', "") != "":
                            has_backup = True
                            break
                else:
                    for key in backup_indicator_fields:
                        if streams[0].get(key, "") != "":
                            self.LOGGER.debug('Found %s, assuming service %s/%s has backup', key, self._ctype, self._service)
                            has_backup = True
                            break
                    if not has_backup and streams[0].get('SnapshotPath', '') != '':
                        has_backup = (streams[0].get('StreamType', '') == 'Snapshot'
                                      and streams[0].get('ConsumeMode', '') in ('replace', 'hard_replace'))
        return has_backup

    def detect_cpu(self):
        try:
            cpu_info = self.execute_command('get_file', filename='/proc/cpuinfo').json()['result']
            cpu_info_lines = cpu_info.splitlines()
            models = set([
                l.split('\t')[1].strip(': ') for l in cpu_info_lines if l.startswith('model name')
            ])
            if len(models) != 1:
                raise RuntimeError('Looks like differrent CPUs installed on one host')
            return models.pop()
        except (CommandResultReadTimeout, SlotError):
            return None
