# coding: utf-8
from __future__ import unicode_literals

import argparse
import logging  # noqa
import os
import sys
from collections import namedtuple
from datetime import datetime

import yt
from django.conf import settings
from lxml import etree

from common.db.mds.clients import mds_s3_public_client
from common.models.geo import Settlement, Country, Region, StationMajority, Station
from common.models.transport import TransportType
from travel.rasp.library.python.common23.date.environment import now_utc
from common.utils.lock import lock
from travel.rasp.library.python.common23.logging.scripts import script_context
from common.data_api.yt.instance import update_yt_wrapper_config
from route_search.models import ZNodeRoute2
from travel.rasp.tasks.min_prices.yt_min_train_prices import YtMinTrainPrice

log = logging.getLogger(__name__)

Category = namedtuple('Category', ['id', 'value'])
EVERYDAY_CATEGORY = Category(id='1', value='everyday')
NOT_EVERYDAY_CATEGORY = Category(id='2', value='noteveryday')


class TrainMinPriceCache:
    class TrainMinPrice:
        def __init__(self, dates, min_price):
            self.min_price = min_price
            self.dates = dates

    def __init__(self, trip_days_count, yt_table_path):
        self.trip_days_count = trip_days_count
        self.yt_table_path = yt_table_path
        self._min_price_cache = {}

    def init_prices(self):
        self._min_price_cache = {}

        update_yt_wrapper_config()

        if not yt.wrapper.exists(self.yt_table_path):
            log.error('Path %s does not exist', self.yt_table_path)
            sys.exit(1)

        for row in yt.wrapper.read_table(self.yt_table_path, format=b'dsv', raw=False):
            key = (row['settlement_from_id'], row['settlement_to_id'])
            price = float(row['price'])
            item = self._min_price_cache.setdefault(key, self.TrainMinPrice(dates=set(), min_price=price))
            item.dates.add(row['date_forward'])
            if item.min_price > price:
                item.min_price = price

    def get_price_by_settlements(self, settlement_from, settlement_to):
        min_price = self._min_price_cache.get((str(settlement_from.id), str(settlement_to.id)))
        if min_price and min_price.min_price:
            return min_price.min_price, len(min_price.dates) >= self.trip_days_count
        return float(0), False


def get_routes():
    query = ZNodeRoute2.objects.filter(
        settlement_from__isnull=False,
        settlement_to__isnull=False,
        t_type_id=TransportType.TRAIN_ID,
    ).values_list('settlement_from_id', 'settlement_to_id').distinct()

    routes = []
    settlements_ids = set()
    for settlement_from_id, settlement_to_id in query:
        if settlement_from_id == settlement_to_id:
            continue
        settlements_ids.add(settlement_from_id)
        settlements_ids.add(settlement_to_id)
        routes.append((settlement_from_id, settlement_to_id))

    settlement_by_id = Settlement.objects.in_bulk(settlements_ids)

    for settlement_from_id, settlement_to_id in routes:
        settlement_from = settlement_by_id[settlement_from_id]
        settlement_to = settlement_by_id[settlement_to_id]

        if settlement_from.country_id != Country.RUSSIA_ID or settlement_to.country_id != Country.RUSSIA_ID:
            continue
        if (settlement_from.region_id == Region.KALININGRAD_REGION_ID
                or settlement_to.region_id == Region.KALININGRAD_REGION_ID):
            continue

        yield settlement_from, settlement_to


def make_offer_element(settlement_from, settlement_to, price=None, everyday=False):
    el_id = '{}-{}'.format(settlement_from.point_key, settlement_to.point_key)
    offer_el = etree.Element('offer', attrib={'id': el_id, 'available': 'true' if price else 'false'})
    etree.SubElement(offer_el, 'url').text = 'https://{}/{}--{}'.format(
        settings.TRAINS_SITE_RU,
        settlement_from.slug,
        settlement_to.slug
    )
    name = '{} — {}'.format(
        settlement_from.L_title(lang='ru'),
        settlement_to.L_title(lang='ru'),
    )
    if len(name) < 23:
        name = 'Ж/д билеты {}'.format(name)
    elif len(name) < 27:
        name = 'Поезда {}'.format(name)
    elif len(name) < 28:
        name = 'Поезд {}'.format(name)
    elif len(name) > 33:
        return None
    etree.SubElement(offer_el, 'name').text = name
    etree.SubElement(offer_el, 'price').text = str(price) or '0'
    etree.SubElement(offer_el, 'currencyId').text = 'RUB'
    etree.SubElement(offer_el, 'description').text = 'Билеты на поезд {} {}'.format(
        settlement_from.L_title_phrase_from(lang='ru'),
        settlement_to.L_title_phrase_to(lang='ru'),
    )
    etree.SubElement(offer_el, 'categoryId').text = EVERYDAY_CATEGORY.id if everyday else NOT_EVERYDAY_CATEGORY.id
    etree.SubElement(offer_el, 'param', attrib={'name': 'fromName'}).text = settlement_from.L_title(lang='ru')
    etree.SubElement(offer_el, 'param', attrib={'name': 'toName'}).text = settlement_to.L_title(lang='ru')
    etree.SubElement(offer_el, 'picture').text = 'https://rasp.s3.yandex.net/train-ticket.png'
    return offer_el


def generate_xml(min_price_cache):
    log.info('start generate_xml')
    yml_catalog = etree.Element('yml_catalog', attrib={'date': now_utc().strftime("%Y-%m-%d %H:%M")})
    shop = etree.SubElement(yml_catalog, 'shop')
    currencies = etree.SubElement(shop, 'currencies')
    etree.SubElement(currencies, 'currency', attrib={'id': 'RUR', 'rate': '1'})
    categories = etree.SubElement(shop, 'categories')
    for cat in [EVERYDAY_CATEGORY, NOT_EVERYDAY_CATEGORY]:
        etree.SubElement(categories, 'category', attrib={'id': cat.id}).text = cat.value
    offers = etree.SubElement(shop, 'offers')
    for settlement_from, settlement_to in get_routes():
        price, everyday = min_price_cache.get_price_by_settlements(settlement_from, settlement_to)
        offer_el = make_offer_element(settlement_from, settlement_to, price, everyday)
        if offer_el:
            offers.append(offer_el)
    res = etree.tostring(yml_catalog, xml_declaration=True, encoding='UTF-8', pretty_print=True,
                         doctype='<!DOCTYPE yml_catalog SYSTEM "shops.dtd">')
    log.info('done generate_xml')
    return res


def init_min_prices(log_search_depth, trip_days_count):
    env = settings.YANDEX_ENVIRONMENT_TYPE
    table_path = os.path.join(settings.YT_ROOT_PATH, env, 'rasp-min-train-prices-by-route',
                              datetime.today().strftime("%Y-%m-%d"))

    log.info('start init_min_prices. table_path: %s', table_path)
    station_settlement_map = make_station_settlement_map()
    log.info('done station_settlement_map')
    runner = YtMinTrainPrice(
        station_settlement_map,
        log_search_depth,
        trip_days_count,
        os.path.join(settings.YT_ROOT_PATH, env, 'rasp-min-prices-by-routes'),
        table_path,
        settings.YT_ROOT_PATH,
        settings.YT_TOKEN,
        settings.YT_PROXY
    )
    runner.start()
    log.info('done yt_work')

    min_price_cache = TrainMinPriceCache(trip_days_count, table_path)
    min_price_cache.init_prices()
    log.info('done min_price_cache')
    return min_price_cache


def make_station_settlement_map():
    query = Station.objects.filter(
        settlement__isnull=False, t_type__id=TransportType.TRAIN_ID,
        majority__id__in=[StationMajority.MAIN_IN_CITY_ID, StationMajority.EXPRESS_FAKE_ID]
    ).values_list('id', 'settlement__id')
    return {str(row[0]): row[1] for row in query}


def run(log_search_depth, trip_days_count):
    with lock('direct_feeds'), script_context('direct_feeds'):
        try:
            log.info('Start with log_search_depth={}, trip_days_count={}'.format(
                log_search_depth, trip_days_count))
            min_price_cache = init_min_prices(log_search_depth, trip_days_count)
            xml_feeds = generate_xml(min_price_cache)
            mds_s3_public_client.save_data(data=xml_feeds, key='api_public/trains_catalog.xml')
            log.info('Done')
        except Exception:
            log.exception('Ошибка')
            raise


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-l', '--log_search_depth', dest='log_search_depth', type=int, default=5,
                        help=u'поиск по указанному количеству таблиц за прошлые даты')
    parser.add_argument('-t', '--trip_days_count', dest='trip_days_count', type=int, default=14,
                        help=u'поиск цен на указанное количество дней вперед')
    args = parser.parse_args()
    run(args.log_search_depth, args.trip_days_count)
