# -*- coding: utf-8 -*-

from datetime import datetime, timedelta
from cars import settings
from .table import YtTable
from yt.wrapper import TablePath, with_context


class PlusTable(YtTable):
    SOURCE_TABLE = '//home/plus_data/share/1d'
    START_TABLE = '2018-05-24'

    def get_name(self):
        return settings.EXPORT['plus_part']

    def get_schema(self):
        return [
            {'name': 'uid',                     'type': 'int64'},
            {'name': 'is_plus',                 'type': 'boolean'},
            {'name': 'plus.first_subscription', 'type': 'int64'},
            {'name': 'plus.last_subscription',  'type': 'int64'},
            {'name': '_updated',                'type': 'int64'},
        ]

    def get_index(self):
        return ['uid']

    @staticmethod
    def _source_table_mapper(row):
        subscription = int(
            datetime
            .strptime(row['last_subscription_date'], '%Y-%m-%d')
            .timestamp()
        )
        updated = int(
            datetime
            .strptime(row['date'], '%Y-%m-%d')
            .timestamp()
        )
        is_plus = row['status'].lower() != 'gone'
        yield {
            'uid': int(row['puid']),
            'is_plus': is_plus,
            'plus.first_subscription': subscription,
            'plus.last_subscription': subscription,
            '_updated': updated,
        }

    @staticmethod
    def _table_reducer(key, rows):
        result = dict()
        for row in rows:
            uid = row['uid']
            if uid in result:
                result[uid]['plus.first_subscription'] = min(
                    result[uid]['plus.first_subscription'],
                    row['plus.first_subscription']
                )
                result[uid]['plus.last_subscription'] = max(
                    result[uid]['plus.last_subscription'],
                    row['plus.last_subscription']
                )
                if result[uid]['_updated'] < row['_updated']:
                    result[uid]['is_plus'] = row['is_plus']
                    result[uid]['_updated'] = row['_updated']
            else:
                result[uid] = row
        for row in result.values():
            yield row       

    def update(self):
        with self._yt.Transaction():
            if not self._yt.exists(self.get_table_path()):
                self._yt.create(
                    'table',
                    self.get_table_path(),
                    recursive=True,
                    ignore_existing=True,
                )
                self._yt.run_sort(
                    source_table=self.get_table_path(),
                    destination_table=self.get_table_path(),
                    sort_by=self.get_index(),
                )
            last_state = self.get_last_state()
            today_state = (
                datetime.utcnow()
                .replace(hour=0, minute=0, second=0, microsecond=0)
            )
            delta_table = self._yt.create_temp_table(
                settings.EXPORT['temp_path']
            )
            source_table_list = []
            while last_state < today_state:
                source_table = '{}/{}'.format(
                    self.SOURCE_TABLE,
                    last_state.strftime('%Y-%m-%d')
                )
                if self._yt.exists(source_table):
                    source_table_list.append(source_table)
                last_state += timedelta(days=1)
            self._yt.run_map(
                self._source_table_mapper,
                source_table_list,
                delta_table
            )
            self._yt.run_sort(
                source_table=delta_table,
                destination_table=delta_table,
                sort_by=self.get_index(),
            )
            self._yt.run_reduce(
                self._table_reducer,
                source_table=[
                    self.get_table_path(),
                    delta_table,
                ],
                destination_table=self.get_table_path(),
                reduce_by=self.get_index(),
            )
            self._yt.remove(delta_table)
            self.set_last_state(last_state - timedelta(days=2))
            self._yt.run_sort(
                source_table=self.get_table_path(),
                destination_table=self.get_table_path(),
                sort_by=self.get_index(),
            )

    def get_last_state(self):
        return datetime.strptime(
            super().get_last_state(self.START_TABLE),
            '%Y-%m-%d'
        )

    def set_last_state(self, state):
        if state:
            state = state.strftime('%Y-%m-%d')
        super().set_last_state(state)
