import json
import logging
import time
from collections import Counter, defaultdict

import psycopg2
import requests
import random

from logbroker_processors.processors import Processor

log = logging.getLogger(__name__)


NULL = '\\N'
SEP = '\t'
EOL = '\n'
SLEEP = 1

DEFAULT_BATCH_SIZE = 100

UPSERT_SQL = 'INSERT INTO {table_name} VALUES {values_template} ' \
             'ON CONFLICT ({table_index_fields}) ' \
             'DO UPDATE SET last_dt = \'{last_dt}\' ' \
             'WHERE {table_name}.last_dt != \'{last_dt}\''

FIELDS_MAPPING = {
    'uid': 'uid',
    'date': 'last_dt',
    'module': 'module',
    'userAgent': 'user_agent'
}

TABLES_FIELDS = {
    'history.user_activity': ('uid', 'module', 'last_dt'),
    'history.user_agent_activity': ('uid', 'user_agent', 'last_dt')
}

FILTER_VALUES = {'module': ('fastsrv', 'settings')}

TABLES_INDEX_FIELDS = {
    'history.user_activity': 'uid, module',
    'history.user_agent_activity': 'uid, user_agent'
}

SAVE_FIELDS = FIELDS_MAPPING.keys()


class ActivitydbProcessor(Processor):
    def __init__(self, **opts):
        self.opts = opts
        self.data_accumulator = defaultdict(list)
        self.max_ts = None
        self.cur = None
        self.batch_size = int(opts.get('batch_size', DEFAULT_BATCH_SIZE))
        while not self._connect_to_master():
            time.sleep(SLEEP)

    def _connect_to_master(self):
        url = 'http://c.yandex-team.ru/api/groups2hosts/{cond_group}'.format(
            cond_group=self.opts['cond_group'])
        r = requests.get(url, timeout=SLEEP)
        # r.raise_for_status()
        for host in r.text.splitlines():
            # close old cursor and connection ignoring errors
            try:
                self.cur.close()
                self.conn.close()
            except Exception:
                pass

            try:
                self.conn = psycopg2.connect('host=%s %s' %
                                             (host, self.opts['conn_string']))
                self.conn.autocommit = False
                self.cur = self.conn.cursor()
                self.cur.execute('SELECT pg_is_in_recovery();')
                if not self.cur.fetchone()[0]:
                    log.info('Connected to database primary host: %s', host)
                    return True
            except Exception as exc:
                log.error(
                    'Can not connect to database host %s: %s',
                    host,
                    exc,
                    exc_info=True)
        return False

    @staticmethod
    def __normalize(data):
        diff = set(SAVE_FIELDS).difference(set(data.keys()))

        # set missing required fields to NULL
        for k in diff:
            data[k] = NULL

        # sanitize required fields
        for k in SAVE_FIELDS:
            try:
                if data[k] == NULL:
                    continue
                data[k] = ''.join([i if ord(i) < 128 else '' for i in data[k]])
                data[k] = data[k].replace('\n', '').replace('\t', '')
                data[k] = data[k].replace('\.', '').replace('\\', '')
            except Exception as exc:
                log.error(
                    'Error while processing "%s": %s',
                    data[k],
                    exc,
                    exc_info=True)
                pass

        try:
            timestamp = time.localtime(int(data['date']) / 1000)
        except (KeyError, ValueError):
            timestamp = time.localtime()

        # we need only date in activity tables
        data['date'] = time.strftime("%Y-%m-%d", timestamp)
        # also save timestamp for report
        data['timestamp'] = time.asctime(timestamp)

        data['state'] = NULL

        if not unicode(data['uid']).isdigit():
            return None

        # cast type to int (needed for tuples deduplication)
        data['uid'] = int(data['uid'])

        return data

    def process(self, header, data):
        data = self.__normalize(data)
        if data is None:
            return True

        self.max_ts = max(time.strptime(data['timestamp']), self.max_ts)

        processed_data = {}
        for lb_key, pg_key in FIELDS_MAPPING.items():
            processed_data[pg_key] = data[lb_key]

        # filter undesired values
        for field, filter_values in FILTER_VALUES.items():
            if field in processed_data and processed_data[field] \
                    in filter_values:
                return True

        # group by date, see insert query for details
        self.data_accumulator[data['date']].append(processed_data)

        return True

    def execute_insert_values(self, tables_values, last_dt):
        """
        Insert collected tuples into provided tables
        """

        rows_updated = Counter()
        for table_name, values in tables_values.items():
            values_iter = tuple(values)
            # form template for VALUES() clause
            values_template = ', '.join(['%s'] * len(values))
            sql_str = self.cur.mogrify(
                UPSERT_SQL.format(
                    table_name=table_name,
                    values_template=values_template,
                    table_index_fields=TABLES_INDEX_FIELDS[table_name],
                    last_dt=last_dt), values_iter)

            log.debug('Executing sql for table "%s" and date "%s": %s',
                      table_name, last_dt, sql_str)
            self.cur.execute(sql_str)
            rows_updated[table_name] += self.cur.rowcount

        return rows_updated

    def flush(self, force=True):
        log.info('Starting flush')
        is_flushed = False
        lines_accumulated = sum(
            len(values) for values in self.data_accumulator.values())
        max_ts = time.asctime(
            self.max_ts) if self.max_ts is not None else 'None'
        log.info('%d lines to flush, max timestamp is %s.' %
                 (lines_accumulated, max_ts))
        if lines_accumulated == 0:
            return True

        values = defaultdict(set)
        retry = 0
        while not is_flushed:
            try:
                log.debug('Preparing commit')
                rows_updated = Counter()
                log.error('Collected dates %s', self.data_accumulator.keys())
                for log_date, log_date_data in self.data_accumulator.items():
                    log.error('Collected %s lines for log date "%s"',
                              len(log_date_data), log_date)
                    values = defaultdict(set)
                    batch_num = 1

                    for log_date_line in log_date_data:
                        for table, fields in TABLES_FIELDS.items():
                            table_tuple = tuple(log_date_line[f] for f in fields)
                            # filter tuple due to NULL value (only for current table)
                            if any(f == NULL for f in table_tuple):
                                continue
                            values[table].add(table_tuple)

                        if max(len(v)
                               for v in values.values()) >= self.batch_size:
                            log.debug('Sending %s batch for date "%s"',
                                      batch_num, log_date)
                            rows_updated.update(
                                self.execute_insert_values(values, log_date))
                            values = defaultdict(set)
                            batch_num += 1

                    if values:
                        log.debug('Sending remaining lines for date "%s"',
                                  log_date)
                        rows_updated.update(
                            self.execute_insert_values(values, log_date))

                    self.conn.commit()
                    del self.data_accumulator[log_date]

                retry = 0

                log.debug('Committed successfully')
                log.info('Lines processed: %s, total rows updated: %s',
                         lines_accumulated, sum(rows_updated.values()))
                log.debug('Tables row updated: %s', json.dumps(rows_updated))

                self.data_accumulator = defaultdict(list)
                self.max_ts = None
                is_flushed = True
            except Exception as exc:
                log.error(
                    'Unexpected error while sending data: %s',
                    exc,
                    exc_info=True)

                # preventing having a deadlock
                time.sleep(random.randint(1, 15))
                retry += 1

                while not self._connect_to_master():
                    time.sleep(random.randint(1, 5))

        log.debug('Finished flush')
        return is_flushed
