import logging
import sys
import threading
import time
import queue
import requests
from argparse import ArgumentParser
from multiprocessing.pool import ThreadPool

from ticket_parser2.api.v1 import TvmClient, TvmClientStatus, TvmApiClientSettings
from yt.wrapper import YtClient

from travel.hotels.lib.python3.yt import ytlib
from travel.library.python.tools import replace_args_from_env
from travel.hotels.lib.python3.cli.cli import auto_progress_reporter
from travel.hotels.lib.python3.utils.throttler import Throttler


LOG = logging.getLogger(__name__)

result_table_schema = ytlib.schema_from_dict({
    "ExternalUrl": "string",
    "AvatarsImg": "any",
    "RawAvatarsResponse": "string",
})

errors_table_schema = ytlib.schema_from_dict({
    "ExternalUrl": "string",
    "ResponseCode": "uint32",
    "ErrorDescription": "string",
    "RawAvatarsResponse": "string",
})


class ThreadSafeCounter(object):
    def __init__(self):
        self._value = 0
        self._lock = threading.Lock()

    def get(self):
        with self._lock:
            return self._value

    def inc(self):
        with self._lock:
            self._value += 1


class YtDumper:
    def __init__(self, name, proxy, token, path, batch_size):
        self.name = name
        self.yt_client = YtClient(proxy=proxy, token=token)
        self.path = path
        self.q = queue.Queue(batch_size)
        self.th = None
        self.batch_size = batch_size
        self.retry_delay_sec = 1

    def start(self):
        self.th = threading.Thread(target=self._run)
        self.th.start()

    def add(self, item):
        self.q.put(item)

    def report_finish_and_wait(self):
        LOG.info(f"Stopping YtDumper ({self.name}): start")
        self.q.put(None)
        LOG.info(f"Stopping YtDumper ({self.name}): put queue marker")
        self.th.join()
        LOG.info(f"Stopping YtDumper ({self.name}): done")

    def _run(self):
        batch = []
        for item in iter(self.q.get, None):
            batch.append(item)
            if len(batch) >= self.batch_size:
                if self._upload_batch(batch):
                    batch = []
        while len(batch) > 0:
            if self._upload_batch(batch):
                batch = []

    def _upload_batch(self, batch):
        try:
            self.yt_client.write_table(self.yt_client.TablePath(self.path, append=True), batch, raw=False)
            return True
        except Exception:
            LOG.error(f'Failed to save data to yt on path {self.path} ({self.name})', exc_info=True)
            time.sleep(self.retry_delay_sec)
        return False


class AvatarsUploader:
    def __init__(self,
                 yt_dumper_uploaded,
                 yt_dumper_errors,
                 queue_limit,
                 max_rps,
                 max_concurrency, max_errors,
                 avatars_put_base_url,
                 avatars_get_base_url,
                 avatars_namespace,
                 avatars_tvm_id,
                 tvm_client_id,
                 tvm_client_secret,
                 avatars_timeout_seconds):
        self.yt_dumper_uploaded: YtDumper = yt_dumper_uploaded
        self.yt_dumper_errors: YtDumper = yt_dumper_errors
        self.q = queue.Queue(queue_limit)
        self.th = None
        self.sessions = threading.local()
        self.throttler = Throttler(max_rps)
        self.requests_pool = ThreadPool(max_concurrency, self._init_session)
        self.errors_counter = ThreadSafeCounter()
        self.uploaded_images_counter = ThreadSafeCounter()
        self.max_errors = max_errors
        self.avatars_put_base_url = avatars_put_base_url
        self.avatars_get_base_url = avatars_get_base_url
        self.avatars_namespace = avatars_namespace
        self.timeout = avatars_timeout_seconds  # seconds

        settings = TvmApiClientSettings(
            self_client_id=tvm_client_id,
            enable_service_ticket_checking=True,
            enable_user_ticket_checking=False,
            self_secret=tvm_client_secret,
            dsts={'avatars': avatars_tvm_id},
        )
        tvm_client = TvmClient(settings)
        if tvm_client.status != TvmClientStatus.Ok:
            raise Exception("tvm client has bad status: " + TvmClient.status_to_string(tvm_client.status))

        self.ticket = tvm_client.get_service_ticket_for('avatars')

    def start(self):
        self.th = threading.Thread(target=self._run)
        self.th.start()

    def add(self, link):
        self.q.put(link)

    def report_finish_and_wait(self):
        LOG.info("Stopping AvatarsUploader: start")
        self.q.put(None)
        LOG.info("Stopping AvatarsUploader: put queue marker")
        self.th.join()
        LOG.info("Stopping AvatarsUploader: done")

    def is_too_many_errors(self):
        return self.errors_counter.get() > self.max_errors

    def get_error_count(self):
        return self.errors_counter.get()

    def get_uploaded_images_count(self):
        return self.uploaded_images_counter.get()

    def _init_session(self):
        self.sessions.session = requests.session()

    def _run(self):
        for link in auto_progress_reporter(iter(self.q.get, None), name='Adding image to avatars'):
            if self.is_too_many_errors():
                continue  # not break to empty queue and avoid blocking on .put()
            self.throttler.delay_before_next_call()
            self.requests_pool.apply_async(self._do_upload, (link,))
        LOG.info("AvatarsUploader: closing requests pool")
        self.requests_pool.close()
        LOG.info("AvatarsUploader: joining requests pool")
        self.requests_pool.join()
        LOG.info("AvatarsUploader: _run() done")

    def _do_upload(self, link):
        try:
            params = {'url': link}
            headers = {'X-Ya-Service-Ticket': self.ticket}
            resp = self.sessions.session.get(f'{self.avatars_put_base_url}/put-{self.avatars_namespace}',
                                             headers=headers, params=params, timeout=self.timeout)
            if resp.ok:
                raw_resp = resp.text
                resp_data = resp.json()
                avatars_img = {
                    'UrlTemplate': f'{self.avatars_get_base_url}/get-{self.avatars_namespace}/{resp_data["group-id"]}/{resp_data["imagename"]}/%s',
                    'Sizes': [{
                        'Height': size['height'],
                        'Width': size['width'],
                        'Size': name,
                    } for name, size in resp_data['sizes'].items()],
                }
                self.yt_dumper_uploaded.add({
                    'ExternalUrl': link,
                    'AvatarsImg': avatars_img,
                    'RawAvatarsResponse': raw_resp,
                })
                self.uploaded_images_counter.inc()
            else:
                LOG.warning(f"Failed to put image (status_code={resp.status_code}): {resp.text}")
                try:
                    description = resp.json().get('description', '')
                except Exception:
                    description = ''
                self.yt_dumper_errors.add({
                    'ExternalUrl': link,
                    'ResponseCode': resp.status_code,
                    'ErrorDescription': description,
                    'RawAvatarsResponse': resp.text,
                })
                self.errors_counter.inc()

        except Exception:
            LOG.error("Exception while putting image", exc_info=True)
            self.errors_counter.inc()


class Runner:
    def __init__(self, args):
        self.yt_client = YtClient(proxy=args.yt_proxy, token=args.yt_token)
        self.yt_dumper_uploaded = YtDumper('uploaded', args.yt_proxy, args.yt_token, args.yt_path_uploaded, args.yt_batch_size)
        self.yt_dumper_errors = YtDumper('errors', args.yt_proxy, args.yt_token, args.yt_path_errors, args.yt_batch_size)
        self.avatars_uploader = AvatarsUploader(self.yt_dumper_uploaded,
                                                self.yt_dumper_errors,
                                                args.yt_batch_size,
                                                args.max_avatars_rps,
                                                args.max_avatars_concurrency,
                                                args.max_avatars_errors,
                                                args.avatars_put_base_url,
                                                args.avatars_get_base_url,
                                                args.avatars_namespace,
                                                args.avatars_tvm_id,
                                                args.avatars_tvm_client_id,
                                                args.avatars_tvm_client_secret,
                                                args.avatars_timeout_seconds)
        self.max_images_to_upload = args.max_images_to_upload
        self.feeds_path = args.feeds_path
        self.max_avatars_errors = args.max_avatars_errors
        self.time_limit_seconds = args.time_limit_seconds
        self.args = args

    def run(self):
        start_time = time.time()
        rooms_column = 'roomTypes'

        mappings_path = self.args.yt_path_uploaded
        self.yt_client.create("table", mappings_path, attributes={"schema": result_table_schema}, recursive=True,
                              ignore_existing=True)

        errors_path = self.args.yt_path_errors
        self.yt_client.create("table", errors_path, attributes={"schema": errors_table_schema}, recursive=True,
                              ignore_existing=True)

        existing_links = self._get_existing_links(mappings_path) | self._get_existing_links(errors_path)
        self.yt_dumper_uploaded.start()
        self.yt_dumper_errors.start()
        self.avatars_uploader.start()
        uploaded_images_count = 0
        spent_time_seconds = 0
        stopping = False

        feed_table_names = list(self.yt_client.list(self.feeds_path))
        LOG.info(f"Found feed tables: {feed_table_names}")

        for name in feed_table_names:
            if stopping:
                break
            full_name = ytlib.join(self.feeds_path, name)
            LOG.info(f"Processing feed table: {full_name}")
            table_path = self.yt_client.TablePath(full_name, columns=[rooms_column])
            total_count = self.yt_client.row_count(table_path)
            LOG.info(f"Total rows count: {total_count}")
            for row in auto_progress_reporter(self.yt_client.read_table(table_path),
                                              name=f'Photos from feed table {name}', total=total_count):
                if stopping:
                    break
                try:
                    if row.get(rooms_column) is None:
                        continue
                    current_links = {photo['link'] for room in row[rooms_column] for photo in room['photos']}
                    for link in list(current_links):
                        if stopping:
                            break
                        if link not in existing_links:
                            self.avatars_uploader.add(link)
                            uploaded_images_count += 1
                            spent_time_seconds = time.time() - start_time
                            stopping |= uploaded_images_count >= self.max_images_to_upload
                            stopping |= self.avatars_uploader.is_too_many_errors()
                            stopping |= spent_time_seconds > self.time_limit_seconds
                            existing_links.add(link)
                except Exception:
                    LOG.error(f"Failure on {row}")
                    raise

        stop_with_error = False
        if self.avatars_uploader.is_too_many_errors():
            LOG.error(f"Reached avatars error limit ({self.avatars_uploader.get_error_count()} errors), stopping")
            stop_with_error = True

        if uploaded_images_count >= self.max_images_to_upload:
            LOG.warning(f"Reached image limit ({self.max_images_to_upload} images), stopping")

        if spent_time_seconds > self.time_limit_seconds:
            LOG.warning(f"Process took too long ({int(spent_time_seconds)} seconds), stopping")

        self.avatars_uploader.report_finish_and_wait()
        self.yt_dumper_uploaded.report_finish_and_wait()
        self.yt_dumper_errors.report_finish_and_wait()

        LOG.info(f"Tried to upload {uploaded_images_count} images, uploaded {self.avatars_uploader.get_uploaded_images_count()} images")

        if stop_with_error:
            sys.exit(1)

    def _get_existing_links(self, path):
        external_url_column = 'ExternalUrl'
        table_path = self.yt_client.TablePath(path, columns=[external_url_column])
        total_count = self.yt_client.row_count(table_path)
        return {row[external_url_column] for row in
                auto_progress_reporter(self.yt_client.read_table(table_path),
                                       name='Loading known links', total=total_count)}


def main():
    logging.basicConfig(level=logging.INFO, format="%(asctime)-15s | %(module)s | %(levelname)s | %(message)s",
                        stream=sys.stdout)
    logging.getLogger('yt.packages.urllib3.connectionpool').setLevel(logging.WARNING)

    parser = ArgumentParser()
    parser.add_argument('--yt-proxy', default='hahn')
    parser.add_argument('--yt-token', required=True)
    parser.add_argument('--yt-path-uploaded', required=True)
    parser.add_argument('--yt-path-errors', required=True)
    parser.add_argument('--yt-batch-size', default=5000, type=int)
    parser.add_argument('--feeds-path', required=True)
    parser.add_argument('--avatars-get-base-url', required=True)
    parser.add_argument('--avatars-put-base-url', required=True)
    parser.add_argument('--avatars-namespace', default='travel-rooms')
    parser.add_argument('--avatars-tvm-id', required=True, type=int)
    parser.add_argument('--avatars-tvm-client-id', required=True, type=int)
    parser.add_argument('--avatars-tvm-client-secret', required=True)
    parser.add_argument('--avatars-timeout-seconds', default=30)
    parser.add_argument('--max-avatars-rps', default=1, type=int)
    parser.add_argument('--max-avatars-concurrency', default=1, type=int)
    parser.add_argument('--max-images-to-upload', default=2, type=int)
    parser.add_argument('--max-avatars-errors', default=100, type=int)
    parser.add_argument('--time-limit-seconds', default=3600, type=int)
    args = parser.parse_args(args=replace_args_from_env())
    Runner(args).run()


if __name__ == '__main__':
    main()
