# coding: utf-8
from __future__ import absolute_import, division, print_function, unicode_literals

import datetime
import logging
import time

import yt.wrapper as yt
from sqlalchemy import case, func
from yt.wrapper import YtClient

from travel.rasp.bus.db import session_scope
from travel.rasp.bus.db.models.matching import PointMatching
from travel.rasp.bus.db.models.supplier import Supplier

log = logging.getLogger()

LAST_STAT_UPLOAD_SCHEMA = [
    {"name": "last_upload_time", "type": "int64"},
]

ENDPOINTS_STAT_SCHEMA = [
    {"name": "supplier", "type": "string"},
    {"name": "point_type", "type": "string"},
    {"name": "disabled", "type": "boolean"},
    {"name": "in_segments", "type": "boolean"},
    {"name": "total", "type": "int64"},
    {"name": "matched", "type": "int64"},
    {"name": "outdated", "type": "int64"},
    {"name": "new", "type": "int64"},
    {"name": "log_time", "type": "int64"},
]


class EndpointsStatistic:

    def __init__(self, yt_token, env):
        self.yt_token = yt_token
        self.env = env

        self.default_stat_path = yt.ypath_join('//home', 'buses', env, 'endpoints')
        self.stat_table = 'endpoints_stats'
        self.last_upload_table = 'last_stat_time'
        self.yt_proxy = 'hahn'

        self.yt_client = YtClient(self.yt_proxy, self.yt_token)
        log.setLevel(logging.INFO)
        log.addHandler(logging.StreamHandler())

    def _get_stats(self, last_update_datetime):
        log.info('Loading stats from db')

        with session_scope() as session:
            endpoints_stats = session.query(
                Supplier.code,
                PointMatching.type.name,
                PointMatching.disabled,
                PointMatching.in_segments,
                func.count(),  # total
                func.count(PointMatching.point_key),  # matched
                func.sum(case([((PointMatching.outdated.is_(True)), 1)], else_=0)),  # outdated
                func.sum(case([((PointMatching.created_at >= last_update_datetime), 1)], else_=0)),  # new
            ).join(PointMatching).group_by(
                Supplier.code,
                PointMatching.type,
                PointMatching.disabled,
                PointMatching.in_segments,
            ).all()

            log.info('Stats loaded')
            return endpoints_stats

    def run(self):
        with self.yt_client.Transaction():
            log.info('Start')
            start_time = int(time.time())
            last_update_timestamp = self._get_last_stat_time()
            if not last_update_timestamp:
                log.error("got invalid last update time: {}".format(last_update_timestamp))
                last_update_datetime = datetime.datetime.utcnow()
            else:
                last_update_datetime = datetime.datetime.utcfromtimestamp(last_update_timestamp)
                log.info('Got last update time: {}'.format(last_update_datetime))
            stats = self._get_stats(last_update_datetime)
            log.info('got stats: {}'.format(stats))
            self._write_stats(stats, start_time)
            self._set_last_stat_time(start_time)

            log.info('Done')

    def _get_last_stat_time(self):
        path = yt.ypath_join(self.default_stat_path, self.last_upload_table)
        if not self.yt_client.exists(path):
            return None
        row_count = self.yt_client.row_count(path)
        if row_count != 1:
            log.error('Too much rows in table {}, need 1, got {}'.format(path, row_count))
            return None
        rows = list(self.yt_client.read_table(path))
        return rows[0]['last_upload_time'] if rows else None

    def _set_last_stat_time(self, tmstmp):
        log.info('Writing log time: {}'.format(tmstmp))
        path = yt.ypath_join(self.default_stat_path, self.last_upload_table)
        if not self.yt_client.exists(path):
            log.info('Table at path is not exists: {}. Creating'.format(path))
            self.yt_client.create("table", path, attributes={'schema': LAST_STAT_UPLOAD_SCHEMA})
        rows = ({'last_upload_time': tmstmp},)
        self.yt_client.write_table(path, rows)

    def _write_stats(self, stats, update_timestamp):
        log.info('Start writing to YT')
        path = yt.ypath_join(self.default_stat_path, self.stat_table)
        if self.yt_client.exists(path):
            self.yt_client.alter_table(path, ENDPOINTS_STAT_SCHEMA)
        else:
            log.info('Table at path is not exists: {}. Creating'.format(path))
            self.yt_client.create("table", path, attributes={'schema': ENDPOINTS_STAT_SCHEMA})

        data = []
        for row in stats:
            (supplier, point_type, disabled, in_segments, total, matched, outdated, new) = row
            data.append({
                'supplier': supplier,
                'point_type': point_type,
                'disabled': disabled,
                'in_segments': in_segments,
                'total': total,
                'matched': matched,
                'outdated': outdated,
                'new': new,
                'log_time': update_timestamp,
            })
        self.yt_client.write_table('<append=%true>' + path, data)
