# coding: utf8

import Queue
import logging
import os
import threading
import time
import traceback

from StringIO import StringIO
from contextlib import contextmanager
from datetime import datetime, timedelta
from optparse import OptionParser, Option
from random import shuffle

import psycopg2.extras
import pycurl


class Segment:
    def __init__(self, row):
        self.mulca_id = row['mulca_id']
        self.quality = row['quality']
        self.index = row['index']
        self.duration = row['duration']
        self.mds_key = row['mds_key']
        self.last_access_time = row['last_access_time']

    def __str__(self):
        return '%s/%s/%.2fs#%d' % (self.mulca_id, self.quality, self.duration, self.index)


class MdsKey:
    def __init__(self, row):
        self.mds_key = row['mds_key']
        self.creation_time = row['creation_time']

    def __str__(self):
        return self.mds_key


class DaoIterator:
    def __init__(self, dao_factory, fetch_items, build_next_args, initial_args):
        self._dao_factory = dao_factory
        self._args = initial_args
        self._fetch_items = fetch_items
        self._build_next_args = build_next_args

    def __iter__(self):
        return self

    def next(self):
        with dao_factory() as dao:
            items = self._fetch_items(dao, *self._args)
            if not items:
                raise StopIteration

            self._args = self._build_next_args(items)
            return items


class DaoException(Exception):
    pass


class StreamingDao:
    DATE_IN_PAST = datetime.strptime('1983', '%Y')
    _KEYS = ['mulca_id', 'quality', 'index', 'duration', 'mds_key', 'last_access_time']

    def __init__(self, master_conn_str, slave_conn_str, expired_segments_max_count, discarded_keys_max_count,
                 max_last_access_time):
        self._master_conn_str = master_conn_str
        self._slave_conn_str = slave_conn_str
        self._conn = psycopg2.connect(master_conn_str)
        self._slave_conn = psycopg2.connect(slave_conn_str)
        self._expired_segments_max_count = expired_segments_max_count
        self._discarded_keys_max_count = discarded_keys_max_count
        self._max_last_access_time = max_last_access_time

    @staticmethod
    def factory(master_conn_str, slave_conn_str, expired_segments_max_count, discarded_keys_max_count,
                max_last_access_time):
        def create_dao():
            return StreamingDao(master_conn_str, slave_conn_str, expired_segments_max_count, discarded_keys_max_count,
                                max_last_access_time)
        return create_dao

    def select_expired_segments(self, min_last_access_time=DATE_IN_PAST):
        start_time = time.time()
        cursor = self._slave_conn.cursor('cursor' + str(time.time()), cursor_factory=psycopg2.extras.DictCursor)
        cursor.execute('SELECT mulca_id, quality, index, duration, mds_key, last_access_time FROM segment_cache'
                       ' WHERE last_access_time < %s AND last_access_time > %s'
                       ' ORDER BY last_access_time LIMIT %s',
                       [self._max_last_access_time, min_last_access_time, self._expired_segments_max_count])
        query_duration = time.time() - start_time

        iterate_start_time = time.time()
        segments = [Segment(row) for row in cursor]
        iterate_duration = time.time() - iterate_start_time

        logging.info('Selected expired segments, count = %d, query time = %.2f, iterate time = %.2f',
                     len(segments), query_duration, iterate_duration)
        return segments

    def select_discarded_keys(self, mds_key='', min_creation_time=DATE_IN_PAST):
        start_time = time.time()
        cursor = self._slave_conn.cursor('cursor' + str(time.time()), cursor_factory=psycopg2.extras.DictCursor)
        cursor.execute('SELECT * FROM discarded_mds_keys WHERE creation_time >= %s AND mds_key > %s'
                       ' ORDER BY creation_time, mds_key LIMIT %s',
                       [min_creation_time, mds_key, self._discarded_keys_max_count])
        query_duration = time.time() - start_time

        mds_keys = [MdsKey(row) for row in cursor]
        fetch_duration = time.time() - start_time

        logging.info('Select from discarded_keys, count = %d, query time = %.2f, fetch time = %.2f',
                     len(mds_keys), query_duration, fetch_duration)
        return mds_keys

    def move_to_discarded(self, segment):
        with self._new_cursor() as cursor:
            self._delete_from_segment_cache(cursor, segment)
            count = cursor.rowcount
            if count > 1:
                raise DaoException('Too many records updated')

            if count == 1:
                cursor.execute('INSERT INTO discarded_mds_keys (mds_key) VALUES (%s)', [segment.mds_key])
                return True
            else:
                return False

    def delete_from_segment_cache(self, segment):
        with self._new_cursor() as cursor:
            self._delete_from_segment_cache(cursor, segment)

    @staticmethod
    def _delete_from_segment_cache(cursor, segment):
        cursor.execute('DELETE FROM segment_cache'
                       ' WHERE (mulca_id, quality, index, duration) = (%s, %s, %s, %s) AND last_access_time = %s',
                       [segment.mulca_id, segment.quality, segment.index, segment.duration, segment.last_access_time])

    def delete_from_discarded(self, mds_key):
        with self._new_cursor() as cursor:
            cursor.execute('DELETE FROM discarded_mds_keys WHERE mds_key = %s', [mds_key])
            count = cursor.rowcount
            if count > 1:
                raise DaoException('Too many records updated')

            return count == 1

    def close(self):
        try:
            self._conn.close()
            self._slave_conn.close()
        except psycopg2.Error as e:
            logging.error('Could not close connection: %s %s', e.pgcode, e.pgerror)

    @contextmanager
    def _new_cursor(self):
        cursor = self._conn.cursor()
        try:
            yield cursor
        except BaseException as e:
            try:
                self._conn.rollback()
            except psycopg2.Error:
                pass

            raise e
        self._conn.commit()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()


class MdsClient:
    OK_DELETE_RESULTS = [200, 404]
    _DISK_NS = 'video-disk'

    def __init__(self, url, auth, connect_timeout, timeout):
        self._url = url
        self._curl = self._init_curl(auth, connect_timeout, timeout)

    @staticmethod
    def factory(url, auth, connect_timeout):
        def create_client(timeout):
            return MdsClient(url, auth, connect_timeout, timeout)
        return create_client

    @staticmethod
    def _init_curl(auth, connect_timeout, timeout):
        curl = pycurl.Curl()
        curl.setopt(pycurl.HTTPHEADER, ['Authorization: %s' % auth])
        curl.setopt(pycurl.NOSIGNAL, 1)
        curl.setopt(pycurl.CONNECTTIMEOUT_MS, connect_timeout)
        curl.setopt(pycurl.TIMEOUT_MS, timeout)
        return curl

    def delete(self, key):
        resp_body_buffer = StringIO()
        c = self._setup_curl('delete', key)
        c.setopt(c.WRITEDATA, resp_body_buffer)
        c.perform()
        resp_code = c.getinfo(pycurl.RESPONSE_CODE)
        resp_body = resp_body_buffer.getvalue()
        self._log_curl_stats(key, resp_body)
        return resp_code, resp_body

    def _setup_curl(self, action, key):
        self._curl.setopt(pycurl.URL, '%s/%s-%s/%s' % (self._url, action, self._DISK_NS, key))
        return self._curl

    def _log_curl_stats(self, key, resp_body):
        c = self._curl
        logging.debug('Curl stats for mds key = %s - time: '
                      'total = %.3f, namelookup = %.3f, connect = %.3f, redirect = %.3f, '
                      'pretransfer = %.3f, starttransfer = %.3f; response code = %d, body = %s',
                      key,
                      c.getinfo(pycurl.TOTAL_TIME),
                      c.getinfo(pycurl.NAMELOOKUP_TIME),
                      c.getinfo(pycurl.CONNECT_TIME),
                      c.getinfo(pycurl.REDIRECT_TIME),
                      c.getinfo(pycurl.PRETRANSFER_TIME),
                      c.getinfo(pycurl.STARTTRANSFER_TIME),
                      c.getinfo(pycurl.RESPONSE_CODE),
                      resp_body)

    def close(self):
        self._curl.close()


class Cleaner:
    def __init__(self, queue, mds, dao, max_duration, proceed):
        self._queue = queue
        self._mds = mds
        self._dao = dao
        self._stop_time = datetime.utcnow() + max_duration
        self._proceed = proceed

    @staticmethod
    def factory(mds_client_factory, dao_factory, max_duration):
        def create_cleaner(queue, proceed, timeout):
            return Cleaner(queue, mds_client_factory(timeout), dao_factory(), max_duration, proceed)
        return create_cleaner

    def clean(self):
        logging.info('Starting cleaning')
        while self._proceed.is_set() and datetime.utcnow() < self._stop_time:
            try:
                segment_or_mds_key = self._queue.get(timeout=1)
            except Queue.Empty:
                segment_or_mds_key = None
                time.sleep(0.5)

            if segment_or_mds_key is None:
                continue

            if isinstance(segment_or_mds_key, Segment):
                self._try_delete_segment_cache(segment_or_mds_key)
            else:
                self._try_delete_mds_key(segment_or_mds_key.mds_key)

    def _try_delete_segment_cache(self, segment):
        start_time = time.time()
        if segment.mds_key is None:
            try:
                self._dao.delete_from_segment_cache(segment)
                logging.info('Deleted segment with empty mds key %s, done in %.2f', segment, time.time() - start_time)
            except psycopg2.Error as e:
                logging.error('DB error while deleting segment cache with empty MDS key %s: %s', segment, e)
        else:
            try:
                discarded = self._dao.move_to_discarded(segment)
                logging.info('Segment MDS key moved to discarded %s: %s, done in %.2f', segment.mds_key, discarded, time.time() - start_time)
            except psycopg2.Error as e:
                logging.error('DB error while moving to discarded key %s: %s', segment.mds_key, e)
                return

            if discarded:
                self._try_delete_mds_key(segment.mds_key)

    def _try_delete_mds_key(self, mds_key):
        start_time = time.time()
        try:
            delete_result, delete_response = self._mds.delete(mds_key)
        except pycurl.error as e:
            logging.error('Could not delete from MDS key %s: %s', mds_key, e)
            return

        if delete_result not in MdsClient.OK_DELETE_RESULTS:
            logging.error('Could not delete from MDS key %s: got status = %s with message "%s"',
                          mds_key, str(delete_result), delete_response)
            return

        db_start_time = time.time()
        try:
            deleted_from_discarded = self._dao.delete_from_discarded(mds_key)
            if not deleted_from_discarded:
                logging.error('Could not delete from discarded key %s: DELETE query did not affect any rows', mds_key)
                return
        except pycurl.error as e:
            logging.error('Could not delete from discarded key %s: %s', mds_key, e)
            return

        mds_time = db_start_time - start_time
        db_time = time.time() - db_start_time
        total_time = time.time() - start_time
        if delete_result == 200:
            logging.info(
                'Successfully deleted and removed from discarded key %s, done in: mds = %.2fs, db = %.2fs, total = %.2fs',
                mds_key,
                mds_time,
                db_time,
                total_time
            )
        else:
            logging.info(
                'Got already deleted key - just removing from discarded key %s, done in: mds = %.2fs, db = %.2fs, total = %.2fs',
                mds_key,
                mds_time,
                db_time,
                total_time
            )

    def close(self):
        try:
            self._mds.close()
        except:
            traceback.print_exc()

        try:
            self._dao.close()
        except:
            traceback.print_exc()

    def __call__(self, *args, **kwargs):
        self.clean()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()


class QueuerAndCleaner:
    def __init__(self, concurrency, cleaner_factory, dao_factory, fill_queue_size, wait_between_fills):
        self._concurrency = concurrency
        self._cleaner_factory = cleaner_factory
        self._dao_factory = dao_factory
        self._fill_queue_size = fill_queue_size
        self._wait_between_fills = wait_between_fills

    def do_cleaning(self, fetch_items, build_next_args, initial_args, timeout):
        threads = []
        queue = Queue.Queue()
        proceed = threading.Event()
        proceed.set()

        def clean():
            while proceed.is_set():
                try:
                    with self._cleaner_factory(queue, proceed, timeout) as cleaner:
                        cleaner()

                        if proceed.is_set():
                            time.sleep(5)
                except psycopg2.Error as e:
                    logging.error('DB error while cleaning %s', e)

        def cons_and_start_cleaner_with_delay():
            time.sleep(1)
            thread = threading.Thread(target=clean)
            thread.start()
            return thread

        def fill_queue():
            iterator = DaoIterator(self._dao_factory, fetch_items, build_next_args, initial_args)
            retries = 0
            while True:
                try:
                    for items in iterator:
                        start_time = time.time()
                        retries = 0
                        while queue.qsize() > self._fill_queue_size:
                            time.sleep(0.1)

                        shuffle(items)
                        for item in items:
                            queue.put(item)
                        time.sleep(self._wait_between_fills)
                        logging.info('Done iterating %.2fs', time.time() - start_time)
                    break
                except psycopg2.Error as e:
                    if retries >= 3:
                        raise e

                    retries += 1
                    logging.warn('DB Error while filling queue, retrying: %s', e)
                    time.sleep(self._wait_between_fills)

        def shutdown():
            while not queue.empty():
                time.sleep(1)
            proceed.clear()
            for thread in threads:
                thread.join()

        threads = [cons_and_start_cleaner_with_delay() for _ in range(0, self._concurrency)]

        try:
            fill_queue()

            while not queue.empty():
                time.sleep(1)
        finally:
            shutdown()


def select_discarded_keys(dao, mds_key, min_time):
    return dao.select_discarded_keys(mds_key, min_time)


def select_expired_segments(dao, min_time):
    return dao.select_expired_segments(min_time)


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] [%(threadName)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S')

    settings_per_env = {
        'dev': {
            'db-host': 'pgaas-test.mail.yandex.net',
            'db-host-master': 'c-3931c612-757b-46ad-be56-b9d71f02ac23.rw.db.yandex.net',
            'db-host-slave': 'c-3931c612-757b-46ad-be56-b9d71f02ac23.ro.db.yandex.net',
            'db-name': 'disk_videostreaming_testing_pg',
            'mds-url': 'http://storage-int.mdst.yandex.net:1111',
            'expire_time_delta': timedelta(days=1),
            'batch-size': 100,
            'concurrency': 1
        },
        'testing': {
            'db-host': 'pgaas.mail.yandex.net',
            'db-host-master': 'c-3931c612-757b-46ad-be56-b9d71f02ac23.rw.db.yandex.net',
            'db-host-slave': 'c-3931c612-757b-46ad-be56-b9d71f02ac23.ro.db.yandex.net',
            'db-name': 'disk_videostreaming_testing_pg',
            'mds-url': 'http://storage-int.mdst.yandex.net:1111',
            'expire_time_delta': timedelta(days=1),
            'batch-size': 100,
            'concurrency': 1
        },
        'production': {
            'db-host-master': 'c-7f617f46-0474-4b01-a0c7-98bb3e0beeb0.rw.db.yandex.net',
            'db-host-slave': 'c-7f617f46-0474-4b01-a0c7-98bb3e0beeb0.ro.db.yandex.net',
            'db-name': 'disk_streaming',
            'mds-url': 'http://storage-int.mds.yandex.net:1111',
            'expire_time_delta': timedelta(days=7),
            'batch-size': 20000,
            'concurrency': 20
        }
    }

    env_opt_spec = Option(
        '-e', '--env',
        action='store',
        dest='env',
        type='string',
        default=None,
        help='Environment'
    )

    class MyOptionParser(OptionParser):
        def error(self, msg):
            pass
    env_opts, _ = MyOptionParser(option_list=[env_opt_spec]).parse_args()
    if not env_opts.env or env_opts.env not in settings_per_env:
        print 'Environment is not specified'
        exit(1)

    settings = settings_per_env[env_opts.env]

    opt_spec_list = (
        env_opt_spec,
        Option(
            '-b', '--batch-size',
            action='store',
            dest='batch_size',
            type='int',
            default=settings['batch-size'],
            help='Batch size'
        ),
        Option(
            '-c', '--concurrency',
            action='store',
            dest='concurrency',
            type='int',
            default=settings['concurrency'],
            help='Concurrency'
        ),
        Option(
            '--db-pass',
            action='store',
            dest='db_pass',
            type='string',
            default=os.environ.get('DB_PASS'),
            help='Database password'
        ),
        Option(
            '--wait-between-fills',
            action='store',
            dest='wait_between_fills',
            type='int',
            default=5,
            help='Time to wait between queue fills'
        ),
        Option(
            '--fill-queue-size',
            action='store',
            dest='fill_queue_size',
            type='int',
            default=30000,
            help='Fill queue size'
        ),
        Option(
            '--mds-auth',
            action='store',
            dest='mds_auth',
            type='string',
            default=os.environ.get('MDS_AUTH'),
            help='Database password'
        ),
        Option(
            '--mds-timeout',
            action='store',
            dest='mds_timeout',
            type='int',
            default=700,
            help='Mds total timeout'
        ),
        Option(
            '--mds-long-timeout',
            action='store',
            dest='mds_long_timeout',
            type='int',
            default=3000,
            help='Mds total long timeout'
        ),
        Option(
            '--mds-connect-timeout',
            action='store',
            dest='mds_connect_timeout',
            type='int',
            default=500,
            help='Mds total timeout'
        )
    )
    opts, args = OptionParser(option_list=opt_spec_list).parse_args()

    conn_str_pattern = "host='{db-host-master}' port=6432 dbname='{db}' user='disk_streaming' password='{password}' sslmode='require'"
    conn_str_pattern_slave = "host='{db-host-slave}' port=6432 dbname='{db}' user='disk_streaming' password='{password}' sslmode='require'"

    master_db_conn_str = conn_str_pattern.format(password=opts.db_pass, db=settings['db-name'], **settings)
    slave_db_conn_str = conn_str_pattern_slave.format(password=opts.db_pass, db=settings['db-name'], **settings)

    max_last_access_time = datetime.utcnow() - settings['expire_time_delta']
    dao_factory = StreamingDao.factory(master_db_conn_str, slave_db_conn_str, opts.batch_size, opts.batch_size,
                                       max_last_access_time)
    mds_client_factory = MdsClient.factory(settings['mds-url'], opts.mds_auth, opts.mds_connect_timeout)
    cleaner_factory = Cleaner.factory(mds_client_factory, dao_factory, timedelta(minutes=10))

    queuer_and_cleaner = QueuerAndCleaner(opts.concurrency, cleaner_factory, dao_factory, opts.fill_queue_size, opts.wait_between_fills)
    queuer_and_cleaner.do_cleaning(
        select_expired_segments,
        lambda segments: (segments[-1].last_access_time,),
        (StreamingDao.DATE_IN_PAST,),
        opts.mds_timeout
    )
    queuer_and_cleaner.do_cleaning(
        select_discarded_keys,
        lambda mds_keys: (mds_keys[-1].mds_key, mds_keys[-1].creation_time),
        ('', StreamingDao.DATE_IN_PAST),
        opts.mds_long_timeout
    )
