# -*- coding: utf-8 -*-
from __future__ import unicode_literals

import functools
from collections import Counter, OrderedDict as oDict, namedtuple
from itertools import (
    chain, combinations,
    ifilterfalse, islice)
from multiprocessing.dummy import Pool

import six
from chapi.tools import date_range

from helpers.utils import isstopword
from sql import (
    HITS, STATS, WORD_COUNT,
    QUERY_COUNT, HOST_COUNT,
    DUMMY, QUERY_HOST_COUNT,
    with_percent,
    sql_params, union)


def _runtimecheck(method):
    @functools.wraps(method)
    def wrap(self, *args, **kwargs):
        if not self.tables:
            raise RuntimeError('It seems that there will be no results, call `get` first')
        return method(self, *args, **kwargs)
    return wrap


class Proto(object):
    db = 'sins'
    setlike = {'accept', 'reject', 'searchids',
               'reject_hosts', 'accept_hosts',
               'pre_reject', 'pre_accept',
               'accept_category', 'reject_category'
               }
    arraylike = {'replacements'}
    need_lists = {'replacements'}
    Word = namedtuple('Word', 'word,count,fraction')
    Query = namedtuple('Query', 'query,count,fraction')
    Host = namedtuple('Host', 'host,count,fraction')
    QueryHost = namedtuple('QueryHost', 'query_host,count,fraction')
    WordComb = namedtuple('WordComb', 'comb,count')

    def __init__(self, conn, pool=None, threads=2, **defaults):
        self.config = dict(accept=set(),
                           reject=set(),
                           pre_reject=set(),
                           pre_accept=set(),
                           accept_hosts=set(),
                           reject_hosts=set(),
                           accept_category=set(),
                           reject_category=set(),
                           replacements=[],
                           searchids={2, 13, 181},
                           start=None,
                           end=None,
                           sample=0.01,
                           delta=20,
                           custom='',
                           )
        if 'json' in defaults:
            import json
            with open(defaults['json']) as f:
                options = json.load(f)
        else:
            options = defaults
        self.configure(append=False, **options)
        if self.config['start'] is None:
            import datetime
            self.config['start'] = (datetime.date.today() - datetime.timedelta(days=2))
        if self.config['end'] is None:
            import datetime
            self.config['end'] = (datetime.date.today() - datetime.timedelta(days=1))
        # conn is Connetion from https://github.yandex-team.ru/ferres/python-clickhouse-api
        self.conn = conn
        self.pool = pool if pool is not None else 'auto_%d' % id(self)
        self.threads = threads
        self.tables = []
        self.param_range = oDict()

    def filter_base(self, **kwargs):
        dic = self._kwargs(**kwargs.copy())
        params = sql_params(**dic)
        params['date'] = '{topic_date}'
        hits = HITS.format(**params)
        return hits

    def uids_filter(self, uid='UserID', **kwargs):
        statement = '''
        {uid} in (
        SELECT UserID from ({subquery})
        )
        '''
        return statement.format(
                uid=uid,
                subquery=self.filter_base(**kwargs))

    def uids_at(self, date):
        statement = ('SELECT UserID from ({subquery})'
                     .format(subquery=self.filter_base())
                     .format(topic_date=date)
                     )
        return self.conn.tmp(statement)

    def topic_filter(self, timestamp='EventTime',
                     uid='UserID',
                     searchphrase='SearchPhrase',
                     **kwargs):
        statement = '''
                NOT empty({searchphrase})
                AND
                ({uid}, {timestamp}) in (
                SELECT (UserID, Times)
                FROM ({subquery}) ARRAY JOIN Times
                )
                '''
        return statement.format(
            uid=uid,
            timestamp=timestamp,
            searchphrase=searchphrase,
            subquery=self.filter_base(**kwargs)
        )

    def _kwargs(self, **kwargs):
        dic = dict()
        dic['accept'] = kwargs.get('accept', self.config['accept'])
        dic['reject'] = kwargs.get('reject', self.config['reject'])
        dic['pre_accept'] = kwargs.get('pre_accept', self.config['pre_accept'])
        dic['pre_reject'] = kwargs.get('pre_reject', self.config['pre_reject'])
        dic['accept_hosts'] = kwargs.get('accept_hosts', self.config['accept_hosts'])
        dic['reject_hosts'] = kwargs.get('reject_hosts', self.config['reject_hosts'])
        dic['searchids'] = kwargs.get('searchids', self.config['searchids'])
        dic['start'] = kwargs.get('start', self.config['start'])
        dic['end'] = kwargs.get('end', self.config['end'])
        dic['sample'] = kwargs.get('sample', self.config['sample'])
        dic['delta'] = kwargs.get('delta', self.config['delta'])
        dic['custom'] = kwargs.get('custom', self.config['custom'])
        dic['replacements'] = kwargs.get('replacements', self.config['replacements'])
        dic['accept_category'] = kwargs.get('accept_category', self.config['accept_category'])
        dic['reject_category'] = kwargs.get('reject_category', self.config['reject_category'])
        return dic

    def get(self, debug=False, **kwargs):
        # get args stage
        dic = self._kwargs(**kwargs)

        # end get args stage
        if not (dic['accept'] or dic['accept_hosts'] or dic['pre_accept']):
            raise ValueError('`accept` should be specified')
        params = sql_params(**dic)
        if debug:
            query = HITS.format(date=dic['start'], **params)
            return query
        self.tables, self.param_range = self._create_daterange_tables(
                QUERY=HITS,
                params=params, start=dic['start'], end=dic['end'])

    def _create_daterange_tables(self, QUERY, params, start, end):
        param_range = oDict([(date, dict(date=date,
                                         tmp=self._get_tname(QUERY.format(date=date, **params)),
                                         **params))
                            for date in date_range(start, end)])
        query_date_range = [QUERY.format(**params) for params in param_range.values()]
        pool = Pool(self.threads)
        self.conn.gcollect('sins', prefix=self.pool)
        print 'will create {} tables\n'.format(len(query_date_range))
        tables = pool.map(self._create_tmp, query_date_range)
        param_range = oDict([(date, self._check_tmp_consistence(params)) for date, params in param_range.items()])
        pool.close()
        pool.join()
        return tables, param_range

    def _get_tname(self, query):
        import hashlib
        query = query
        name = '%s_tmp_%s' % (self.pool, hashlib.sha224(query.encode('utf-8', 'ignore')).hexdigest())
        table = '%s.%s' % (self.db, name)
        return table

    def _create_tmp(self, query):
        tmp_create = """
            CREATE TABLE {table} ENGINE = Log AS
            {select}
"""
        table = self._get_tname(query)
        creation = tmp_create.format(table=table, select=query)

        try:
            self.conn.request(creation, _output=False)
            print 'success create table: {}'.format(table)
        except Exception, e:
            print 'error create table: {}, ex: {}'.format(table, e)
        return table

    @_runtimecheck
    def all_results(self):
        res = dict(
            last=self.last(),
            words=self.word_count(None),
            hosts=self.host_count(None),
            queries=self.query_count(None),
            query_hosts=self.query_host_count(None),
            word_combs=self.word_combinations_count(None)
        )
        return res

    @property
    @_runtimecheck
    def total(self):
        c = self.conn.cursor()
        query = 'SELECT count() FROM (%s)' % union(self.tables)
        c.execute(query, gen_=False)
        return int(c.fetchone(plain=True)[0])

    @_runtimecheck
    def last(self, start=None, stop=None, step=None):
        c = self.conn.cursor()
        query = 'SELECT * FROM {table}'
        c.executeiter(query, [{'table': item['tmp']} for item in self.param_range.values()],
                      gen_=True, trace=False)
        return islice(c.fetchall(), start, stop, step)

    def last_to_tsv(self, out):
        import unicodecsv as csv
        last = self.last()
        header = ['id', 'i', 'userid', 'query', 'host', 'time']

        def good(id, i, userid, pht):
            return [
                id,
                i,
                userid,
                pht[0],
                pht[1],
                pht[2]
            ]
        writer = csv.writer(out, delimiter=str('\t'))
        writer.writerow(header)
        for id, delta in enumerate(last):
            for i, row in enumerate(zip(delta['Phrases'], delta['Hosts'], delta['Times'])):
                writer.writerow(good(id, i, delta['UserID'], row))

    @_runtimecheck
    def print_last(self, limit=20):
        print 'total %d' % self.total
        for item in self.last(stop=limit):
            print 'userid %s' % (item['UserID'])
            print 'Queries'
            for subitem in set(item['Phrases']):
                print "\t '%s'" % subitem.strip()
            print 'Hosts'
            for subitem in set(item['Hosts']):
                print '\t' + subitem

            print '\n'

    @_runtimecheck
    def word_count(self, top=None):
        end = ''
        if top is not None:
            end = 'LIMIT %d' % top
        c = self.conn.cursor()
        subquery = union(self.tables)
        query = with_percent(WORD_COUNT).format(subquery=subquery) + end
        c.execute(query)
        cast = lambda row: self.Word(row['Word'],
                                     int(row['count']), float(row['fraction']))
        return map(cast, c.fetchall())

    def word_count_to_tsv(self, out):
        import unicodecsv as csv
        wc = self.word_count()
        header = ['word', 'count', 'fraction']
        writer = csv.writer(out, delimiter=str('\t'))
        writer.writerow(header)
        for row in wc:
            writer.writerow(row)

    @_runtimecheck
    def query_count(self, top=None):
        end = ''
        if top is not None:
            end = 'LIMIT %d' % top
        c = self.conn.cursor()
        subquery = union(self.tables)
        query = with_percent(QUERY_COUNT).format(subquery=subquery) + end
        c.execute(query)
        cast = lambda row: self.Query(row['Query'],
                                      int(row['count']), float(row['fraction']))
        return map(cast, c.fetchall())

    def query_count_to_tsv(self, out):
        import unicodecsv as csv
        qc = self.query_count()
        header = ['query', 'count', 'fraction']
        writer = csv.writer(out, delimiter=str('\t'))
        writer.writerow(header)
        for row in qc:
            writer.writerow(row)

    @_runtimecheck
    def query_host_count(self, top=None):
        end = ''
        if top is not None:
            end = 'LIMIT %d' % top
        c = self.conn.cursor()
        subquery = union(self.tables)
        query = with_percent(QUERY_HOST_COUNT).format(subquery=subquery) + end
        c.execute(query)
        cast = lambda row: self.QueryHost((row['Query'], row['Host']),
                                          int(row['count']), float(row['fraction']))
        return map(cast, c.fetchall())

    def query_host_count_to_tsv(self, out):
        import unicodecsv as csv
        qhc = self.query_host_count()
        header = ['query', 'host', 'count', 'fraction']
        good = lambda r: [r[0][0], r[0][1]] + list(r[1:])
        writer = csv.writer(out, delimiter=str('\t'))
        writer.writerow(header)
        for row in qhc:
            writer.writerow(good(row))

    @_runtimecheck
    def host_count(self, top=None):
        end = ''
        if top is not None:
            end = 'LIMIT %d' % top
        c = self.conn.cursor()
        subquery = union(self.tables)
        query = with_percent(HOST_COUNT).format(subquery=subquery) + end
        c.execute(query)
        cast = lambda row: self.Host(row['Host'],
                                     int(row['count']), float(row['fraction']))
        return map(cast, c.fetchall())

    def host_count_to_tsv(self, out):
        import unicodecsv as csv
        hc = self.host_count()
        header = ['host', 'count', 'fraction']
        writer = csv.writer(out, delimiter=str('\t'))
        writer.writerow(header)
        for row in hc:
            writer.writerow(row)

    @_runtimecheck
    def word_combinations_count(self, top=None):
        """Computes pairwise frequencies _within_a_session_
        :return: list of most common
        """
        counter = Counter()
        for session in self.last():
            for combination in combinations(
                    set(ifilterfalse(isstopword,
                        chain(*(query.split() for query in set(session['Phrases']))))), 2):
                counter[tuple(sorted(combination))] += 1

        def cast(row):
            return self.WordComb(*row)
        return map(cast, counter.most_common(top))

    def word_combinations_to_tsv(self, out):
        import unicodecsv as csv
        qhc = self.word_combinations_count()
        header = ['word1', 'word2', 'count']
        good = lambda r: [r[0][0], r[0][1]] + list(r[1:])
        writer = csv.writer(out, delimiter=str('\t'))
        writer.writerow(header)
        for row in qhc:
            writer.writerow(good(row))

    @_runtimecheck
    def daystat(self, trace=False):
        c = self.conn.cursor()
        c.executeiter(STATS, self.param_range.values(), gen_=False, trace=trace)

        return c.fetchall()

    def daystat_to_tsv(self, out):
        import unicodecsv as csv
        ds = self.daystat()
        header = ['AllUniques',
                  'FilteredUniques',
                  'NotFilteredUniques',
                  'HitsAtFiltered',
                  'FilteredHits',
                  'HitsAtNotFiltered',
                  'AllHits',
                  'FilteredHitsDivHitsAtFiltered',
                  'FilteredHitsFraction',
                  'FilteredUniquesFraction',
                  'HitsAtFilteredFraction',
                  'Date']
        writer = csv.DictWriter(out, header, delimiter=str('\t'))
        writer.writerows(ds)

    @_runtimecheck
    def print_word_count(self, top=10):
        print 'count\t%\tword'
        for item in self.word_count(top):
            print '{:5d}\t{:.3f}%\t{}'.format(item[1], item[2]*100, item[0])

    @_runtimecheck
    def print_word_combinations_count(self, top=20):
        print 'count\t%\tcombination'
        for item in self.word_combinations_count(top):
            print '%5d\t({}, {})'.format(item[1], item[0][0], item[0][1])

    @_runtimecheck
    def print_query_host_count(self, top=20):
        print 'count\t%\tcombination'
        for item in self.query_host_count(top):
            print '{:5d}\t{:.3f}%\t({}, {})'.format(item[1], item[2]*100, item[0][0], item[0][1])

    @_runtimecheck
    def print_query_count(self, top=100):
        print 'count\tquery'
        for item in self.query_count(top):
            print '{:5d}\t{:.3f}%\t{}'.format(item[1], item[2]*100, item[0])

    @_runtimecheck
    def print_host_count(self, top=100):
        print 'count\thost'
        for item in self.host_count(top):
            print '{:5d}\t{:.3f}%\t{}'.format(item[1], item[2]*100, item[0])

    @_runtimecheck
    def navig_estimate(self, threshold=3):
        import re
        import transliterate
        import editdistance
        pattern = re.compile(r'[\s\-.]+')

        def norm(query):
            return re.sub(pattern, '.', query.strip())

        def is_navig(rec):
            query = norm(rec[0][0])
            host = norm(rec[0][1])
            host = transliterate.translit(
                    host, language_code=str('ru'),
                    reversed=True, strict=False)
            query = transliterate.translit(
                    query, language_code=str('ru'),
                    reversed=True, strict=False)
            s_host = '.'.join(host.strip('.')[:-1])
            dist1 = editdistance.eval(query, host)
            dist2 = editdistance.eval(query, s_host)
            return min(dist1, dist2) <= threshold

        fraction = 0
        count = 0
        for record in filter(is_navig, self.query_host_count(None)):
            count += record[1]
            fraction += record[2]
        return count, fraction

    def __str__(self):
        import unicodedata
        mes = """
        start : {start}
        end   : {end}
        delta : {delta}
        sample: {sample}

        replacements:
                {_replacements}

        searchids:
                {_searchids}

        accept:
            pre accept:
                {_pre_accept}
            accept:
                {_accept}
            hosts:
                {_accept_hosts}
        reject:
            pre reject:
                {_pre_reject}
            reject
                {_reject}
            hosts:
                {_reject_hosts}
        custom:
{custom}
        """

        def format(items):
            return '\n\t\t'.join(map(lambda s: "'%s'" % s, items))
        _searchids = '[' + ','.join(map(six.text_type, self.config['searchids'])) + ']'
        _accept = format(self.config['accept'])
        _pre_accept = format(self.config['pre_accept'])
        _accept_hosts = format(self.config['accept_hosts'])
        _reject = format(self.config['reject'])
        _pre_reject = format(self.config['pre_reject'])
        _reject_hosts = format(self.config['reject_hosts'])
        _replacements = '\n\t\t'.join(map(
                lambda i_p_r: "%2d: '%s'->'%s'" %
                              (i_p_r[0], i_p_r[1][0], i_p_r[1][1]),
                enumerate(self.config['replacements'])))
        res = mes.format(_searchids=_searchids,
                         _accept=_accept,
                         _accept_hosts=_accept_hosts,
                         _reject=_reject,
                         _reject_hosts=_reject_hosts,
                         _replacements=_replacements,
                         _pre_accept=_pre_accept,
                         _pre_reject=_pre_reject,
                         **self.config)
        res = unicodedata.normalize('NFKD', res).encode('utf-8', 'replace')
        return res

    def dump_sins_config(self, path):
        """It is going be like a transfer to sins on YT"""
        import json
        with open(path, 'w') as f:
            json.dump(self.serializable_config, f)

    @property
    def sins_config(self):
        import json
        return json.dumps(self.serializable_config)

    @property
    def serializable_config(self):
        """Usual set is not serializable, it will be a list then
        """
        copy = self.config.copy()
        for key in self.setlike:
            copy[key] = list(copy[key])
        copy['start'] = str(copy['start'])
        copy['end'] = str(copy['end'])
        return copy

    def __del__(self):
        if self.conn:
            self.conn.gcollect('sins', prefix=self.pool)

    def configure(self, append=True, **kwargs):
        """Code seems to be ugly and there are problems with unicode json"""
        for key in self.setlike:
            val = kwargs.pop(key, [])
            if isinstance(val, six.string_types):
                raise ValueError('No strings allowed in setlike config')
            if val:
                if not append:
                    self.config[key].clear()
                self.config[key].update(val)
        for key in self.arraylike:
            val = kwargs.pop(key, [])
            if isinstance(val, six.string_types):
                raise ValueError('No strings allowed in arraylike config')
            if val:
                if not append:
                    del self.config[key][:]
                for v in val:
                    if key in self.need_lists:
                        if not isinstance(v, (list, tuple)):
                            raise ValueError('`%s` needs list or tuple' % key)
                        self.config[key].append(tuple(v))
                    else:
                        self.config[key].append(v)
        self.config.update(kwargs)

    def config_remove(self, **kwargs):
        """Code seems to be ugly"""
        for key in self.setlike:
            value = kwargs.pop(key, [])
            if value:
                for val in value:
                    try:
                        self.config[key].remove(val)
                    except KeyError:
                        import sys
                        sys.stderr.write('No such pattern %s in %s' % (val, key))
        for key in self.arraylike:
            value = kwargs.pop(key, None)
            if isinstance(value, int):
                self.config[key].pop(value)
            elif value is None:
                continue
            else:
                raise ValueError('Please use single int positions for arraylike options')

    def isempty(self, tmp):
        q = 'select count() from %s' % tmp
        c = self.conn.cursor()
        c.execute(q)
        return not bool(c.rowcount)

    def _check_tmp_consistence(self, params):
        if self.isempty(params['tmp']):
            drop = 'drop table {}'
            create = 'CREATE TABLE {} ENGINE=Log AS {}'
            self.conn.request(
                    drop.format(params['tmp']),
                    _output=False)
            self.conn.request(
                    create.format(params['tmp'], DUMMY),
                    _output=False)
        return params
