# coding: utf8
from __future__ import unicode_literals, absolute_import, division, print_function

import json
import hashlib
import logging

from collections import defaultdict
from datetime import datetime, timedelta
from requests import Session
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

from common.db.mds.clients import mds_s3_common_client
from common.cysix.builder import (CarrierBlock, ChannelBlock, GroupBlock,
                                  ThreadBlock, ScheduleBlock, StationBlock, StoppointBlock)
from travel.rasp.library.python.common23.date.environment import today
from rasp_vault.api import get_secret

log = logging.getLogger(__name__)


def parse_datetime(text):
    return datetime.strptime(text, '%Y-%m-%d %H:%M:%S')


class OKRoute(object):
    def __init__(self, ride_id, date, segments):
        self.ride_id = ride_id
        self.date = date
        self.segments = sorted(segments, key=lambda s: parse_datetime(s['datetime_end']))
        self.stations = [s['station_end'] for s in self.segments]
        self.path = '-'.join(str(s['station_id']) for s in self.stations)

        main_segment = segments[0]
        self.start_dt = parse_datetime(main_segment['datetime_start'])
        self.start_time = self.start_dt.strftime('%H:%M')

        self.carrier_title = main_segment['carrier_title']
        self.key = ':'.join(k for k in [self.start_time, self.path, self.carrier_title] if k)


class OKStation(object):
    def __init__(self, station, distance, arrival_shift=None, departure_shift=None):
        self.code = str(station['station_id'])
        self.title = self.make_station_title(station)
        self.code_system = 'vendor'
        self.distance = distance
        self.departure_shift = departure_shift
        self.arrival_shift = arrival_shift
        self.lat = station.get('lat')
        self.lon = station.get('lng')
        self.country_code = station.get('country_iso')

    def make_station_title(self, station):
        station_title = station.get('station_title')
        city_title = station.get('city_title')
        district_title = station.get('district_title')
        region_title = station.get('region_title')

        if district_title:
            return '{} ({}, {}, {})'.format(station_title, city_title, district_title, region_title)
        else:
            return '{} ({}, {})'.format(station_title, city_title, region_title)


class OKCarrier(object):
    def __init__(self, code, title):
        self.code = code
        self.title = title


class OKThread(object):
    global_carriers = {}

    def __init__(self, routes):
        self.dates = [r.date for r in routes]
        main_route = routes[0]

        self.start_time = main_route.start_time
        self._make_stops(main_route)
        self._make_carrier(main_route)

        raw_title = main_route.segments[0]['route_name']

        if raw_title[0].isdigit():
            del_pos = raw_title.find(' ')
            self.number = raw_title[:del_pos]
            self.title = raw_title[del_pos + 1:]
        else:
            self.title = raw_title
            self.number = ''

    def _make_carrier(self, route):
        carrier_title = route.carrier_title
        if not carrier_title:
            self.carrier = None
            return

        if carrier_title not in self.global_carriers:
            code = str(len(self.global_carriers) + 1)
            self.global_carriers[carrier_title] = OKCarrier(code, carrier_title)

        self.carrier = self.global_carriers[carrier_title]

    def _make_stops(self, route):
        self.stops = []
        self.prices = []

        segments = route.segments
        first_segment = segments[0]

        self.base_station = OKStation(
            first_segment['station_start'],
            distance=0,
            departure_shift='0'
        )

        self.stops.append(self.base_station)

        for segment in route.segments:
            end_dt = parse_datetime(segment['datetime_end'])
            shift = int((end_dt - route.start_dt).total_seconds())

            station = OKStation(segment['station_end'], segment['distance'], str(shift - 1), str(shift))
            self.stops.append(station)
            self.prices.append({
                'value': str(segment['price_source_tariff']),
                'currency': 'RUR',
                'stop_from': self.base_station,
                'stop_to': station
            })

        last_stop = self.stops[-1]
        last_stop.arrival_shift = last_stop.departure_shift
        last_stop.departure_shift = None


class OKBuilder(object):
    def __init__(self):
        self.channel_block = ChannelBlock(
            'bus', timezone='start_station',
            carrier_code_system='temporary_vendor',
            station_code_system='vendor'
        )

    def add_group(self, city, threads):
        base_station = threads[0].base_station
        group_block = GroupBlock(self.channel_block, title=base_station.title, code=str(city), t_type='bus')

        for thread in threads:
            self.build_thread_block(group_block, thread)
            self.build_carrier_block(group_block, thread.carrier)

            fare_block = group_block.add_local_fare()
            self.build_prices(fare_block, thread)

        self.channel_block.add_group_block(group_block)

    def build_thread_block(self, group_block, thread):
        thread_block = ThreadBlock(group_block, title=thread.title, number=thread.number, carrier=thread.carrier)

        for stop in thread.stops:
            self.build_stoppoint_block(thread_block, stop)

            self.build_station_block(group_block, stop)

        schedule_block = ScheduleBlock(thread_block, ';'.join(thread.dates), times=thread.start_time)
        thread_block.add_schedule_block(schedule_block)

        group_block.add_thread_block(thread_block)

    def build_carrier_block(self, group_block, carrier):
        if carrier:
            carrier.code_system = 'temporary_vendor'

            carrier_block = CarrierBlock(group_block, carrier.title, carrier.code)
            group_block.add_carrier_block(carrier_block)

    def build_prices(self, fare_block, thread):
        for price in thread.prices:
            fare_block.add_price_block(
                price['value'],
                price['currency'],
                price['stop_from'],
                price['stop_to']
            )

    def build_stoppoint_block(self, thread_block, station):
        args = dict(
            departure_shift=station.departure_shift,
            arrival_shift=station.arrival_shift
        )

        if station.distance is not None:
            args['distance'] = station.distance

        stoppoint_block = StoppointBlock(thread_block, station, **args)
        thread_block.add_stoppoint_block(stoppoint_block)

    def build_station_block(self, group, station):
        station_block = StationBlock(group, station.title, station.code, code_system='vendor')

        station_block.lat = station.lat
        station_block.lon = station.lon
        station_block.country_code = station.country_code

        group.add_station_block(station_block)

    def to_xml(self):
        return self.channel_block.to_unicode_xml()


class OKParser(object):
    def __init__(self,
                 agent_id,
                 secret_key,
                 base_url='https://api-gds.odnakassa.ru/ride/list',
                 cities=(1375, 1193, 1227),
                 start_date=today(),
                 days_shift=14,
                 **kwargs):

        self.base_url = base_url
        self.agent_id = agent_id
        self.secret_key = secret_key
        self.cities = cities
        self.start_date = start_date
        self.days_shift = days_shift

        session = Session()
        retries = Retry(total=3, backoff_factor=0.5)
        adapter = HTTPAdapter(max_retries=retries)
        session.mount(base_url, adapter)
        self._session = session

    def get_hash_code(self, start_city_id, date):
        key = 'agent_id={}city_id_start={}date={}secret_key={}'.format(
            self.agent_id, start_city_id, date, self.secret_key)

        return hashlib.md5(key.encode()).hexdigest()

    def get_segments(self, city, date):
        response = self._session.get(self.base_url, timeout=120, params={
            'agent_id': self.agent_id,
            'city_id_start': city,
            'date': date,
            'hash': self.get_hash_code(city, date)
        })

        log.info('Request: {}, response code: {}'.format(
            response.request.url,
            response.status_code,
        ))

        response.raise_for_status()
        json_data = json.loads(response.content)['data']

        return json_data.get('ride_list')

    def get_routes_for_date(self, city, date):
        segments = self.get_segments(city, date)
        log.info('Found {} segments for city {} on {}'.format(len(segments), city, date))

        segment_groups = defaultdict(list)
        for ride in segments:
            segment_groups[ride['ride_id']].append(ride)

        return [OKRoute(k, date, v) for k, v in segment_groups.items()]

    def get_threads_for_date_range(self, city, start_date, days):
        routes = []
        dates = [(start_date + timedelta(days=d)).strftime('%Y-%m-%d') for d in range(days + 1)]

        for ride_date in dates:
            date_segments = self.get_routes_for_date(city, ride_date)
            routes.extend(date_segments)

        log.info('Created {} routes for city {} on dates {} - {}'.format(len(routes), city, dates[0], dates[-1]))

        groups = defaultdict(list)
        for route in routes:
            groups[route.key].append(route)

        return [OKThread(v) for v in groups.values()]

    def build_schedule(self):
        xml_builder = OKBuilder()

        for city in self.cities:
            threads = self.get_threads_for_date_range(city, self.start_date, self.days_shift)
            log.info('Created {} threads for city {}'.format(len(threads), city))

            if threads:
                xml_builder.add_group(city, threads)

        return xml_builder.to_xml()


def run(**kwargs):
    """
        https://st.yandex-team.ru/RASPFRONT-8294 Создавать файл расписания по Твери из АПИ Одной Кассы
    """

    kwargs.setdefault('agent_id', get_secret('rasp-common.RASP_ODNAKASSA_AGENT_ID'))
    kwargs.setdefault('secret_key', get_secret('rasp-common.RASP_ODNAKASSA_SECRET_KEY'))

    parser = OKParser(**kwargs)

    schedule = parser.build_schedule()
    mds_path = kwargs.get('mds_path', 'schedule/bus/odnakassa_tmp.xml')
    mds_s3_common_client.save_data(mds_path, schedule)

    log.info('Schedule uploaded to {}'.format(mds_path))
