# coding: utf-8

from __future__ import unicode_literals

import logging
import re
from datetime import datetime, timedelta
from collections import defaultdict

from django.db import transaction

from common.importinfo.models import Express2Country
from common.models.geo import Country, Station
from common.models.schedule import Supplier, Route, RThread, RTStation, Company, DeLuxeTrain, RThreadType
from common.models.transport import TransportType
from travel.rasp.library.python.common23.date import environment
from common.utils.caching import cache_method_result
from common.utils.date import RunMask, get_pytz, MSK_TIMEZONE
from travel.rasp.admin.importinfo.models import TrainTransportOverride, BlackList
from travel.rasp.admin.importinfo.models.express_2_country_cache import get_tz_by_express
from travel.rasp.admin.lib.mail import mail_train_import
from travel.rasp.admin.scripts.schedule.tis_train.utils import ExpressStationsGetter
from travel.rasp.admin.scripts.schedule.utils.route_loader import RouteUpdater, CompactThreadNumberBuilder, MaskUpdater
from travel.rasp.admin.www.models.schedule import Route2Company, RouteImportInfo, ExpressNumber
from travel.rasp.admin.www.utils.data import rus2translit


log = logging.getLogger(__name__)


RZD_ID = 112


def is_kazakhstan_suburban(route, digital_number):
    """
    EXRASP-3336

     * начальная или конечная станция в Казахстане (или в городе в Казахстане)
     and
     (
     * номер 7001-7998
     or
     * номер 851-898
     )
    """

    if digital_number is None:
        return False

    if (7001 <= digital_number <= 7998) or (851 <= digital_number <= 898):
        return all(Country.KAZAKHSTAN_ID == t.rtstations[0].station.country_id == t.rtstations[-1].station.country_id
                   for t in route.threads)

    return False


def is_litva_suburban(route, digital_number):
    """
    EXRASP-4075

    * начальная или конечная станция в Литве (или в городе в Литве)
    * номер 601-698
    * номер 801-898
    """
    if digital_number is None:
        return False

    if (601 <= digital_number <= 698) or (801 <= digital_number <= 898):
        return all(Country.LITVA_ID == t.rtstations[0].station.country_id == t.rtstations[-1].station.country_id
                   for t in route.threads)

    return False


def is_ukraine_suburban_express(route, digital_number):
    """
    EXRASP-12009
    """

    if digital_number is None:
        return False

    if 800 <= digital_number <= 899:
        return all(Country.UKRAINE_ID == t.rtstations[0].station.country_id == t.rtstations[-1].station.country_id
                   for t in route.threads)

    return False


def process_local_suburbans(route, digit_number):
    if is_kazakhstan_suburban(route, digit_number):
        # EXRASP-3384
        route.t_type_id = TransportType.SUBURBAN_ID

        for thread in route.threads:
            thread.t_type_id = TransportType.SUBURBAN_ID
            thread.express_type = None

    elif is_litva_suburban(route, digit_number):
        # EXRASP-4075
        route.t_type_id = TransportType.SUBURBAN_ID

        for thread in route.threads:
            thread.t_type_id = TransportType.SUBURBAN_ID
            thread.express_type = None

    elif is_ukraine_suburban_express(route, digit_number):
        # EXRASP-12009
        route.t_type_id = TransportType.SUBURBAN_ID

        for thread in route.threads:
            thread.t_type_id = TransportType.SUBURBAN_ID
            thread.express_type = 'express'

    return route


def get_digital_part(number):
    try:
        return int(number)
    except ValueError:
        try:
            return int(number[:-1])
        except ValueError:
            log.error('Не смогли выделить цифровую часть номера %s', number)


class TisImporter(object):
    def __init__(self, records_iter):
        self.records_iter = records_iter

        self.routes = {}
        self.processed_thread_keys = set()
        self.skip_processed_thread_keys = False

        self.tis = Supplier.objects.get(code='tis')
        self.af = Supplier.objects.get(code='af')
        self.company_dict = dict((r2c.number, r2c.company) for r2c in Route2Company.objects.all())
        self.deluxe_subtypes = defaultdict(set)
        self.set_deluxe_subtypes()

    def set_deluxe_subtypes(self):
        for train in DeLuxeTrain.objects.exclude(t_subtype__isnull=True):
            self.deluxe_subtypes[train.t_subtype].update(train.numbers.split('/'))

    @transaction.atomic
    def do_import(self):
        with Express2Country.using_precache():
            self.process_records()

        updater = RouteUpdater.create_route_updater(
            Route.objects.filter(supplier=self.tis), log=log, mask_updater=MaskUpdater,
            thread_number_builder=CompactThreadNumberBuilder
        )

        bad_af_numbers = []
        for number, route in sorted(self.routes.items()):
            if route.threads:
                if self.in_af_import(number, route.threads, bad_af_numbers):
                    log.info(u"Маршрут %s уже есть в импорте от Фетисова", number)
                    continue

                route = process_local_suburbans(route, get_digital_part(number))

                if route.has_script_protected_copy_in_base():
                    log.info(u"Маршрут %s уже есть в базе и помечен как защищенный, игнорируем его",
                             route.route_uid)
                    continue

                updater.add_route(route)

        updater.update()

        if bad_af_numbers:
            message = ('Маршруты от ТИС совпадающие по номеру с маршрутами от А.Ф.,'
                       ' но не совпадающие по конечным станциям:\n')
            message += '\n'.join(bad_af_numbers)
            mail_train_import('Не совпадения у А.Ф. и ТИСа', message)

        return len(updater.processed_route_uids)

    @ExpressStationsGetter.cache
    def process_records(self):
        records_count = 0

        for train_record in self.records_iter:
            records_count += 1
            try:
                self.process_train_record(train_record)
            except Exception:
                log.exception('Ошибка при обработке данных от Тис %s', train_record.record_key)
            if not records_count % 1000:
                log.info('%d records have been processed', records_count)

    def process_train_record(self, train_record):
        number = train_record.number
        thread_type = train_record.thread_type

        if self.in_ignore_list(number, thread_type):
            log.debug('Поезд %s %s попал в ignore list', number, thread_type.code)
            return

        try:
            start_station = train_record.start_station
        except Station.DoesNotExist:
            log.error('Не нашли начальную станцию %s маршрута %s', train_record.start_station_code, number)
            return

        # При обновления из файла за прошлую неделю, не учитываем те рейсы, которые
        # уже пришли в основном файле.
        processed_thread_key = train_record.record_key
        if self.skip_processed_thread_keys and processed_thread_key in self.processed_thread_keys:
            return
        elif not self.skip_processed_thread_keys:
            self.processed_thread_keys.add(processed_thread_key)

        if number in self.routes:
            route = self.routes[number]
        else:
            route = Route(t_type_id=train_record.t_type_id, supplier=self.tis)
            route.threads = []
            route.thread_uid_map = {}
            try:
                route_info = RouteImportInfo.objects.get(number=number, supplier=self.tis)
                route.t_type = route_info.t_type
            except RouteImportInfo.DoesNotExist:
                route_info = None

            transport_override = self.get_transport_override(number)
            if transport_override:
                route.t_type = transport_override.t_type

            if transport_override and route_info:
                log.error('Траспорт для рейса %s переопределяется дважды в RouteImportInfo'
                          ' и в TrainTransportOverride')

            self.routes[number] = route

        if len(train_record.stop_records) < 2:
            log.warning('Не достаточно станций указано для записи %s', train_record.record_key)
            return

        thread = RThread(
            year_days=RunMask.EMPTY_YEAR_DAYS,
            number=number,
            express_type=ExpressNumber.get_express_type(number, self.tis),
            company=self.get_company(train_record),
            route=route,
            t_type=route.t_type,
            supplier=route.supplier
        )

        transport_override = self.get_transport_override(number)
        if transport_override:
            thread.t_subtype = transport_override.t_subtype

        if thread.t_subtype is None:
            thread.t_subtype = self.get_t_subtype_by_deluxe_number(number)

        try:
            if thread_type.id == RThreadType.THROUGH_TRAIN_ID:
                thread.canonical_uid = 'R_{}_{}_{}_{}'.format(
                    rus2translit(thread.number),
                    train_record.start_station.id,
                    train_record.end_station.id,
                    thread.company.id
                )
            else:
                thread.canonical_uid = 'R_{}_{}'.format(rus2translit(thread.number), thread.company.id)
        except Exception as ex:
            log.error('Не смогли заполнить canonical_uid. %s', repr(ex))

        thread.type = thread_type
        thread.rtstations = []

        start_time = train_record.stop_records[0].departure_time

        if not train_record.run_mask:
            return

        mask_first_day = train_record.run_mask.iter_dates(past=True).next()
        # Местное время отправления
        naive_start_dt = datetime.combine(mask_first_day, start_time)
        # Увеличиваемвремя отправления на на start_day_shift
        naive_start_dt += timedelta(days=train_record.start_day_shift)
        start_pytz = self.get_tz(start_station, train_record.start_station_code)

        thread.rtstations = self.make_rtstations(train_record, naive_start_dt, start_pytz)
        if len(thread.rtstations) < 2:
            log.warning('Не достаточно станций разобралось у записи %s', train_record.record_key)
            return

        thread.tz_start_time = naive_start_dt.time()
        thread.time_zone = start_pytz.zone
        thread.gen_title()

        # Для генерации import_uid нужен route_uid
        if not route.route_uid:
            route.route_uid = thread.gen_route_uid()
        thread.gen_import_uid()

        already_added_thread = next((t for t in route.threads if t.import_uid == thread.import_uid), None)
        if already_added_thread is None:
            thread.mask = RunMask(today=environment.today())
            route.threads.append(thread)
        else:
            thread = already_added_thread

        thread.mask |= train_record.run_mask
        thread.year_days = str(thread.mask)

    def get_t_subtype_by_deluxe_number(self, number):
        for t_subtype, numbers in self.deluxe_subtypes.items():
            if number in numbers:
                return t_subtype

    @cache_method_result
    def get_transport_override(self, number):
        for override in self.get_all_overrides():
            if override.number_re.match(number):
                return override

    @cache_method_result
    def get_all_overrides(self):
        try:
            overrides = list(TrainTransportOverride.objects.all())
            for override in overrides:
                override.number_re = re.compile(override.number, re.UNICODE + re.IGNORECASE)

            return overrides
        except Exception:
            log.exception('Не смогли сформировать переопределения для TIS')
            return []

    def in_af_import(self, number, threads, bad_af_numbers):
        digital_number = number

        if not number[-1].isdigit():
            digital_number = number[:-1]

        af_threads = list(RThread.objects.filter(route__supplier=self.af, number=digital_number))

        if not af_threads:
            return False

        af_directions = {(af_thread.path[0].station.id, af_thread.path.reverse()[0].station.id)
                         for af_thread in af_threads}

        directions = {(t.rtstations[0].station.id, t.rtstations[-1].station.id) for t in threads}

        has_same_direction = any(d in af_directions for d in directions)

        if not has_same_direction:
            bad_af_numbers.append(number)

        return has_same_direction

    def in_ignore_list(self, number, thread_type):
        return BlackList.has_train_number(number, self.tis, thread_type)

    @cache_method_result
    def get_tz(self, station, express_id):
        time_zone = get_tz_by_express(express_id)

        if not time_zone:
            log.error('Не нашли временную зону по коду %s, берем страну станции или Москву', express_id)
            if station.country:
                time_zone = station.country.get_capital_tz()

        return get_pytz(time_zone or MSK_TIMEZONE)

    def make_rtstations(self, train_record, naive_start_dt, start_pytz):
        rtstations = []
        start_dt = start_pytz.localize(naive_start_dt)
        prev_dt = start_dt

        start_date = naive_start_dt.date()
        for index, stop_record in enumerate(train_record.stop_records):
            try:
                station = stop_record.station
            except Station.DoesNotExist:
                log.error('Не нашли станции по %s в рейсе %s', stop_record.station_code, train_record.number)
                continue

            st_dep_time = stop_record.departure_time

            naive_departure_dt = datetime.combine(start_date, st_dep_time)
            naive_departure_dt += timedelta(days=stop_record.day_shift)
            rts_pytz = self.get_tz(station, stop_record.station_code)
            departure_dt = rts_pytz.localize(naive_departure_dt)

            if prev_dt > departure_dt:
                log.error(
                    '%s %s время старта поезда или время отправления с предыдущей станции %s'
                    ' меньше, чем время отправления %s со станции %s %s',
                    train_record.number, train_record.record_key, prev_dt, departure_dt, station.id, station.title
                )

            prev_dt = departure_dt

            tz_departure = (naive_departure_dt - naive_start_dt).total_seconds() / 60

            tz_arrival = tz_departure - stop_record.stop_time

            rtstation = RTStation(
                tz_arrival=tz_arrival,
                tz_departure=tz_departure,
                station=station,
                time_zone=rts_pytz.zone,
                is_technical_stop=stop_record.is_technical
            )

            if hasattr(stop_record, 'departure_number'):
                rtstation.train_number = stop_record.departure_number
            elif hasattr(stop_record, 'arrival_number'):
                rtstation.train_number = stop_record.arrival_number

            rtstations.append(rtstation)

        if rtstations:
            rtstations[-1].tz_departure = None
            rtstations[0].tz_arrival = None

        return rtstations

    def get_company_by_number(self, number):
        return self.company_dict.get(number)

    def get_company(self, train_record):
        company_by_code = self.get_company_by_code(train_record.company_code)
        company_by_number = self.get_company_by_number(train_record.number)

        if company_by_code and company_by_number and company_by_code != company_by_number:
            log.warning('Компания по коду "%s" %s отличается от компании по номеру "%s" %s берем компанию по номеру',
                        train_record.company_code, company_by_code.L_title(),
                        train_record.number, company_by_number.L_title())

        if company_by_number or company_by_code:
            return company_by_number or company_by_code

        company = self.get_company_by_kg(train_record.country_code)
        log.error('Не нашли компанию по коду "%s", поезд %s. %s',
                  train_record.company_code, train_record.number,
                  'Берем компанию {} {}'.format(company.id, company.L_title())
                  if company else 'Не нашли компании по коду государства и/или номеру поезда.')

        return company

    @cache_method_result
    def get_company_by_code(self, code):
        if not code:
            return
        try:
            return Company.objects.get(tis_code=code)
        except Company.DoesNotExist:
            pass

    @cache_method_result
    def get_company_by_kg(self, kg):
        """ Ищет компанию-перевозчика по коду государства """

        company = None  # дефолтовое значение компании-перевозчика
        company_country = Express2Country.get_country(kg)
        if not company_country:
            log.error('Не найдена страна с кодом = %s!', kg)
            return company

        if company_country.id == Country.RUSSIA_ID:
            return Company.objects.get(id=RZD_ID)

        companies = Company.objects.filter(country=company_country, t_type_id=TransportType.TRAIN_ID)

        if not companies:
            company = Company()
            company.title = 'ЖД ' + company_country.title
            company.country = company_country
            company.t_type_id = TransportType.TRAIN_ID
            company.short_title = company.title if len(company.title) < 10 else '*'
            company.save()
            return company
        else:
            return companies[0]
