# -*- coding: utf8 -*-

import csv
import re
from zipfile import ZipFile, ZIP_DEFLATED

from django.utils.xmlutils import SimplerXMLGenerator

from travel.rasp.admin.lib.tmpfiles import get_tmp_filepath, clean_temp


@clean_temp
def convert_file(file_name, country_code):
    def time_to_secs(t):
        h, m, s = [int(e.lstrip("0") if e not in ('0', '00') else e) for e in t.split(':')]
        return h * 3600 + m * 60 + s

    def as_dict(f):
        data = list(csv.reader(f, delimiter=',', quotechar='"'))
        headers = data[0]
        bom = '\xef\xbb\xbf'
        if bom in headers[0]:
            headers[0] = headers[0].replace(bom, '')

        body = data[1:]
        return [{k: v for k, v in zip(headers, b)} for b in body]

    f = ZipFile(file_name)
    output_name = get_tmp_filepath('gtfs')
    output = ZipFile(output_name, 'w', compression=ZIP_DEFLATED)

    agency = as_dict(f.open('agency.txt'))
    calendar = as_dict(f.open('calendar.txt'))
    frequencies = as_dict(f.open('frequencies.txt'))
    routes = as_dict(f.open('routes.txt'))
    stop_times = as_dict(f.open('stop_times.txt'))
    stops = as_dict(f.open('stops.txt'))
    trips = as_dict(f.open('trips.txt'))
    try:
        shapes = as_dict(f.open('shapes.txt'))
    except KeyError:
        shapes = []

    calendar_by_id = {c['service_id']: c for c in calendar}

    shapes_by_id = {}
    for s in shapes:
        shape_id = s['shape_id']
        if shape_id in shapes_by_id:
            shapes_by_id[shape_id].append(s)
        else:
            shapes_by_id[shape_id] = [s]

    routes_by_id = {r['route_id']: r for r in routes}
    routes_by_type = {}

    types = {
        '3': 'bus',
        '2': 'train',
        '4': 'sea',
    }

    for r in routes:
        t = r['route_type']
        t = types[t]
        if t in routes_by_type:
            routes_by_type[t].append(r)
        else:
            routes_by_type[t] = [r]

    agency_by_id = {a['agency_id']: a for a in agency}
    stops_by_id = {s['stop_id']: s for s in stops}
    frequencies_by_trip_id = {f['trip_id']: f for f in frequencies}

    stop_times_by_trip = {}
    for st in stop_times:
        if st['trip_id'] in stop_times_by_trip:
            stop_times_by_trip[st['trip_id']].append(st)
        else:
            stop_times_by_trip[st['trip_id']] = [st]

    for t in trips:
        t['route'] = routes_by_id[t['route_id']]
        t['stop_times'] = stop_times_by_trip[t['trip_id']]
        t['agency'] = agency_by_id[t['route']['agency_id']]

    trips_by_type = {}
    for trip in trips:
        t_type = trip['route']['route_type']
        t_type = types[t_type]
        if t_type in trips_by_type:
            trips_by_type[t_type].append(trip)
        else:
            trips_by_type[t_type] = [trip]

    stops_by_type = {}
    for trip in trips:
        t_type = trip['route']['route_type']
        t_type = types[t_type]
        ss = stop_times_by_trip[trip['trip_id']]
        ss = [s['stop_id'] for s in ss]
        if t_type in stops_by_type:
            stops_by_type[t_type] += ss
        else:
            stops_by_type[t_type] = ss

    for k, items in stops_by_type.items():  # отавим в каждом типе транспорта только уникальные станции
        stops_by_type[k] = list(set(items))

    stops_by_geo = {(s['stop_lat'], s['stop_lon']): s for s in stops}

    def add_stoppoints(xml, trip):
        xml.startElement('stoppoints', {})

        for st in trip['stop_times']:
            xml.startElement('stoppoint', {
                'station_code': st['stop_id'],
                'station_title': stops_by_id[st['stop_id']]['stop_name'],
                'departure_shift': str(time_to_secs(st['departure_time'])),
            })
            xml.endElement('stoppoint')
        xml.endElement('stoppoints')

    def convert_date(s):
        return '-'.join((s[:4], s[4:6], s[6:]))

    def add_schedule(xml, trip):
        xml.startElement('schedules', {})
        calendar = calendar_by_id[trip['service_id']]
        period_start_date = convert_date(calendar['start_date'])
        period_end_date = convert_date(calendar['end_date'])

        xml.startElement('schedule', {
            'period_int': str(int(round(float(frequencies_by_trip_id[trip['trip_id']]['headway_secs']) / 60))),
            'period_start_date': period_start_date,
            'period_end_date': period_end_date,
            'period_start_time': frequencies_by_trip_id[trip['trip_id']]['start_time'],
            'period_end_time': frequencies_by_trip_id[trip['trip_id']]['end_time'],
        })
        xml.endElement('schedule')
        xml.endElement('schedules')

    def transform_route_name(t):
        title = t['route']['route_long_name']
        if re.match(r'((\d+),?)+', title):
            stop_times = t['stop_times']
            first = stops_by_id[stop_times[0]['stop_id']]['stop_name']
            last = stops_by_id[stop_times[-1]['stop_id']]['stop_name']
            return (u'%s – %s' % (first.decode('utf8'), last.decode('utf8'))).encode('utf-8')
        else:
            return title

    def add_points(xml, t):
        xml.startElement('geometry', {})
        stoppoints = t['stop_times']
        f, l = stops_by_id[stoppoints[0]["stop_id"]], stops_by_id[stoppoints[-1]["stop_id"]]
        shapes = shapes_by_id.get(t.get('shape_id'), [])
        max_shape_i = len(shapes) - 1

        for i, p in enumerate(shapes):
            lat = p['shape_pt_lat']
            lon = p['shape_pt_lon']
            s = stops_by_geo.get((lat, lon))
            if i == 0:
                s = f
            if i == max_shape_i:
                s = l
            if s:
                xml.startElement('point', {
                    'lat': lat,
                    'lon': lon,
                    'station_title': s['stop_name'],
                    'station_code': s['stop_id'],
                })
            else:
                xml.startElement('point', {
                    'lat': lat,
                    'lon': lon,
                })
            xml.endElement('point')
        xml.endElement('geometry')

    for t_type in routes_by_type:
        out_xml_name = get_tmp_filepath('gtfs')
        file_xml = open(out_xml_name, 'w')
        xml = SimplerXMLGenerator(file_xml, encoding='utf8')
        xml.startDocument()
        xml.startElement("channel", {
            'version': '1.0',
            't_type': t_type,
            'station_code_system': 'vendor',
            'timezone': 'local',
            'carrier_code_system': 'vendor',
            'vehicle_code_system': 'vendor',
        })
        xml.startElement('group', {
            'code': 'all',
            'title': 'all from gtfs',
        })

        xml.startElement('carriers', {})
        for a in agency:
            xml.startElement('carrier', {'code': a['agency_id'], 'title': a['agency_name']})
            xml.endElement('carrier')
        xml.endElement('carriers')

        xml.startElement('stations', {})
        for s_id in stops_by_type[t_type]:
            station = stops_by_id[s_id]
            xml.startElement('station', {
                'code': station['stop_id'],
                'title': station['stop_name'],
                'lat': station['stop_lat'],
                'lon': station['stop_lon'],
                'country_code': country_code,
            })
            xml.endElement('station')
        xml.endElement('stations')

        xml.startElement('threads', {})

        for t in trips_by_type[t_type]:
            xml.startElement('thread', {
                'route_long_name': transform_route_name(t),
                't_type': t_type,
                'carrier_code': t['agency']['agency_id'],
                'carrier_title': t['agency']['agency_name'],
                'timezone': t['agency']['agency_timezone'],
            })

            add_stoppoints(xml, t)
            add_schedule(xml, t)
            add_points(xml, t)

            xml.endElement('thread')
        xml.endElement('threads')
        xml.endElement('group')
        xml.endElement('channel')
        xml.endDocument()
        file_xml.close()

        output.write(out_xml_name, '%s.xml' % t_type)

    output.close()
    with open(output_name, 'rb') as f:
        return f.read()
