import travel.avia.admin.init_project  # noqa

import argparse
import logging
import sys
from abc import ABCMeta, abstractmethod
from collections import defaultdict
from datetime import datetime, timedelta

import requests
import ujson
import pytz
from django.conf import settings

from travel.avia.library.python.avia_data.models.razladki import RazladkiLastProcessedTable
from travel.avia.library.python.common.models.partner import Partner
from travel.avia.admin.lib.logs import add_stdout_handler, create_current_file_run_log
from travel.avia.admin.lib.yt_helpers import yt_client_fabric, yt_read_tables


RAZLADKI_URL_TEMPLATE = 'http://launcher.razladki.yandex-team.ru/save_new_data_json/{project}'
RAZLADKI_TEST_URL_TEMPLATE = 'http://launcher.razladki-test.yandex-team.ru/save_new_data_json/{project}'
YT_LOG_PATH = '//home/logfeller/logs/avia-json-redir-log'
ALLOWED_ENVS = ['dev', 'production']


MOSCOW_TIMEZONE = pytz.timezone('Europe/Moscow')
_EPOCH = datetime(1970, 1, 1, tzinfo=pytz.UTC)


logger = logging.getLogger(__name__)


IMPORTANT_UTM_SOURCES = ('rasp', 'sovetnik', 'yamain', 'ohm_google')  # Wizard is an exception
OLD_UTM_SOURCES = ()
IMPORTANT_PARTNERS = ('ozon', 'sindbad', 'dohop', 's_seven', 'utair', 'trip_ru', 'citytravel')
IMPORTANT_AIRLINES = ('SU', 'S7', 'DP', 'UT', 'U6', 'N5', 'WZ')
AIRLINES_FIXLIST = {
    'FV': 'SU',
}
TIMESTAMP_FORMAT = '%Y-%m-%dT%H:%M:%S'
DATE_FORMAT = '%Y-%m-%d'


class IStatRule(object):
    __metaclass__ = ABCMeta

    @abstractmethod
    def get_keys(self, record):
        pass

    @abstractmethod
    def get_fields(self):
        pass


class UtmSourceRule(IStatRule):
    def __init__(self, important_utm_sources, other_name='other_source'):
        self.important_utm_sources = frozenset(important_utm_sources)
        self.other_name = other_name

    def get_keys(self, record):
        yield self.get_type_by_utm_source(record.get('utm_source'))

    def get_fields(self):
        return ['utm_source']

    def get_type_by_utm_source(self, utm_source):
        if utm_source is None:
            return self.other_name

        if utm_source.startswith('wizard') or utm_source.startswith('unisearch'):
            return 'wizard'

        if utm_source in self.important_utm_sources:
            return utm_source

        return self.other_name


class PartnerRule(IStatRule):
    def __init__(self, partner_codes):
        self.partners = {
            Partner.objects.get(code=partner_code).billing_client_id: partner_code
            for partner_code in partner_codes
        }

    def get_keys(self, record):
        partner_code = self.partners.get(record.get('billing_client_id'))
        if partner_code:
            yield partner_code

        yield 'other'

    def get_fields(self):
        return ['billing_client_id']


class AirlineRule(IStatRule):
    FIELDS = ('forward_numbers', 'backward_numbers')

    def __init__(self, airline_codes, fix_list):
        self.codes = frozenset(airline_codes)
        self.fix_list = fix_list

    def get_keys(self, record):
        airlines = set()
        for field in self.FIELDS:
            flights = record.get(field, '').split(';')
            if not flights:
                continue

            for flight in flights:
                airline = flight.splt(' ')[0]
                airlines.add(self.fix_list.get(airline, airline))

        for airline in airlines:
            yield airline

    def get_fields(self):
        return self.FIELDS


RULES_BY_LOG_TYPE = {
    'avia-json-redir-log-stream': [PartnerRule(IMPORTANT_PARTNERS), UtmSourceRule(IMPORTANT_UTM_SOURCES)],
    'rasp-popular-flights-log': [AirlineRule(IMPORTANT_AIRLINES, AIRLINES_FIXLIST)],
    'avia-json-redir-log': [UtmSourceRule(OLD_UTM_SOURCES, 'other')],
}


class StatCalculator(object):
    def __init__(self, time_field='unixtime', time_window=30):
        """
        @param time_field: field with unixtime
        @param time_window: window size for aggregation in minutes
        """

        self._time_field = time_field
        self._time_window = time_window * 60
        self._rules = []

    def add_rule(self, rule):
        self._rules.append(rule)

    def calc(self, records, key_prefix='redirects', force_timestamp=None):
        stat = defaultdict(lambda: defaultdict(int))
        for record in records:
            timestamp = force_timestamp or _round_unixtime(record['unixtime'], self._time_window)
            for rule in self._rules:
                for current_key in rule.get_keys(record):
                    key = '{}_{}'.format(key_prefix, current_key)
                    stat[timestamp][key] += 1
                    stat[timestamp]['{}_total'.format(key_prefix)] += 1

        return stat

    def get_fields(self):
        fields = {self._time_field}
        for rule in self._rules:
            fields.update(rule.get_fields())

        return fields


def _dt_to_unixtime(dt):
    return int((dt - _EPOCH).total_seconds())


def encode_data_to_razladki(data):
    encoded_data = []
    for timestamp, values in data.iteritems():
        encoded_data.extend(
            {
                'ts': timestamp,
                'param': series_name,
                'value': series_value,
            }
            for series_name, series_value in values.iteritems()
        )

    return {'data': encoded_data}


def choose_razladki_url(environment):
    return RAZLADKI_URL_TEMPLATE if environment == 'production' else RAZLADKI_TEST_URL_TEMPLATE


def send_data_to_razladki(project, data, environment):
    encoded_data = encode_data_to_razladki(data)
    url_template = choose_razladki_url(environment)
    r = requests.post(
        url_template.format(project=project),
        timeout=300,
        data=ujson.dumps(encoded_data),
    )

    return r


def _get_tables(yt, root_dir, start, end=None, columns=None, last_offset=None):
    for table in yt.search(root_dir, node_type='table', attributes=['row_count']):
        last_part = _get_timestamp_from_table_path(table)
        if last_part > start or (last_part == start and last_offset is not None):
            if end is None or last_part < end:
                if last_offset is None:
                    yield yt.TablePath(table, columns=columns)
                else:
                    start_index = last_offset if last_part == start else 0
                    end_index = table.attributes['row_count']
                    yield yt.TablePath(
                        table,
                        columns=columns,
                        start_index=start_index, end_index=end_index,
                    )


def _round_unixtime(unixtime, base):
    return unixtime - unixtime % base


def _timestamp_to_string(timestamp):
    return timestamp.strftime('%Y-%m-%dT%H:%M:%S')


def _init(args):
    last_table = RazladkiLastProcessedTable.objects.get(
        log_name=args.log_name
    )

    if args.manual:
        return _parse_timestamp(args.start_timestamp), _parse_timestamp(args.end_date), last_table

    start_timestamp = last_table.timestamp
    end_timestamp = None

    logger.info('Last saved timestamp: %s', start_timestamp)
    return start_timestamp, end_timestamp, last_table


def _get_timestamp_from_table_path(table_path):
    return _parse_timestamp(str(table_path).split('/')[-1])


def _parse_timestamp(timestamp):
    for date_format in (TIMESTAMP_FORMAT, DATE_FORMAT):
        try:
            return datetime.strptime(timestamp, date_format)
        except ValueError:
            continue

    raise ValueError('Timestamp has unsupported format: {}'.format(timestamp))


def get_stat(yt_client, yt_path, start_timestamp, end_timestamp, calculator, last_offset=None, stream_mode=False, time_window=0):
    tables = sorted(
        _get_tables(
            yt_client, yt_path,
            start_timestamp, end_timestamp,
            columns=list(calculator.get_fields()),
            last_offset=last_offset,
        ),
        key=str,
    )

    logger.info('Tables: %r', tables)

    if not tables:
        logger.info('Nothing to do, abort')
        return {}, {}

    if stream_mode:
        stat = {}
        ind = 0
        while ind < len(tables):
            current_tables = []
            current_start_timestamp = _get_timestamp_from_table_path(tables[ind])
            current_end_timestamp = current_start_timestamp + timedelta(minutes=time_window)
            while ind < len(tables) and _get_timestamp_from_table_path(tables[ind]) < current_end_timestamp:
                current_tables.append(tables[ind])
                ind += 1

            if len(current_tables) != time_window / 5:
                logger.info(
                    'Not enough tables to process in stream mode. Timestamp: %s',
                    _timestamp_to_string(current_start_timestamp),
                )
                break

            force_timestamp = MOSCOW_TIMEZONE.localize(current_start_timestamp)
            logger.info(
                'Processing group, start=%s, end=%s',
                _timestamp_to_string(current_start_timestamp),
                _timestamp_to_string(_get_timestamp_from_table_path(current_tables[-1])),
            )

            stat.update(
                calculator.calc(
                    yt_read_tables(yt_client, current_tables),
                    force_timestamp=force_timestamp,
                    key_prefix='redirect_stream',
                )
            )
            last_table = current_tables[-1]
    else:
        stat = calculator.calc(yt_read_tables(yt_client, tables))
        last_table = tables[-1]

    return stat, last_table


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-v', '--verbose', action='store_true')
    parser.add_argument('--manual', action='store_true')
    parser.add_argument('--start-date', default=None)
    parser.add_argument('--end-date', default=None)
    parser.add_argument('--time-window', default=30, type=int, help='time window for aggregation (in minutes)')
    parser.add_argument('--log-name')
    args = parser.parse_args()

    create_current_file_run_log()

    if args.verbose:
        add_stdout_handler(logger)

    logger.info('Start')

    current_env = settings.ENVIRONMENT
    if current_env not in ALLOWED_ENVS:
        logger.info('Can work only in: %s', ', '.join(ALLOWED_ENVS))
        sys.exit(0)

    start_timestamp, end_timestamp, last_table = _init(args)

    yt_client = yt_client_fabric.create()

    rules = RULES_BY_LOG_TYPE.get(last_table.log_name)
    if not rules:
        logger.error('No rules for %s. Abort', last_table.log_name)
        return

    stat_calculator = StatCalculator(time_window=args.time_window)
    for rule in rules:
        stat_calculator.add_rule(rule)

    stat, max_processed_table = get_stat(
        yt_client, last_table.yt_path, start_timestamp, end_timestamp,
        calculator=stat_calculator,
        stream_mode=last_table.mode == RazladkiLastProcessedTable.STREAM,
        time_window=args.time_window,
    )

    if stat:
        logger.info('Sending to Razladki')
        response = send_data_to_razladki('avia', stat, current_env)

        if response.status_code == 200:
            logger.info('Data sending done')
        elif response.status_code == 409:
            logger.warning('Duplicated data')
        else:
            logger.error(
                'Error while sending data. Status: %d, message: %s',
                response.status_code,
                response.content,
            )

        if not args.manual:
            last_table.timestamp = _get_timestamp_from_table_path(max_processed_table)
            if last_table.offset is not None:
                _range = max_processed_table.attributes['ranges'][0]  # We have exactly one range
                last_table.offset = _range['lower_limit']['row_index']
            last_table.save()

            logger.info('New timestamp: %s', last_table.timestamp)

    logger.info('Done')
