import logging
import os
from functools import lru_cache
from typing import Optional

import boto3
import click
import sys
from pydantic import BaseSettings
from yql.api.v1.client import YqlClient
from yt.wrapper import YtClient

from travel.avia.ad_feed.ad_feed.airport_blacklist import AirportBlacklist
from travel.avia.ad_feed.ad_feed.metrics import send_file_metrics
from travel.avia.ad_feed.ad_feed.runner.tools import EnumType
from travel.avia.ad_feed.ad_feed.converter.factory import create_smart_banners_converter
from travel.avia.ad_feed.ad_feed.converter.runner import S3Path, dump_yt_to_s3_csv
from travel.avia.ad_feed.ad_feed.entities import NationalVersion
from travel.avia.ad_feed.ad_feed.environment import Environment
from travel.avia.ad_feed.ad_feed.feed_generator.const import OUTPUT_TABLE_PREFIX_BY_ENV, SOURCE_TABLE_BY_ENV
from travel.avia.ad_feed.ad_feed.feed_generator.factory import (
    create_yql_generator,
    get_feed_config,
    get_solomon_reporter,
)
from travel.avia.ad_feed.ad_feed.settings import MdsSettings, YtSettings, YqlSettings
from travel.avia.library.python.lib_yt.client import configured_client
from travel.avia.library.python.boto3_entities import S3ClientProto

logger = logging.getLogger(__name__)


class AppSettings(BaseSettings):
    stations_table: str = '//home/rasp/reference/station'
    airport_blacklist_table: str = '//home/avia/data/ad-feed/blacklist'


@lru_cache(maxsize=None)
def create_yt_client() -> YtClient:
    yt_settings = YtSettings()
    return configured_client(yt_settings.proxy, yt_settings.token)


@lru_cache(maxsize=None)
def create_yql_client() -> YqlClient:
    return YqlClient(db=YtSettings().proxy, token=YqlSettings().token)


@lru_cache(maxsize=None)
def create_s3_client() -> S3ClientProto:
    mds_settings = MdsSettings()
    return boto3.session.Session(
        aws_access_key_id=mds_settings.access_key_id,
        aws_secret_access_key=mds_settings.access_key_secret,
    ).client(
        service_name='s3',
        endpoint_url=mds_settings.endpoint,
    )


def create_airport_blacklist(yt_client: YtClient, app_settings: AppSettings) -> AirportBlacklist:
    return AirportBlacklist(
        yt_client=yt_client,
        blacklist_table=app_settings.airport_blacklist_table,
        stations_table=app_settings.stations_table,
    )


@click.group()
def main():
    logging.basicConfig(level=logging.INFO, stream=sys.stdout)


def _set_defaults(
    ctx: click.core.Context, _: click.core.Parameter, environment: Optional[Environment]
) -> Optional[Environment]:
    if environment is None:
        return None
    ctx.default_map = {
        'source_table': SOURCE_TABLE_BY_ENV[environment],
        'output_prefix': OUTPUT_TABLE_PREFIX_BY_ENV[environment],
    }
    return environment


@main.command()
@click.option(
    '--env',
    'environment',
    type=EnumType(Environment, case_sensitive=False),
    required=True,
    is_eager=True,
    callback=_set_defaults,
)
@click.option('--source-table', type=str, required=True)
@click.option('--output-prefix', 'output_prefix', type=str)
@click.option('--national-version', type=click.Choice(list(NationalVersion), case_sensitive=False), required=True)
@click.option('--mds-s3-bucket-name', 's3_bucket', type=str, default='avia-indexer', required=True)
@click.option('--mds-s3-prefix', type=str, default='smart-banners-ad-feed', required=True)
def generate(
    national_version: str,
    environment: Environment,
    output_prefix: Optional[str],
    s3_bucket: str,
    mds_s3_prefix: str,
    source_table: str,
) -> None:
    yt_client = create_yt_client()
    yql_client = create_yql_client()
    s3_client = create_s3_client()

    config = get_feed_config()
    app_settings = AppSettings()
    airport_blacklist = create_airport_blacklist(yt_client, app_settings)

    for report_name in config.keys():
        yt_output_table = os.path.join(output_prefix, report_name)
        generator = create_yql_generator(
            report_name=report_name,
            source_table=source_table,
            output_table=yt_output_table,
            yt_client=yt_client,
            yql_client=yql_client,
            airport_blacklist=airport_blacklist,
        )
        for _ in generator.generate_feed():
            pass

        s3_prefix = os.path.join(mds_s3_prefix, report_name)
        s3_key = '{}.csv'.format(national_version)
        s3_path = S3Path(key=s3_key, prefix=s3_prefix, bucket=s3_bucket)

        dump_yt_to_s3_csv(
            yt_client=yt_client,
            yt_path=yt_output_table,
            s3_client=s3_client,
            csv_path=s3_path,
            converter=create_smart_banners_converter(),
            validator=None,
        )
        send_file_metrics(
            path=s3_path,
            client=s3_client,
            reporter=get_solomon_reporter(),
            labels={'smart_banners_report_name': report_name},
        )


if __name__ == '__main__':
    main()
