# -*- coding: utf-8 -*-

from argparse import ArgumentParser
from itertools import groupby, product
# noinspection PyUnresolvedReferences
from tzlocal import get_localzone
from xml.etree.ElementTree import Element, ElementTree
import logging
import sys
import os
import shutil
import datetime
import csv
import yaml

import boto3
import jinja2

from library.python import resource
from travel.library.python.tools import replace_args_from_env
from travel.hotels.lib.python3.yql import yqllib
from travel.hotels.lib.python3.yt import ytlib
from travel.hotels.lib.python3.yt.versioned_path import VersionedPath

DATE_FORMAT = '%Y-%m-%d'

MARKET_HOTELS_FOR_CATEGORY_MIN = 3
RUB_CURRENCY_ID = 'RUB'


class Runner(object):
    def __init__(self, args):
        self.yql_client = yqllib.create_client(db=args.yt_proxy, token=args.yql_token, token_path=args.yql_token_path)
        yt_config = {
            'token': args.yt_token,
            'token_path': args.yt_token_path,
        }
        self.yt_client = ytlib.create_client(proxy=args.yt_proxy, config=yt_config)
        self.args = args
        today = datetime.date.today()
        self.today_str = today.isoformat()

    def run(self):
        if sum([self.args.build_hotels_feed, self.args.build_regions_feed, self.args.build_regions_filters_feed]) > 1:
            raise Exception('Only one feed can be built in a run')

        if self.args.build_hotels_feed:
            with VersionedPath(self.args.yt_path, yt_client=self.yt_client) as work_path:
                self.do_work_hotels(work_path)
        if self.args.build_market_feed:
            with VersionedPath(self.args.yt_path, yt_client=self.yt_client) as work_path:
                self.do_work_market(work_path)
        if self.args.build_regions_feed:
            with VersionedPath(self.args.yt_path, yt_client=self.yt_client) as work_path:
                self.do_work_regions(work_path)
        if self.args.build_regions_filters_feed:
            with VersionedPath(self.args.yt_path, yt_client=self.yt_client) as work_path:
                self.do_work_regions_with_filters(work_path)
        if self.args.build_bboxes_feed:
            with VersionedPath(self.args.yt_path, yt_client=self.yt_client) as work_path:
                self.do_work_bboxes_feed(work_path)

    def do_work_hotels(self, work_path):
        full_feed_table = ytlib.join(work_path, '_full')
        local_work_dir = 'feeds'
        self._recreate_local_dir(local_work_dir)

        self.prepare_full_feed(full_feed_table)
        feed_names = self.generate_feeds(full_feed_table, work_path)
        self.put_feeds_to_files(work_path, local_work_dir, [(x, x) for x in feed_names])
        self.sync_with_s3(local_work_dir, feed_names)

    def do_work_regions(self, work_path):
        full_feed_table = ytlib.join(work_path, 'regions')
        local_work_dir = 'feeds'
        self._recreate_local_dir(local_work_dir)

        feed_names = [('regions', 'regions.csv')]
        self.prepare_regions_feed(full_feed_table, 'Generate regions feed', 'prepare_feed.yql', True)
        self.put_feeds_to_files(work_path, local_work_dir, feed_names)
        self.sync_with_s3(local_work_dir, [x[1] for x in feed_names])

    def do_work_regions_with_filters(self, work_path):
        full_feed_table = ytlib.join(work_path, 'regions_with_filters')
        local_work_dir = 'feeds'
        self._recreate_local_dir(local_work_dir)

        feed_names = [('regions_with_filters', 'regions-with-filters.csv')]
        self.prepare_regions_feed(full_feed_table, 'Generate regions with filters feed', 'prepare_feed_with_filters.yql', True)
        self.put_feeds_to_files(work_path, local_work_dir, feed_names)
        self.sync_with_s3(local_work_dir, [x[1] for x in feed_names])

    def do_work_bboxes_feed(self, work_path):
        full_feed_table = ytlib.join(work_path, 'bboxes')
        local_work_dir = 'feeds'
        self._recreate_local_dir(local_work_dir)

        feed_names = [('bboxes', 'bboxes.csv')]
        self.prepare_regions_feed(full_feed_table, 'Generate bboxes feed', 'bboxes_feed.yql', False)
        self.put_feeds_to_files(work_path, local_work_dir, feed_names)
        self.sync_with_s3(local_work_dir, [x[1] for x in feed_names])

    def do_work_market(self, work_path):
        full_feed_table = ytlib.join(work_path, '_full')
        market_feed_table = ytlib.join(work_path, '_market')

        self.prepare_full_feed(full_feed_table)
        self.prepare_market_feed(full_feed_table, market_feed_table)

        local_work_dir = 'feeds'
        self._recreate_local_dir(local_work_dir)
        feed_names = self.generate_market_feeds(market_feed_table, local_work_dir)
        self.sync_with_s3(local_work_dir, feed_names)

    def prepare_regions_feed(self, full_feed_table, title, yql_name, add_regions_queries_params):
        logging.info("Preparing regions feed table")
        now = datetime.datetime.now()
        from_date = datetime.datetime.strftime(now - datetime.timedelta(self.args.regions_feed_stats_days), DATE_FORMAT)
        to_date = datetime.datetime.strftime(now, DATE_FORMAT)
        params = {
            '$output_table': full_feed_table,
            '$portal_prefix': self.args.portal_prefix,
            '$orders_from': from_date,
            '$orders_to': to_date,
            '$cluster_permalink_prices': self.args.cluster_permalink_prices,
        }
        if add_regions_queries_params:
            params['$region_queries_from'] = from_date
            params['$region_queries_to'] = to_date

        yqllib.run_yql_file(
            self.yql_client,
            f'regions_feed_queries/{yql_name}', 'MarketingFeedBuilder',
            title=f'YQL:MarketingFeedBuilder:{title}',
            parameters=params,
            attaches={
                'region_feeds_lib.sql': 'USE `{}`;\n'.format(self.yql_client.config.db) + resource.find('regions_feed_queries/region_feeds_lib.sql').decode('utf-8')
            }
        )

    def _recreate_local_dir(self, path):
        if os.path.exists(path):
            shutil.rmtree(path)
        os.mkdir(path)

    def prepare_full_feed(self, full_feed_table):
        logging.info("Preparing full feed table")
        yqllib.run_yql_file(
            self.yql_client,
            'hotels_feed_queries/prepare_feed.yql', 'MarketingFeedBuilder',
            parameters={
                '$output_table': full_feed_table,
                '$portal_prefix': self.args.portal_prefix,
                '$cluster_permalink_prices': self.args.cluster_permalink_prices,
            },
        )

    def prepare_market_feed(self, full_feed_table: ytlib.YPath, market_feed_table: ytlib.YPath) -> None:
        logging.info("Preparing market feed table")
        yqllib.run_yql_file(
            self.yql_client,
            'market_feed_queries/prepare_feed.yql', 'MarketingFeedBuilder',
            parameters={
                '$full_feed_table': full_feed_table,
                '$output_table': market_feed_table,
            },
        )

    def generate_market_feeds(self, market_feed_table: ytlib.YPath, local_work_dir: str) -> list[str]:

        by_region = dict()
        for region, hotels in groupby(self.yt_client.read_table(market_feed_table), lambda x: x['Destination name']):
            by_region[region] = list(hotels)

        now = str(datetime.datetime.now(get_localzone()))

        regions_fn = 'market_regions.xml'
        regions_path = os.path.join(local_work_dir, regions_fn)
        self.generate_market_feed_regions(now, by_region, regions_path)

        hotels_fn = 'market_hotels.xml'
        hotels_path = os.path.join(local_work_dir, hotels_fn)
        self.generate_market_feed_hotels(now, by_region, hotels_path)

        return [regions_fn, hotels_fn]

    def generate_market_feed_regions(
        self,
        now: str,
        by_region: dict[str, list[dict[str, ...]]],
        regions_path: str,
    ) -> None:
        category_ids = set()
        tree, categories, offers = self.get_market_feed_skeleton(now)

        for region, hotels in by_region.items():
            if len(hotels) < MARKET_HOTELS_FOR_CATEGORY_MIN:
                continue
            for hotel in hotels:
                category_id = str(hotel['Destination id'])
                category_name = hotel['Destination name']

                if category_id not in category_ids:
                    categories.append(self.element_with_text('category', category_name, id=category_id))
                    category_ids.add(category_id)

                offers.append(self.get_offer(category_id, hotel))

        with open(regions_path, 'wb') as f:
            tree.write(f)

    def generate_market_feed_hotels(
        self,
        now: str,
        by_region: dict[str, list[dict[str, ...]]],
        hotels_path: str,
    ) -> None:
        category_ids = set()
        tree, categories, offers = self.get_market_feed_skeleton(now)

        for region, hotels in by_region.items():
            if len(hotels) < MARKET_HOTELS_FOR_CATEGORY_MIN + 1:
                continue
            for category, hotel in product(hotels, hotels):
                if category == hotel:
                    continue
                category_id = str(category['Property ID'])
                category_name = category['Property name']

                if category_id not in category_ids:
                    categories.append(self.element_with_text('category', category_name, id=category_id))
                    category_ids.add(category_id)

                offers.append(self.get_offer(category_id, hotel))

        with open(hotels_path, 'wb') as f:
            tree.write(f)

    def get_market_feed_skeleton(self, now: str) -> (ElementTree, Element, Element):
        root = Element('yml_catalog', date=now)
        tree = ElementTree(root)
        shop = Element('shop')
        root.append(shop)

        name = self.element_with_text('name', 'Яндекс.Путешествия')
        company = self.element_with_text('company', 'ООО "Яндекс"')
        url = self.element_with_text('url', 'https://travel.yandex.ru/hotels')
        currencies = Element('currencies')
        currencies.append(Element('currency', id=RUB_CURRENCY_ID, rate='1'))
        categories = Element('categories')
        offers = Element('offers')
        shop.extend([name, company, url, currencies, categories, offers])

        return tree, categories, offers

    @staticmethod
    def element_with_text(tag: str, text: str, **attrs) -> Element:
        element = Element(tag, attrs)
        element.text = text
        return element

    def get_offer(self, category_id: str, hotel: dict[str, ...]) -> Element:
        offer = Element('offer', id=str(hotel['Property ID']), available='true')
        offer.append(self.element_with_text('name', hotel['Property name']))
        offer.append(self.element_with_text('url', hotel['Final URL']))
        offer.append(self.element_with_text('picture', hotel['Image URL']))
        offer.append(self.element_with_text('price', str(hotel['Price'])))
        offer.append(self.element_with_text('currencyId', RUB_CURRENCY_ID))
        offer.append(self.element_with_text('categoryId', category_id))
        return offer

    def generate_feeds(self, full_feed_table, work_path):
        feeds_dict = yaml.safe_load(resource.find("hotels_feed_queries/feeds.yaml"))
        logging.info(f'Feeds configuration: {feeds_dict}')
        generate_feeds_query_template = resource.find("hotels_feed_queries/generate_feeds.yql.template").decode('utf-8')
        generate_feeds_query = jinja2.Template(
            generate_feeds_query_template,
            undefined=jinja2.StrictUndefined
        ).render(filters=feeds_dict, output_path=work_path)
        yqllib.wait_results(yqllib.run_query(
            self.yql_client,
            generate_feeds_query,
            title='YQL:MarketingFeedBuilder:Generate feeds',
            parameters={
                '$full_feed_table': full_feed_table,
            },
        ))
        return feeds_dict.keys()

    def put_feeds_to_files(self, work_path, local_work_dir, feed_names):
        for yt_name, file_name in feed_names:
            feed_table = ytlib.join(work_path, yt_name)
            local_file = os.path.join(local_work_dir, file_name)
            with open(local_file, 'wt', newline='') as fout:
                writer = None
                for row in self.yt_client.read_table(self.yt_client.TablePath(feed_table)):
                    if writer is None:  # Determine columns by first row
                        writer = csv.DictWriter(fout, row.keys(), restval='', dialect='excel', delimiter=',',
                                                extrasaction='ignore',
                                                quoting=csv.QUOTE_ALL, quotechar='"')
                        writer.writerow({x: x for x in row.keys()})  # Header
                    writer.writerow(row)

    def sync_with_s3(self, local_work_dir, file_list):
        s3_pfx = self.args.s3_prefix
        if not s3_pfx.endswith('/'):
            s3_pfx += '/'
        session = boto3.session.Session(
            aws_access_key_id=self.args.s3_access_key,
            aws_secret_access_key=self.args.s3_access_secret_key,
        )
        s3 = session.client(service_name='s3', endpoint_url=self.args.s3_endpoint, verify=False)
        allowed_keys = set()
        for fn in file_list:
            local_fn = os.path.join(local_work_dir, fn)
            key = s3_pfx + fn
            logging.info(f'Uploading {local_fn} -> {key}')

            s3.put_object(
                Bucket=self.args.s3_bucket,
                Key=key,
                Body=open(local_fn, 'rb'),
            )
            allowed_keys.add(key)

        resp = s3.list_objects(Bucket=self.args.s3_bucket, Prefix=s3_pfx)
        for o in resp['Contents']:
            key = o['Key']
            if key in allowed_keys:
                pass
            else:
                logging.info(f"Remove '{key}'")
                s3.delete_object(Bucket=self.args.s3_bucket, Key=key)


def main():
    logging.basicConfig(level=logging.INFO, format="%(asctime)-15s | %(module)s | %(levelname)s | %(message)s", stream=sys.stdout)
    logging.getLogger('yt.packages.urllib3.connectionpool').setLevel(logging.WARNING)

    parser = ArgumentParser()
    parser.add_argument('--yt-proxy', default='hahn')
    parser.add_argument('--yt-token')
    parser.add_argument('--yt-token-path')
    parser.add_argument('--yql-token')
    parser.add_argument('--yql-token-path')
    parser.add_argument('--yt-path', required=True)
    parser.add_argument('--s3-endpoint', default='https://s3.mds.yandex.net')
    parser.add_argument('--s3-bucket', default='travel-indexer')
    parser.add_argument('--s3-prefix', required=True)
    parser.add_argument('--s3-access-key', required=True)
    parser.add_argument('--s3-access-secret-key', required=True)
    parser.add_argument('--cluster-permalink-prices', required=True)
    parser.add_argument('--portal-prefix', default='https://travel.yandex.ru')
    parser.add_argument('--build-hotels-feed', action='store_true')
    parser.add_argument('--build-market-feed', action='store_true')
    parser.add_argument('--build-regions-feed', action='store_true')
    parser.add_argument('--build-regions-filters-feed', action='store_true')
    parser.add_argument('--build-bboxes-feed', action='store_true')
    parser.add_argument('--regions-feed-stats-days', type=int, default=28)
    args = parser.parse_args(args=replace_args_from_env())
    Runner(args).run()


if __name__ == '__main__':
    main()
