# -*- coding: utf-8 -*-
import abc
import io
import logging
import os
import typing

import time
from PIL import Image
from pydantic import BaseModel
from selenium import webdriver

from travel.avia.revise.extractor.choices import ResultChoices, ShownResultChoices
from travel.avia.revise.extractor.error import TicketChangedException
from travel.avia.revise.lib.selenium_lib import NO_ACCESS_EXCEPTIONS

log = logging.getLogger(__name__)


class ExtractionError(object):
    def __init__(self, exc_info, screenshots: list['ScreenShot']):
        self.exc_info = exc_info
        self.screenshots = screenshots

    @property
    def error(self):
        return self.exc_info[1]

    def is_unexpected_error(self):
        return not isinstance(self.exc_info[1], TicketChangedException)


class ScreenShot(BaseModel):
    name: str
    image: typing.Optional[bytes]
    type_: str = 'jpeg'
    error: typing.Optional[str] = None  # TODO: remove

    def __repr__(self):
        return 'ScreenShot %r' % self.name


def convert_png_to_jpeg(png: bytes) -> bytes:
    image = Image.open(io.BytesIO(png)).convert('RGB')
    converted_data = io.BytesIO()
    image.save(converted_data, format='JPEG')
    return converted_data.getvalue()


class ScreenShotMakerError(Exception):
    pass


class IScreenShotMaker(abc.ABC):
    @abc.abstractmethod
    def make(self, name: str, debug: bool = False, screenshots_folder: str = './screenshots') -> ScreenShot:
        pass


class DefaultScreenShotMaker(IScreenShotMaker):
    def __init__(self, driver: webdriver.Remote):
        self._driver = driver

    def make(self, name: str, debug: bool = False, screenshots_folder: str = './screenshots') -> ScreenShot:
        try:
            png = self._driver.get_screenshot_as_png()
            if debug:
                os.makedirs(screenshots_folder, exist_ok=True)
                self._driver.get_screenshot_as_file(os.path.join(screenshots_folder, '{}.png'.format(name)))
            return ScreenShot(image=convert_png_to_jpeg(png), name=name)
        except Exception as e:
            log.exception('making screenshots')
            return ScreenShot(image=None, name=name, error=str(e))  # TODO: raise error instead


class FullPageScreenShotMaker(DefaultScreenShotMaker):
    def make(self, name: str, debug: bool = False, screenshots_folder: str = './screenshots') -> ScreenShot:
        original_size = self._driver.get_window_size()
        required_width = self._driver.execute_script('return document.body.parentNode.scrollWidth')
        required_height = self._driver.execute_script('return document.body.parentNode.scrollHeight')
        self._driver.set_window_size(required_width, required_height)
        try:
            return super().make(name)
        finally:
            self._driver.set_window_size(original_size['width'], original_size['height'])


def get_with_screenshot(driver, url, sleep=5, debug=False) -> ScreenShot:
    driver.get(url)
    time.sleep(sleep)
    return DefaultScreenShotMaker(driver=driver).make('first', debug=debug)


def filter_dict(d: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
    return {k: v for k, v in d.items() if v is not None}


class ErrorHandler(object):
    @classmethod
    def on_fetch_review(cls, error):
        result = ResultChoices.can_not_fetch_review_error
        shown_result = ShownResultChoices.can_not_fetch_review_error

        if isinstance(error, NO_ACCESS_EXCEPTIONS):
            result = ResultChoices.error
            shown_result = ShownResultChoices.error
        elif isinstance(error, TicketChangedException):
            result = ResultChoices.ticket_changed
            shown_result = ShownResultChoices.ticket_changed

        return cls.get_report(result, shown_result, error)

    @classmethod
    def on_fetch_price(cls, error):
        result = ResultChoices.can_not_fetch_price_error
        shown_result = ShownResultChoices.can_not_fetch_price_error
        return cls.get_report(result, shown_result, error)

    @classmethod
    def get_report(cls, result, shown_result, error):
        return {
            'result': result,
            'shown_result': shown_result,
            'description': cls.get_description(result, error),
        }

    @classmethod
    def get_description(cls, choice, error):
        error_choice = choice.capitalize().replace('_', ' ')
        return '%s: %s' % (error_choice, str(error))
