import concurrent.futures
import contextlib
import itertools
import logging
import os
import threading
import time

from paramiko import Transport, SFTPClient

from cars.request_aggregator.core import RateLimiter


LOGGER = logging.getLogger(__name__)


class SFTPFileDownloader(object):
    def __init__(self, host, port, username, password, *, root_dir):
        self._sftp = None

        self._host = host
        self._port = port
        self._username = username
        self._password = password

        self._root_dir = root_dir

    def connect(self):
        transport = Transport((self._host, self._port))
        transport.connect(username=self._username, password=self._password)
        self._sftp = SFTPClient.from_transport(transport)

    def close(self):
        if self._sftp is not None:
            self._sftp.close()

    def iter_files(self):
        yield from self._sftp.listdir_iter(self._root_dir)

    def open_file(self, file_name):
        file_path = os.path.join(self._root_dir, file_name)
        return self._sftp.open(file_path)


class SFTPDownloadFileManager(object):
    def __init__(
            self, root_dir, file_downloader_cls, mds_client_cls,
            *, max_entries_per_s=8, work_thread_count=8, progress_tick_interval=None
    ):
        self._root_dir = root_dir

        self._file_downloader_cls = file_downloader_cls

        self._file_downloader = self._make_file_downloader()
        self._mds_client = mds_client_cls.from_settings()

        self._thread_file_downloader_collection = {}
        self._rate_limiter = RateLimiter(max_entries_per_s)

        self._max_entries_per_s = max_entries_per_s
        self._work_thread_count = work_thread_count
        self._progress_tick_interval = progress_tick_interval

    def _make_file_downloader(self):
        try:
            file_downloader = self._file_downloader_cls.from_settings(self._root_dir)
            file_downloader.connect()
        except Exception:
            LOGGER.exception('cannot make file downloader')
            raise
        return file_downloader

    def process(self, *, re_raise=False):
        try:
            with concurrent.futures.ThreadPoolExecutor(max_workers=self._work_thread_count) as executor:
                futures = [
                    executor.submit(self._save_file, track_entry)
                    for track_entry in self._collect_files_to_process()
                ]

                for idx, future in zip(itertools.count(), concurrent.futures.as_completed(futures)):
                    try:
                        future.result()
                    except Exception as exc:
                        LOGGER.exception('future raised an exception: {}'.format(exc))
                        if re_raise:
                            raise

                    if self._progress_tick_interval is not None and not ((idx + 1) % self._progress_tick_interval):
                        LOGGER.info('processed {} entries'.format(idx + 1))
        finally:
            for d in self._thread_file_downloader_collection.values():
                d.close()

    def _collect_files_to_process(self):
        for file_description in self._file_downloader.iter_files():
            track_entry = self._make_track_entry(file_description)

            if self._filter_entry(track_entry):
                yield track_entry

    def _filter_entry(self, track_entry):
        raise NotImplementedError

    def _make_track_entry(self, file_description):
        # use filename, st_size, st_mtime ofr file_description
        raise NotImplementedError

    def _save_file(self, track_entry):
        self._rate_limiter.wait()

        file_downloader = self._get_thread_file_downloader()

        try:
            with contextlib.closing(file_downloader.open_file(track_entry.file_name)) as f:
                remote_file_content = f.read()

            self._mds_client.put_object(track_entry, remote_file_content)

            track_entry.save()
        except Exception as exc:
            LOGGER.exception('error saving file {}'.format(track_entry.file_name))
            raise

    def _get_thread_file_downloader(self):
        thread_name = threading.current_thread().name

        if thread_name not in self._thread_file_downloader_collection:
            self._thread_file_downloader_collection[thread_name] = self._make_file_downloader()

        file_downloader = self._thread_file_downloader_collection[thread_name]
        return file_downloader
