# -*- coding: utf-8 -*-
import base64
import datetime
import json
import logging
from typing import Optional, Union, Any, Dict

from celery import Celery
from pydantic import BaseModel, Extra, validator

from travel.avia.revise import settings
from travel.avia.revise.extractor.choices import ShownResultChoices
from travel.avia.revise.extractor.extract import get_extractor, ExtractedInfo, Extractor
from travel.avia.revise.extractor.report import ErrorHandler, ExtractionError
from travel.avia.revise.lib.selenium_lib import get_driver
from travel.avia.revise.revise_task.review import (
    ReviewWriter, get_prices_diffs, get_result_status,
)
from travel.avia.revise.revise_task.review import ReviseReport

logger = logging.getLogger(__name__)
review_writer = ReviewWriter(logger)

DESCRIPTION_LENGTH_RESTRICTION = 255
LONG_DESCRIPTION_SUFFIX = '...'


class Settings:
    BROKER_URL = 'sqs://%s:%s@%s' % (settings.SQS_ACCESS_KEY, settings.SQS_SECRET_KEY, settings.SQS_ENDPOINT)

    BROKER_TRANSPORT_OPTIONS = {
        'is_secure': False,
        'region': 'yandex',
    }
    CELERY_ROUTES = {
        'do_review': {
            'queue': settings.SQS_QUEUE_NAME,
        },
    }
    CELERY_QUEUES = {
        settings.SQS_QUEUE_NAME: {},
    }


app = Celery('revise')
app.config_from_object(Settings())


class RedirectParams(BaseModel):
    url: str
    post: Optional[Dict[str, Any]]

    class Config:
        extra = Extra.allow


class Task(BaseModel):
    partner: str
    redirect_params: RedirectParams
    price_value: float
    shown_price_value: Optional[float]

    hit_time: Optional[datetime.datetime]
    review_time: Optional[datetime.datetime]
    order_content: Any
    user_info: Any
    price_currency: Any
    price_unixtime: Optional[datetime.datetime]
    query_source: Optional[str]
    utm_source: Any
    utm_campaign: Any
    utm_medium: Any
    utm_content: Any
    wizard_redir_key: Any
    wizard_flags: Any
    point_from: Any
    point_to: Any
    date_forward: datetime.date
    date_backward: Optional[datetime.date]
    adults: Optional[int]
    children: Optional[int]
    infants: Optional[int]
    national_version: Optional[str]
    klass: Optional[str]
    shown_price_currency: Any
    shown_price_unixtime: Optional[datetime.datetime]

    @validator('redirect_params', pre=True)
    def fix_redirect_params(cls, value: Any) -> Any:
        if isinstance(value, str):
            return json.loads(value)
        return value

    class Config:
        extra = Extra.allow


@app.task(name='do_review')
def do_review(raw_data: bytes, message_id: str):
    logger.info('task received, id: %s, content: %r', message_id, raw_data)
    task = Task.parse_raw(raw_data)

    report = parse_data(task)

    logger.info('Message parsed, sending to stat, id: %s', message_id)

    review_writer.write_review(review=report, message_id=message_id)


def make_info_report(task: Task, parsed_info: ExtractedInfo) -> ReviseReport:
    report = {}
    try:
        report['revise_data'] = json.dumps(
            parsed_info.meta or parsed_info.price, ensure_ascii=False
        )
        report['price_revise'] = parsed_info.price.value
    except Exception as e:
        logger.exception('Can not fetch price. Data: "%r", "%r"', parsed_info.price, parsed_info.meta)
        return ErrorHandler.on_fetch_price(e)

    report['price_diff_abs'], report['price_diff_rel'] = get_prices_diffs(
        report['price_revise'], task.price_value)
    report['result'], report['description'] = get_result_status(report['price_diff_abs'])

    if task.shown_price_value is not None:
        report['shown_price_diff_abs'], report['shown_price_diff_rel'] = get_prices_diffs(
            report['price_revise'], task.shown_price_value)
        report['shown_result'], _ = get_result_status(report['shown_price_diff_abs'])
    else:
        report['shown_result'] = ShownResultChoices.shown_price_is_absent

    if parsed_info.screenshots is not None:
        report['screenshots'] = [base64.b64encode(s.image) for s in parsed_info.screenshots]

    return report


def make_error_report(error: ExtractionError) -> ReviseReport:
    revise_report = ErrorHandler.on_fetch_review(error.error)
    revise_report['screenshots'] = [base64.b64encode(s.image) for s in error.screenshots]
    return revise_report


def extract_with_retries(
    extractor: Extractor, url: str, post_params: Optional[dict[str, Any]] = None, tries: int = 3
) -> Union[ExtractedInfo, ExtractionError]:
    assert isinstance(tries, int) and tries > 0

    while tries > 0:
        tries -= 1
        result = extractor(url, post_params)
        if isinstance(result, ExtractionError):
            logger.warning('Extracting error occurred', exc_info=result.exc_info)
            if tries == 0:
                return result
        else:
            return result


def preprocess_revise_report(data: ReviseReport) -> ReviseReport:
    returning_data = data

    description_field_name = 'description'
    description = returning_data.get(description_field_name, None)

    if description:
        if len(description) > DESCRIPTION_LENGTH_RESTRICTION:
            logger.info(f'Description field is longer than 255 characters. Cutting it: {description}')
            description_new_length = DESCRIPTION_LENGTH_RESTRICTION - len(LONG_DESCRIPTION_SUFFIX)
            description = description[:description_new_length] + LONG_DESCRIPTION_SUFFIX
            returning_data[description_field_name] = description

    return returning_data


def parse_data(task: Task) -> ReviseReport:
    with get_driver(partner_code=task.partner) as driver:
        extractor = get_extractor(task.partner, driver)
        parsed_info = extract_with_retries(
            extractor=extractor, url=task.redirect_params.url, post_params=task.redirect_params.post
        )
    if isinstance(parsed_info, ExtractionError):
        result = make_error_report(parsed_info)
    else:
        result = make_info_report(task, parsed_info)
    data = json.loads(task.json())
    data['redirect_params'] = json.dumps(data['redirect_params'])
    result.update(data)
    result = preprocess_revise_report(result)
    return result
