import collections
import datetime
import json
import logging
import re

import jinja2
import pytz

import sandbox.common.types.resource as ctr
from sandbox import sdk2

_CURRENT_RESOURCE_FORMAT_VERSION = 2
"""
Format version of cache resource.
Should be incremented when resource format was updated.
"""


class MetrikaMobileLogValues(sdk2.Resource):
    """
    YQL query results.
    """
    config = sdk2.Attributes.String('Config', required=True)
    timestamp = sdk2.Attributes.Integer('Timestamp', required=True)
    version = sdk2.Attributes.Integer('Version', required=True)
    ttl = 8


class MonitorError(BaseException):
    CODE_CANNOT_FETCH_DATA = 0

    def __init__(self, code, message):
        self.code = code
        self.message = message


class MonitorQuery(object):
    STEP_1DAY = int(datetime.timedelta(days=1).total_seconds())
    STEP_30MIN = int(datetime.timedelta(minutes=30).total_seconds())
    _STALE_DELTA = datetime.timedelta(days=2)
    _MSK_TIMEZONE = pytz.timezone('Europe/Moscow')

    @classmethod
    def _get_query_timestamp(cls, timestamp):
        return int(timestamp / cls.STEP_30MIN) * cls.STEP_30MIN

    @classmethod
    def _load_yql_template(cls, template_path, **kwargs):
        with open(template_path, 'r') as f:
            template = jinja2.Environment(loader=jinja2.BaseLoader).from_string(f.read())
            return template.render(**kwargs)

    @classmethod
    def _load_yql_query_text(cls, static_pool_name, query_timestamp,
                             yql_stale_template_path, yql_fresh_template_path):
        query_datetime = pytz.UTC.localize(datetime.datetime.utcfromtimestamp(query_timestamp))
        query_datetime_msk = query_datetime.astimezone(cls._MSK_TIMEZONE)
        now_datetime = pytz.UTC.localize(datetime.datetime.utcnow())

        if query_datetime < now_datetime - cls._STALE_DELTA:
            return cls._load_yql_template(
                yql_stale_template_path,
                static_pool_name=static_pool_name,
                date=query_datetime_msk.strftime('%Y-%m-%d'),
                ts1=query_timestamp, ts2=query_timestamp + cls.STEP_30MIN)
        else:
            return cls._load_yql_template(
                yql_fresh_template_path,
                static_pool_name=static_pool_name,
                datetime=query_datetime_msk.strftime('%Y-%m-%dT%H:%M:%S'))

    def __init__(self, monitor_config, static_pool_name, timestamp):
        """
        :type monitor_config: sandbox.projects.browser.monitor.MetrikaMobileLogMonitor.config.MonitorConfig
        :type timestamp: int
        """
        self.static_pool_name = static_pool_name
        self.query_timestamp = self._get_query_timestamp(timestamp)
        self.yql_query_text = self._load_yql_query_text(
            self.static_pool_name, self.query_timestamp,
            monitor_config.yql_stale_template_path,
            monitor_config.yql_fresh_template_path)


class MonitorValues(object):
    """
    Raw YQL query result. This table contains the following columns:
    - report_timestamp: timestamp to use in Solomon report.
    - app_version_name: value of "version" metric.
    - <other>: other columns contain values of the sensors.
    """

    @classmethod
    def fetch(cls, monitor_query, yql_token):
        """
        :type monitor_query: MonitorQuery
        :type yql_token: str
        :rtype: MonitorValues
        """
        import yql.api.v1.client
        with yql.api.v1.client.YqlClient(token=yql_token) as yql_client:
            logging.info('Fetch monitor values for %s',
                         datetime.datetime.utcfromtimestamp(monitor_query.query_timestamp))
            yql_query = yql_client.query(query=monitor_query.yql_query_text, syntax_version=1)
            yql_query.run()
            yql_query.get_results(wait=True)
            if not yql_query.is_success:
                raise MonitorError(
                    MonitorError.CODE_CANNOT_FETCH_DATA,
                    'YQL query failed: {}'.format(yql_query.status))

            yql_query.table.fetch_full_data()

            column_names, rows = yql_query.table.column_names, yql_query.table.rows
            return MonitorValues(monitor_query.query_timestamp, column_names, rows)

    @classmethod
    def load(cls, config_name, query_timestamp):
        """
        :type config_name: str
        :type query_timestamp: int
        :rtype: MonitorValues
        """
        values_resource = sdk2.Resource.find(
            resource_type=MetrikaMobileLogValues, state=ctr.State.READY,
            attrs=dict(config=config_name, timestamp=query_timestamp, version=_CURRENT_RESOURCE_FORMAT_VERSION),
        ).order(-sdk2.Resource.created).first()
        if not values_resource:
            return None
        logging.info('Load monitor values for %s from resource #%s',
                     datetime.datetime.utcfromtimestamp(query_timestamp), values_resource.id)

        values_resource_data = sdk2.ResourceData(values_resource)
        with open(str(values_resource_data.path), 'r') as f:
            data = json.loads(f.read())
            return MonitorValues(
                query_timestamp,
                data['column_names'], data['rows'])

    @classmethod
    def load_history(cls, config_name, timestamp):
        """
        :type config_name: str
        :type timestamp: int
        :rtype: dict[int, MonitorValues]
        """
        timestamp_list = range(
            timestamp, timestamp - MonitorQuery.STEP_1DAY,
            -MonitorQuery.STEP_30MIN)
        return {timestamp: MonitorValues.load(config_name, timestamp)
                for timestamp in timestamp_list}

    def __init__(self, query_timestamp, column_names, rows):
        """
        :type query_timestamp: int
        :type column_names: list[str]
        :type rows: list[list[any]]
        """
        self.query_timestamp = query_timestamp
        self.column_names = column_names
        self.rows = rows
        self.data_list = [{
            column_name: row[index]
            for index, column_name in enumerate(column_names)
        } for row in rows]

    def save(self, my_task, config_name):
        """
        :type my_task: sdk2.Task
        :type config_name: str
        """
        query_datetime = datetime.datetime.utcfromtimestamp(self.query_timestamp)
        values_resource = MetrikaMobileLogValues(
            my_task,
            query_datetime.strftime('YQL[%Y-%m-%d %H:%M:%S]'),
            my_task.path(query_datetime.strftime('%Y-%m-%d-%H-%M-%S.json')),
            version=_CURRENT_RESOURCE_FORMAT_VERSION,
            config=config_name,
            timestamp=self.query_timestamp)
        with open(str(values_resource.path), 'w') as f:
            f.write(json.dumps({
                'column_names': self.column_names,
                'rows': self.rows,
            }))
        logging.info('Save monitor values for %s to resource #%s', query_datetime, values_resource.id)
        sdk2.ResourceData(values_resource).ready()


class VersionDataPoint(object):
    def __init__(self, report_timestamp, version_name, weight, values):
        """
        :type report_timestamp: int
        :type version_name: str
        :type weight: int
        :type values: dict[str, int]
        """
        self.report_timestamp = report_timestamp
        self.version_name = version_name
        self.weight = weight
        self.values = values


class VersionDataTable(object):
    def __init__(self, query_timestamp, points):
        """
        :type query_timestamp: int
        :type points: list[VersionDataPoint]
        """
        self.query_timestamp = query_timestamp
        self.points = points
        self.versions = sorted(set([p.version_name for p in points]))
        self.timestamps = sorted(set([p.report_timestamp for p in points]))

        data = collections.defaultdict(dict)
        for p in points:
            vn, rt = p.version_name, p.report_timestamp
            assert rt not in data[vn], 'version={} timestamp={} already set'.format(vn, rt)
            data[vn][rt] = p
        self.data = dict(data)
        """ :type: dict[str, dict[int, VersionDataPoint]] """

    def build_without_versions(self, ignore_versions):
        """
        :type ignore_versions: list[str]
        :rtype: VersionDataTable
        """
        return self.build_with_version_mapping({
            v: v for v in self.versions if v not in ignore_versions})

    def build_with_version_mapping(self, version_mapping):
        """
        :type version_mapping: dict[str, str]
        :rtype: VersionDataTable
        """
        new_points = []
        for target_version, source_version in version_mapping.iteritems():
            version_data = self.data.get(source_version, {})
            if source_version == target_version:
                new_points.extend(version_data.values())
            else:
                for point in version_data.values():
                    new_points.append(VersionDataPoint(
                        point.report_timestamp, target_version,
                        point.weight, point.values))
        new_points = sorted(new_points, key=lambda p: [p.report_timestamp, p.version_name])
        return VersionDataTable(self.query_timestamp, new_points)

    def build_sum(self, sum_version_name):
        """
        :type sum_version_name: str
        :rtype: VersionDataTable
        """
        new_points = []
        for report_timestamp in self.timestamps:
            sum_weight = 0
            sum_values = collections.defaultdict(int)
            for vn in self.versions:
                point = self.data.get(vn, {}).get(report_timestamp)
                if point:
                    sum_weight += point.weight
                    for k, v in point.values.iteritems():
                        sum_values[k] += v

            new_points.append(VersionDataPoint(
                report_timestamp, sum_version_name,
                sum_weight, dict(sum_values)))
        return VersionDataTable(self.query_timestamp, new_points)

    def build_merged_with(self, other):
        """
        :type other: VersionDataTable
        :rtype: VersionDataTable
        """
        new_points = list(other.points)
        for point in self.points:
            if point.report_timestamp not in other.data.get(point.version_name, {}):
                new_points.append(point)
        new_points = sorted(new_points, key=lambda p: [p.report_timestamp, p.version_name])
        return VersionDataTable(self.query_timestamp, new_points)


class VersionDataManager(object):
    _COLUMN_REPORT_TIMESTAMP = 'report_timestamp'
    _COLUMN_VERSION_NAME = 'app_version_name'
    _SOLOMON_LABEL_SENSOR = 'sensor'
    _SOLOMON_LABEL_VERSION = 'version'
    _SOLOMON_SENSOR_DAILY_WEIGHT = '_daily_weight'

    @classmethod
    def load_version_data_table(cls, monitor_values, weight_column_name):
        """
        :type monitor_values: MonitorValues
        :type weight_column_name: str
        :rtype: VersionDataTable
        """
        points = []
        for data in monitor_values.data_list:
            values = dict(data)
            report_timestamp = values.pop(cls._COLUMN_REPORT_TIMESTAMP)
            version_name = values.pop(cls._COLUMN_VERSION_NAME)
            # Sometimes YQL query result contains 'None' version.
            if not version_name:
                continue

            weight = values[weight_column_name]
            points.append(VersionDataPoint(report_timestamp, version_name, weight, values))
        return VersionDataTable(monitor_values.query_timestamp, points)

    @classmethod
    def get_wrong_versions(cls, versions, correct_version_pattern):
        """
        :type versions: list[str]
        :type correct_version_pattern: str
        :rtype: list[str]
        """
        return [v for v in versions if not re.match(correct_version_pattern, v)]

    @classmethod
    def get_version_weights(cls, monitor_values_dict, weight_column_name):
        """
        :type monitor_values_dict: dict[int, MonitorValues]
        :type weight_column_name: str
        :rtype: dict[str, int]
        """
        version_weight_lists = collections.defaultdict(list)
        for query_timestamp in sorted(monitor_values_dict.keys()):
            monitor_values = monitor_values_dict[query_timestamp]
            if monitor_values is not None:
                weights = collections.defaultdict(int)
                for data in monitor_values.data_list:
                    version_name = data[cls._COLUMN_VERSION_NAME]
                    # Sometimes YQL query result contains 'None' version.
                    if not version_name:
                        continue

                    weight = data[weight_column_name]
                    weights[version_name] += weight

                for version_name, weight in weights.iteritems():
                    version_weight_lists[version_name].append(weight)

        version_weight_dict = {}
        for version_name, weight_list in version_weight_lists.iteritems():
            version_weight_dict[version_name] = (
                sum(weight_list) * len(monitor_values_dict) / len(weight_list))
        return version_weight_dict

    @classmethod
    def create_version_weights_table(cls, query_timestamp, version_weights):
        """
        :type query_timestamp: int
        :type version_weights: dict[str, int]
        :rtype: VersionDataTable
        """
        points = []
        for version_name, weight in version_weights.iteritems():
            values = {cls._SOLOMON_SENSOR_DAILY_WEIGHT: weight}
            points.append(VersionDataPoint(
                query_timestamp, version_name, weight, values))
        points = sorted(points, key=lambda p: p.weight, reverse=True)
        return VersionDataTable(query_timestamp, points)

    @classmethod
    def create_solomon_metrics(cls, version_data_table, version_labels):
        """
        :type version_data_table: VersionDataTable
        :type version_labels: dict[str, dict[str, str]]
        :rtype: list[dict[str, any]]
        """
        solomon_metrics = []
        for point in version_data_table.points:
            for sensor_name, sensor_value in point.values.iteritems():
                labels = dict(version_labels.get(point.version_name, {}), **{
                    cls._SOLOMON_LABEL_SENSOR: sensor_name,
                    cls._SOLOMON_LABEL_VERSION: point.version_name,
                })
                solomon_metrics.append({
                    'ts': point.report_timestamp,
                    'labels': labels,
                    'value': sensor_value,
                })
        return solomon_metrics


def parse_version(version):
    if version is None:
        return []
    components = version.split('.')
    for i, c in enumerate(components):
        try:
            components[i] = int(c)
        except ValueError:
            pass
    return components


def compare_versions(version1, version2):
    components1 = parse_version(version1)
    components2 = parse_version(version2)
    if components1 == components2:
        return 0
    if components1 < components2:
        return -1
    else:
        return +1


class WeightAnalyser(object):
    _LABEL_LEADER = 'leader'
    _LABEL_RELEASE = 'release'
    _LABEL_SIGNIFICANT = 'significant'
    _LABEL_YES = 'yes'
    _LABEL_NO = 'no'
    _VERSION_RELEASE = 'release'
    _VERSION_TOP_FMT = 'top-{}'

    @classmethod
    def get_significant_versions(cls, significant_threshold, version_weights):
        """
        :type significant_threshold: int
        :type version_weights: dict[str, int]
        :rtype: set[str]
        """
        return set(
            version_name for version_name, weight in version_weights.iteritems()
            if weight >= significant_threshold)

    @classmethod
    def get_sorted_releases(cls, versions):
        """
        :type versions: list[str]
        :rtype: list[str]
        """
        sorted_releases = list(versions)
        sorted_releases.sort(compare_versions)
        return sorted_releases

    @classmethod
    def find_release_version(cls, release_threshold, version_weights):
        """
        :type release_threshold: int
        :type version_weights: dict[str, int]
        :rtype: str | None
        """
        sorted_releases = cls.get_sorted_releases(version_weights.keys())
        for version_name in reversed(sorted_releases):
            if version_weights[version_name] >= release_threshold:
                return version_name
        return None

    @classmethod
    def get_leader_versions(cls, leader_percent, version_weights):
        """
        :type leader_percent: int
        :type version_weights: dict[str, int]
        :rtype: set[str]
        """
        total_weight = sum(version_weights.values())
        version_weights_pairs = sorted(version_weights.items(),
                                       key=lambda i: i[1], reverse=True)

        leader_versions = set()
        sum_weight = 0
        for version_name, weight in version_weights_pairs:
            leader_versions.add(version_name)
            sum_weight += weight
            if sum_weight >= total_weight * leader_percent / 100:
                break
        return leader_versions

    @classmethod
    def update_version_labels_with_hardcoded_metrics(cls, versions_labels, monitor_config):
        """
        :type versions_labels: dict[str, dict[str, str]]
        :type monitor_config: sandbox.projects.browser.monitor.MetrikaMobileLogMonitor.config.MonitorConfig
        """
        for version_name in versions_labels.keys():
            for metric_name, metric_values in monitor_config.metrics_values.iteritems():
                values = [
                    value for value, version_name_list in metric_values.iteritems()
                    if version_name in version_name_list]
                if len(values) > 1:
                    raise ValueError(
                        'Cannot choose value of metric "{}" for version "{}": too many values: {}'.format(
                            metric_name, version_name, values))
                if values:
                    versions_labels[version_name][metric_name] = values[0]
                else:
                    if metric_name in monitor_config.metrics_default:
                        versions_labels[version_name][metric_name] = monitor_config.metrics_default[metric_name]

    @classmethod
    def get_version_labels(cls, monitor_config, version_weights):
        """
        :type monitor_config: sandbox.projects.browser.monitor.MetrikaMobileLogMonitor.config.MonitorConfig
        :type version_weights: dict[str, int]
        :rtype: dict[str, dict[str, str]]
        """
        significant_versions = cls.get_significant_versions(
            monitor_config.significant_daily_weight_threshold, version_weights)
        last_release_version = cls.find_release_version(
            monitor_config.release_daily_weight_threshold, version_weights)
        leader_versions = cls.get_leader_versions(
            monitor_config.leaders_daily_weight_percent, version_weights)

        version_labels = {}
        for version_name in version_weights.keys():
            version_labels[version_name] = {
                cls._LABEL_LEADER: (
                    cls._LABEL_YES if version_name in leader_versions else cls._LABEL_NO),
                cls._LABEL_RELEASE: (
                    cls._LABEL_YES if version_name == last_release_version else cls._LABEL_NO),
                cls._LABEL_SIGNIFICANT: (
                    cls._LABEL_YES if version_name in significant_versions else cls._LABEL_NO),
            }
        cls.update_version_labels_with_hardcoded_metrics(version_labels, monitor_config)
        return version_labels

    @classmethod
    def get_generated_versions_mapping(cls, monitor_config, version_weights):
        """
        :type monitor_config: sandbox.projects.browser.monitor.MetrikaMobileLogMonitor.config.MonitorConfig
        :type version_weights: dict[str, int]
        :rtype: dict[str, str]
        """
        version_mapping = {}

        last_release_version = cls.find_release_version(
            monitor_config.release_daily_weight_threshold, version_weights)
        if last_release_version:
            version_mapping[cls._VERSION_RELEASE] = last_release_version

        version_weight_pairs = sorted(version_weights.items(), key=lambda i: i[1], reverse=True)
        top_pairs = version_weight_pairs[:monitor_config.top_versions_count]
        for version_index, (version_name, weight) in enumerate(top_pairs):
            version_mapping[cls._VERSION_TOP_FMT.format(version_index + 1)] = version_name

        return version_mapping
