# encoding: UTF-8

import itertools
import multiprocessing.pool

import sqlalchemy as sa
from ws_properties.environ.environment import Environment
from ws_properties.utils.logs import get_logger_for_instance

from appcore.injection import Injected
from dns_hosting.services.migrator.bucket import RecordsBucket
from dns_hosting.services.migrator.loader import ThreadedLoader
from dns_hosting.services.migrator.migration import make_migration
from dns_hosting.services.migrator.qbundle import QueryBundle
from dns_hosting.utils.iterators import ijoin
from dns_hosting.utils.progress import IteratorLogger
from dns_hosting.utils.retry import retry, StandardRetryPolicy
from dns_hosting.utils.text import try_idna_decode


class MigratorService(object):
    QUERIES = (
        'load_all_pdd_records',
        'load_all_own_records',
        'insert_domain',
        'delete_domain',
        'clear_domain',
        'apply_domain_changes',
    )

    environment = Injected('environment')  # type: Environment
    master_engine = Injected('master_engine')  # type: sa.engine.Engine
    slave_engine = Injected('slave_engine')  # type: sa.engine.Engine
    pdd_engine = Injected('pdd_engine')  # type: sa.engine.Engine

    def __init__(
            self,
            workers_number,
            writers_number,
            read_batch_size=10000,
            write_batch_size=10,
            queue_size=5,
            log_interval=5,
    ):
        self._logger = get_logger_for_instance(self)
        self.workers_number = workers_number
        self.writers_number = writers_number
        self.read_batch_size = read_batch_size
        self.write_batch_size = write_batch_size
        self.queue_size = queue_size
        self.log_interval = log_interval

    def _group_records(self, records):
        buckets = itertools.groupby(
            records,
            key=lambda r: (r.pop('domain_id'), r.pop('origin'), r.pop('sync_enabled')),
        )
        for (domain_id, origin, sync_enabled), records in buckets:
            domain_id = int(domain_id)
            origin = try_idna_decode(origin)
            yield RecordsBucket(
                domain_id,
                origin,
                sync_enabled,
                filter(lambda r: r['id'] is not None, records)
            )

    @retry(StandardRetryPolicy(delay=5, max_attempts=3))
    def _migrate(self, (migrations, engine)):
        migration = '<no migration>'
        try:
            with engine.begin() as conn:
                for migration in migrations:
                    if migration is None:
                        continue
                    else:
                        conn.execute(migration)
        except Exception:
            self._logger.exception('Migration can\'t be applied\n%s', migration)
            raise

    def migrate(self):
        worker_pool = multiprocessing.pool.Pool(self.workers_number)
        writer_pool = multiprocessing.pool.ThreadPool(self.writers_number)

        query_bundle = QueryBundle()
        query_bundle.load(self.environment, self.QUERIES)

        pdd_records = ThreadedLoader(
            self.pdd_engine,
            query_bundle.get('load_all_pdd_records'),
            self.read_batch_size,
            self.queue_size,
        )
        pdd_buckets = self._group_records(pdd_records)

        own_records = ThreadedLoader(
            self.slave_engine,
            query_bundle.get('load_all_own_records'),
            self.read_batch_size,
            self.queue_size,
        )
        own_buckets = self._group_records(own_records)

        buckets = ijoin(
            pdd_buckets,
            own_buckets,
            cmp=lambda b1, b2: cmp(b1.domain_id, b2.domain_id),
        )
        buckets_prg = IteratorLogger(
            buckets,
            'Loading buckets',
            self._logger,
            self.log_interval,
        )

        migrations = worker_pool.imap_unordered(
            make_migration,
            itertools.izip(buckets_prg, itertools.cycle([query_bundle])),
        )
        migrations_prg = IteratorLogger(
            migrations,
            'Making migrations',
            self._logger,
            self.log_interval,
        )

        migration_batches = itertools.ifilter(None, migrations_prg)
        migration_batches = [iter(migration_batches)] * self.write_batch_size
        migration_batches = itertools.izip_longest(
            fillvalue=None,
            *migration_batches
        )

        results = writer_pool.imap_unordered(
            self._migrate,
            itertools.izip(
                migration_batches,
                itertools.cycle([self.master_engine]),
            ),
        )
        results_prg = IteratorLogger(
            results,
            'Migration applying x%d' % self.write_batch_size,
            self._logger,
            self.log_interval,
        )

        self._logger.info('Migrating PDD domains...')
        try:
            with buckets_prg, migrations_prg, results_prg:
                pdd_records.start()
                own_records.start()
                for _ in results_prg:
                    pass
        except Exception:
            self._logger.exception('Migration from PDD failed')
        else:
            self._logger.info('Migration from PDD completed')
        finally:
            worker_pool.terminate()
            worker_pool.join()
            writer_pool.terminate()
            writer_pool.join()
