# -*- coding: utf-8 -*-
"""The script aggregates redirect-logs of the travel service
by a number of parameters: hotel ("name"), its country, provider/operator
and where the user comes from ("source", "medium")
See HOTELS-2748, HOTELS-2958"""

import datetime as dt

from collections import deque, defaultdict
from copy import copy
from itertools import combinations
from functools import partial
from string import Template
import yaml

from library.python import resource
from travel.hotels.stats.dictionaries import SOURCES, PARTNERS, OPERATORS, SOURCE_MAP

from nile.api.v1 import (
    extractors as ne,
    aggregators as na,
    filters as nf,
    statface as ns,
    files as nfi,
    Record,
    cli,
    with_hints,
    extended_schema,
)

REPORT_TITLE = u'Статистика по redir-логам'
REPORT_PATH = "Travel/Hotels/Redir-logs"

REDIRECT = "redirect.datahc.com"
FIELDS_MAP = "FieldsMap"
FIELDDATE = "fielddate"
REDIR_TEMPLATE = Template('//logs/travel-redir-log/1d/{$first..$last}')

COMPANY = '//home/altay/db/export/current-state/snapshot/company'

UNCLASSIFIED = 1
UNKNOWN = 'unknown'
EMPTY = 'empty'
WORLD = 10000

HOUR = 60*60
THRESHOLD = HOUR*2

FIELDS = ("source", "medium", "operator",
          "status", "full_name", "offer_source")

DEFAULT_DIR = "home/travel/analytics/travel-redir-log"

OPERATORS.update({
    "2": 'booking',
    "4": 'ostrovok',
    "9": 'hotels101',
    "40": 'expedia'})


def fielddate_prototype(val, scale, date=None):
    return {"daily": val,
            "weekly": date,
            "monthly": date}[scale]


def safe_int(val):
    try:
        return int(val)
    except:
        return -1


def make_parents_dict(geobase_records):
    """Make dictionary {child_id:parent_id} from geobase"""
    children_dict = defaultdict(list)
    for elem in geobase_records:
        key, value = elem.id, elem.parent
        children_dict[value].append(key)

    parents = {WORLD: []}
    stack = [WORLD]
    while stack:
        elem = stack.pop(0)
        for child in children_dict[elem]:
            w_parents = parents[elem] + [elem]
            parents[child] = w_parents
            stack.append(child)
    return parents


def make_tree_path(region_id, geo_dict, parents):
    """Make 2 lists (parents + the object itself) from region_id:
       ids and names"""
    tree_path = parents.get(region_id, []) + [region_id]
    if len(tree_path) == 1 and tree_path[0] != WORLD:
        elem = tree_path[0]
        name_tree_path = [geo_dict.get(WORLD), "unmapped", elem]
        tree_path = [WORLD, 0, elem]
    else:
        name_tree_path = [geo_dict.get(elem) for elem in tree_path]
    return tree_path, name_tree_path


def path_to_str(path):
    """Make tab-separeted and tab-string from iterable"""
    return "\t" + "\t".join("{}".format(level) for level in path) + "\t"


@with_hints(files=[nfi.TableFile("//home/travel/prod/indexer/geobase", "geobase")],
            output_schema=extended_schema(region_path=str, raw=bool))
def add_parents(records, **options):
    geobase = options['file_streams']['geobase']
    geo_dict = {elem.id: elem.name
                for elem in geobase}
    parents = make_parents_dict(geobase)
    for rec in records:
        proto_result = rec.to_dict()
        tree_path, name_tree_path = make_tree_path(
            rec.geo_id, geo_dict, parents)
        for i in range(len(name_tree_path)):
            subpath = name_tree_path[:i+1]
            elem = int(tree_path[i])
            result = copy(proto_result)
            result["geo_id"] = elem
            result["region_path"] = path_to_str(subpath)
            result["raw"] = True if elem == rec.geo_id else False
            yield Record(**result)


@with_hints(output_schema=dict(source=str, price=int, permalink=int, uid=str, has_uid=bool,
                               operator=str, medium=str, fielddate=str, offer_source=str, delta=int, ok=bool))
def get_data(records):
    """Map raw log records into useful ones"""
    for rec in records:
        fields_map = rec[FIELDS_MAP]
        operator = OPERATORS.get(fields_map.get("OperatorId"), UNKNOWN)
        offer_source = fields_map.get("OfferSource") or EMPTY
        uid = fields_map.get("YandexUid")
        uuid = fields_map.get("Uuid")
        has_uid = bool(uid) or bool(uuid)
        source = SOURCE_MAP.get(fields_map["Source"], UNKNOWN)
        price = safe_int(fields_map.get("Price"))
        permalink = safe_int(fields_map.get("Permalink"))
        medium = fields_map.get("Medium") or EMPTY
        fielddate = rec.iso_eventtime.split()[0]
        offer_timestamp = safe_int(fields_map.get("CacheTimestamp"))
        ts = rec.unixtime
        delta = ts - offer_timestamp
        query = fields_map.get("Query")
        ok = True
        if medium == "booking_form" and source == "serp" and query == "" and not has_uid:
            ok = False
        yield Record(source=source,
                     price=price,
                     permalink=permalink,
                     uid=uid or uuid,
                     has_uid=has_uid,
                     operator=operator,
                     medium=medium,
                     offer_source=offer_source,
                     delta=delta,
                     fielddate=fielddate,
                     ok=ok)


@with_hints(output_schema=dict(full_name=str, name=str, permalink=int, country=str, region_code=str,
                               status=str, ru=str, geo_id=int, main_rubric=int))
def company_mapper(records):
    for rec in records:
        permalink = rec.permalink
        address = rec.address or {}
        countries = deque()
        for elem in address.get("components", []):
            if elem and elem.get("kind") == "country":
                name = elem["name"]
                if name["locale"] == "ru":
                    countries.appendleft(name["value"])
                else:
                    countries.append(name["value"])
        names = deque()
        for elem in rec.names:
            if elem.get("type") == "main":
                name = elem["value"]
                if name["locale"] == "ru":
                    names.appendleft(name["value"])
                else:
                    names.append(name["value"])
        the_name = names[0] if names else ""
        region_code = address.get("region_code", UNKNOWN)
        geo_id = address.get("geo_id", 0)
        rubrics = [elem["rubric_id"]
                   for elem in rec.rubrics if elem["is_main"] is True]
        main_rubric = rubrics[0] if rubrics else UNCLASSIFIED
        yield Record(
            full_name="{}_{}".format(permalink, the_name),
            permalink=permalink,
            name=the_name,
            country=countries[0] if countries else UNKNOWN,
            region_code=region_code,
            geo_id=geo_id,
            main_rubric=main_rubric,
            status=rec.publishing_status,
            ru="ru" if region_code == "RU" else "world"
        )


def add_totals(records, fields, special_cases):
    """The mapper should me curried before use"""
    for rec in records:
        proto_result = rec.to_dict()
        for idx in range(len(fields) + 1):
            for combination in combinations(fields, idx):
                result = copy(proto_result)
                for key in result.keys():
                    if key in combination:
                        if key in special_cases:
                            result[key] = special_cases[key]
                        else:
                            result[key] = "all"
                yield Record(**result)


add_totals_curried = with_hints(output_schema=extended_schema())(partial(
    add_totals, fields=FIELDS, special_cases=()))


def none_to_unknown(val):
    return val or UNKNOWN


@cli.statinfra_job
def make_job(job, options, nirvana, statface_client):
    """Standart function according to Statistics conventions,
    see https://clubs.at.yandex-team.ru/statistics/1143"""

    report_config_dict = yaml.load(resource.find('yaml_config'))
    report_config_dict["dictionaries"] = {"source":
                                          {"name": "config",
                                           "values": SOURCES},
                                          "operator":
                                          {"name": "config",
                                           "values": PARTNERS}}
    scale = options.scale
    dates = options.dates
    if scale == "daily":
        first = dates[0]
        last = dates[-1]
        if len(dates) > 1:
            suffix = "{first}_{last}".format(first=first, last=last)
        else:
            suffix = first
    elif scale == 'weekly':
        if len(dates) > 1:
            assert 'Non-daily scale is legitimate only for single date'
        else:
            last = dates[0]
            last_date = dt.datetime.strptime(last, "%Y-%m-%d").date()
            delta = last_date.weekday()
            first = (last_date - dt.timedelta(delta)).isoformat()
            suffix = "{first}_{last}".format(first=first, last=last)
    elif scale == "monthly":
        if len(dates) > 1:
            assert 'Non-daily scale is legitimate only for single date'
        else:
            last = dates[0]
            last_date = dt.datetime.strptime(last, "%Y-%m-%d").date()
            delta = last_date.day - 1
            first = (last_date - dt.timedelta(delta)).isoformat()
            suffix = "{first}_{last}".format(first=first, last=last)

    fielddate_transform = partial(
        fielddate_prototype, scale=scale, date=first if scale in ("weekly", "monthly") else None)

    report = ns.StatfaceReport() \
        .from_dict_config(report_config_dict)\
        .path(REPORT_PATH)\
        .title(REPORT_TITLE.encode('utf8'))\
        .scale(scale)\
        .client(statface_client)

    job_root = nirvana.directories[0] if nirvana.directories else DEFAULT_DIR

    job = job.env(
        templates=dict(job_root=job_root,
                       suffix=suffix,
                       )
    )

    input_table = nirvana.input_tables[0] if nirvana.input_tables else REDIR_TEMPLATE.substitute(
        first=first, last=last)
    output_table = nirvana.output_tables[0] if nirvana.output_tables else '$job_root/redirs/$scale/$suffix'

    company = job.table(COMPANY)\
        .map(company_mapper)

    data = job.table(input_table)\
        .map(get_data)\
        .project(ne.all(),
                 fielddate=ne.custom(fielddate_transform, FIELDDATE).with_type(str))\
        .filter(nf.custom(lambda x: x <= THRESHOLD, "delta"))\
        .filter(nf.custom(lambda x: bool(x), "ok"))\
        .join(company, by="permalink", type="left")\
        .project(ne.all(),
                 country=ne.custom(
                     none_to_unknown, "country").add_hints(type=str),
                 full_name=ne.custom(lambda x, y: x or "{}_unknown".format(
                     y), "full_name", "permalink").add_hints(type=str),
                 geo_id=ne.custom(lambda x: x or 0,
                                  "geo_id").add_hints(type=int),
                 status=ne.custom(none_to_unknown, "status").add_hints(type=str))\
        .map(add_parents)\
        .map(add_totals_curried)

    not_by_region = data\
        .filter(nf.equals("raw", True))\
        .groupby(FIELDDATE, *FIELDS)\
        .aggregate(total=na.count())

    data = data\
        .groupby(FIELDDATE, "region_path", *FIELDS)\
        .aggregate(clicks=na.count(),
                   uids=na.count_distinct("uid"),
                   has_uid=na.count(predicate=nf.custom(
                       lambda x: x is True, "has_uid")),
                   total_price=na.sum("price"))\
        .join(not_by_region, by=FIELDS+(FIELDDATE,))\
        .project(ne.all(),
                 percent=ne.custom(lambda clicks, total: float(clicks)/total*100, "clicks", "total").add_hints(type=float))\
        .put(output_table)\
        .publish(report, remote_mode=True, allow_change_job=True)

    return job


if __name__ == "__main__":
    cli.run()
