from typing import Dict, List, Optional
from enum import unique
from dataclasses import dataclass

import csv
import io
import logging
import re
from pathlib import Path
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
from datetime import date, timedelta, datetime
from uuid import uuid4
from base64 import urlsafe_b64encode

from library.python import resource
from yql.api.v1.client import YqlClient
from yt.wrapper import YtClient, read_table, row_count, run_map, YPath

from travel.library.python.tools import replace_args_from_env
from travel.hotels.lib.python3.cli.cli import create_progress_bar
from travel.hotels.lib.python3.lang.enum import ExtendedEnum, UseValueAsToStringMixin
from travel.hotels.lib.python3.yql.yqllib import run_yql_file as _run_yql_file
from travel.hotels.lib.python3.yt.ytlib import join, ensure_table_exists, get_default_user_path, schema_from_dict
from travel.hotels.lib.python3.yt.versioned_path import VersionedPath, parse_cleanup_strategy, DEFAULT_CLEANUP_STRATEGY

from travel.hotels.proto.suggest.suggest_pb2 import THotelSuggest, TRegionSuggest


UTF_8 = 'utf-8'
DATE_FORMAT = '%Y-%m-%d'

logging.basicConfig(level=logging.INFO)


def get_b64uuid() -> str:
    return urlsafe_b64encode(uuid4().bytes)[:-2].decode('ascii')


class RequestDescription(object):
    def __init__(self, request_resource, request_processor, parameters=None):
        self.query = resource.find(request_resource).decode(UTF_8)
        self.parameters = parameters
        self.request_processor = request_processor


def run_yql_file(client, resource_name: str, title: str, parameters: Dict):
    return _run_yql_file(client, resource_name, 'BuildSuggest', title, parameters)


def proto_message_to_b64(message) -> str:
    return urlsafe_b64encode(message.SerializeToString()).decode('ascii')


class DictionaryBuilder:
    @unique
    class SuggestParts(UseValueAsToStringMixin, ExtendedEnum):
        HOTELS = "hotels"
        REGIONS = "regions"

        def __repr__(self):
            return str(self)

    parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        '--yql-token',
        required=True,
    )
    parser.add_argument(
        '--yt-token',
        required=True,
    )
    parser.add_argument(
        '--yt-cluster',
        default='hahn',
        help='YT cluster',
    )
    parser.add_argument(
        '--yt-working-path',
        default=get_default_user_path('suggest/dictionary'),
        help='Working path in YT',
    )
    parser.add_argument(
        '--output-path',
        default='./build',
        type=lambda p: Path(p).resolve(),
        help='Local output path for text files',
    )
    parser.add_argument(
        '--suggest-parts',
        type=SuggestParts,
        nargs='+',
        default=SuggestParts.enums(),
        choices=SuggestParts.enums(),
        help='Build only selected suggest parts',
    )
    parser.add_argument(
        "--date",
        type=lambda x: datetime.strptime(x, DATE_FORMAT).date(),
        default=date.today(),
        help='Use logs up to DATE for search suggest build',
    )
    parser.add_argument(
        "--days",
        type=int,
        default=31,
        help='Use logs for DAYS days for search suggest build',
    )
    parser.add_argument(
        '--cleanup-strategy',
        type=parse_cleanup_strategy,
        default=DEFAULT_CLEANUP_STRATEGY,
        help=parse_cleanup_strategy.__doc__,
    )

    def __init__(self, argv):

        self.logger = logging.getLogger(self.__class__.__name__)

        args = self.parser.parse_args(replace_args_from_env(argv))

        self.yql_client = YqlClient(
            token=args.yql_token,
            db=args.yt_cluster,
            db_proxy=args.yt_cluster,
        )
        self.yt_client = YtClient(
            proxy=args.yt_cluster,
            token=args.yt_token,
        )

        self.output_path = args.output_path
        self.from_date = (args.date - timedelta(days=args.days)).strftime(DATE_FORMAT)
        self.to_date = args.date.strftime(DATE_FORMAT)
        self.yt_working_path = args.yt_working_path
        self.cleanup_strategy = args.cleanup_strategy

        self.suggest_parts = args.suggest_parts

    def run(self):
        with VersionedPath(self.yt_working_path, yt_client=self.yt_client, cleanup_strategy=self.cleanup_strategy) as vp:
            if DictionaryBuilder.SuggestParts.HOTELS in self.suggest_parts:
                self.process_top_hotels_suggest(vp)
            if DictionaryBuilder.SuggestParts.REGIONS in self.suggest_parts:
                self.process_region_suggest(vp)

    def process_top_hotels_suggest(self, yt_path: YPath):
        hotels_table = join(yt_path, 'hotels-raw')
        processed_table = join(yt_path, 'hotels-processed')
        ready_table = join(yt_path, 'hotels-ready')
        groups_table = join(yt_path, 'hotels-groups')

        run_yql_file(
            self.yql_client,
            '/hotels.yql',
            title='Prepare hotels data',
            parameters={
                '$output_table': str(hotels_table),
            }
        )

        path = self.output_path / 'top_hotels_suggest'
        path.mkdir(parents=True, exist_ok=True)

        ready_file = path / 'ready.txt'
        data_file = path / 'data.txt'
        streams_file = path / 'streams.txt'
        groups_file = path / 'groups.txt'

        ensure_table_exists(processed_table, yt_client=self.yt_client, schema=[
            {'name': 'lowerName', 'type': 'string'},
            {'name': 'name', 'type': 'string'},
            {'name': 'description', 'type': 'string'},
            {'name': 'permalink', 'type': 'int64'},
            {'name': 'popularity', 'type': 'double'},
        ])
        run_map(HotelNameRowMapper(), hotels_table, processed_table, client=self.yt_client)

        run_yql_file(
            self.yql_client,
            '/hotels_2.yql',
            title='Sort and reduce hotels data',
            parameters={
                '$hotels_processed': str(processed_table),
                '$hotels_groups': str(groups_table),
                '$hotels_ready': str(ready_table),
            },
        )

        total_weight = 0.
        total_records = row_count(ready_table, client=self.yt_client)
        processed_records = 0

        with open(ready_file, 'w+') as ready, open(data_file, 'w+') as data:
            with create_progress_bar('Hotel names processed', total_records) as bar:

                for row in read_table(ready_table, client=self.yt_client):
                    permalink = row['permalink']
                    lower_name = row['lowerName']
                    name = row['name']
                    description = row['description']
                    popularity = row['popularity']
                    id = f'hotel-{permalink}-ru'

                    total_weight += popularity

                    try:
                        ready.write(f'{lower_name}\t\t{popularity}\t\n')
                        bytes = proto_message_to_b64(THotelSuggest(Id=id, Name=name, Description=description, Permalink=permalink))
                        data.write(f'{lower_name}\t{bytes}\n')
                        bar.update(processed_records)
                        processed_records += 1
                    except Exception as e:
                        self.logger.exception(f'Error at name: {name}')
                        raise

        self._write_groups_to_file(groups_table, groups_file, 'Hotel')

        with open(streams_file, 'w+') as streams:
            streams.write('ALL\t{}'.format(total_weight))

        self.logger.info('Top hotels suggest successfully build')

    def _upload_csv(self, table_path, resource_name, schema):
        self.yt_client.create("table", table_path, attributes={"schema": schema_from_dict(schema)})
        converters = {
            'int32': int,
            'string': str,
        }
        csv_data = resource.find(resource_name).decode('utf-8')
        rows = []
        for line in csv.DictReader(io.StringIO(csv_data), delimiter=',', quotechar='"'):
            row = dict()
            for field, yt_type in schema.items():
                converter = converters.get(yt_type)
                if converter is None:
                    raise Exception(f"Cannot find convertor for field {field}, type is {yt_type}")
                row[field] = converter(line[field])
            rows.append(row)
        self.yt_client.write_table(table_path, rows)

    def _upload_region_synonyms(self, synonyms_table):
        self._upload_csv(synonyms_table, '/geo_synonyms.csv', {
            'geoId': 'int32',
            'name': 'string',
        })

    def _upload_region_popularity_patches(self, synonyms_table):
        self._upload_csv(synonyms_table, '/geo_popularity_patches.csv', {
            'geoId': 'int32',
            'boost': 'int32',
        })

    def process_region_suggest(self, yt_path: YPath):
        regions_table = join(yt_path, 'regions-ready')
        groups_table = join(yt_path, 'regions-groups')
        synonyms_table = join(yt_path, 'regions-synonyms-config')
        popularity_patches_table = join(yt_path, 'regions-popularity-patches-config')

        self._upload_region_synonyms(synonyms_table)
        self._upload_region_popularity_patches(popularity_patches_table)

        run_yql_file(
            self.yql_client,
            '/regions.yql',
            title='Prepare region data',
            parameters={
                '$regions_ready': str(regions_table),
                '$from': self.from_date,
                '$to': self.to_date,
                '$regions_synonyms': str(synonyms_table),
                '$regions_groups': str(groups_table),
                '$region_popularity_patches': str(popularity_patches_table),
            },
        )

        path = self.output_path / 'region_suggest'
        path.mkdir(parents=True, exist_ok=True)

        ready_file = path / 'ready.txt'
        data_file = path / 'data.txt'
        streams_file = path / 'streams.txt'
        groups_file = path / 'groups.txt'

        total_weight = 0.
        total_records = row_count(regions_table, client=self.yt_client)
        processed_records = 0

        with open(ready_file, 'w+') as ready, open(data_file, 'w+') as data:
            with create_progress_bar('Regions processed', total_records) as bar:

                for row in read_table(regions_table, client=self.yt_client):
                    name = row['name']
                    lower_name = row['lowerName']
                    geo_id = row['geoId']
                    weight = row['popularity']
                    total_weight += weight
                    id = f'region-{geo_id}-ru'

                    try:
                        ready.write(f'{lower_name}\t\t{weight}\t\n')
                        bytes = proto_message_to_b64(TRegionSuggest(Id=id, Name=name, GeoId=geo_id))
                        data.write(f'{lower_name}\t{bytes}\n')
                        bar.update(processed_records)
                        processed_records += 1
                    except Exception as e:
                        self.logger.exception(f'Error at name: {name}')
                        raise

        with open(streams_file, 'w+', encoding=UTF_8) as streams:
            streams.write('ALL\t{}'.format(total_weight))

        self._write_groups_to_file(groups_table, groups_file, 'Region')

        self.logger.info('Region suggest successfully build')

    def _write_groups_to_file(self, groups_table, groups_file, name):
        total_records = row_count(groups_table, client=self.yt_client)
        processed_records = 0
        with open(groups_file, 'w+', encoding=UTF_8) as groups:
            with create_progress_bar(f'{name} groups processed', total_records) as bar:
                for row in read_table(groups_table, client=self.yt_client):
                    try:
                        groups.write('{}\n'.format('\t'.join(row['group'])))
                        bar.update(processed_records)
                        processed_records += 1
                    except Exception as e:
                        self.logger.exception(f'Error at group: {row["group"]}')
                        raise


@dataclass
class HotelNames:
    main: Optional[str]
    all: List[str]


class HotelNameRowMapper:
    MINIMAL_POPULARITY = 0.1

    LOCALES = {
        "ru",
    }

    CHARACTERS_TO_STRIP = ' '

    def __call__(self, row):
        permalink = row['permalink']
        try:
            names = self.group_hotel_names(row)
            descriptions = self.group_hotel_descriptions(row)

            for locale in self.LOCALES:
                if locale not in names or locale not in descriptions:
                    continue

                hotel_name = names[locale].main
                description = descriptions[locale]
                popularity = row['popularity']
                permalink = row['permalink']

                # Ceil low popularity for minimal value
                if popularity < HotelNameRowMapper.MINIMAL_POPULARITY:
                    popularity = HotelNameRowMapper.MINIMAL_POPULARITY

                for uniqe_name in names[locale].all:
                    yield {
                        'name': hotel_name,
                        'lowerName': uniqe_name,
                        'description': description,
                        'popularity': popularity,
                        'permalink': permalink,
                    }
        except Exception as e:
            raise Exception(f'Error while performing Map operation on permalink {permalink}') from e

    @staticmethod
    def group_hotel_names(row):
        permalink = row['permalink']
        names = row['names']

        pre_result = {}

        name_index = 0
        for name in names:
            locale = name['locale']

            type = name['type']
            value = name['name']
            hotel_names = pre_result.setdefault(locale, HotelNames(None, []))
            if type == 'main':
                if hotel_names.main is not None:
                    raise Exception(f'Many main names for permalink {permalink} for locale {locale}')
                else:
                    hotel_names.main = value.strip(HotelNameRowMapper.CHARACTERS_TO_STRIP)
            hotel_names.all.append(value.strip(HotelNameRowMapper.CHARACTERS_TO_STRIP).lower() + '_' + str(permalink) + '_' + str(name_index))
            name_index += 1

        result = {}
        for locale, hotel_names in pre_result.items():
            if hotel_names.main is not None:
                result[locale] = hotel_names

        return result

    @staticmethod
    def group_hotel_descriptions(row):
        rubric_name = {}
        if row['rubric']:
            for rubric in row['rubric']:
                locale = rubric['locale']
                if locale not in HotelNameRowMapper.LOCALES:
                    continue

                rubric_name[locale] = rubric['value'] + ' · '

        address_value = {}
        if row['formatted']:
            for address in row['formatted']:
                if not address:
                    continue

                locale = address['locale']
                if locale not in HotelNameRowMapper.LOCALES:
                    continue

                address_value[locale] = address['value']

        result = {}
        for locale in HotelNameRowMapper.LOCALES:
            rubric_value = rubric_name.get(locale, '')
            address_value = address_value.get(locale)

            if address_value is None:
                continue

            result[locale] = rubric_value + address_value
        return result


class PornoFilterRowMapper:
    WORD_PATTERN = re.compile(r'\b[\w#@]+\b')

    def __init__(self):
        self.porno_words = set()
        for porno_word in resource.find('/porno.lst').decode(UTF_8).splitlines():
            self.porno_words.add(porno_word.lower())

    def __call__(self, row):
        if not self.is_porno(row['text']):
            yield row

    def is_porno(self, text):
        for match in self.WORD_PATTERN.finditer(text):
            word = match.group()
            if word.lower() in self.porno_words:
                return True

        return False
