# -*- coding: utf-8 -*-
from datetime import datetime, timedelta
import logging
from os.path import dirname, join
import pandas as pd
import re
import requests
from StringIO import StringIO
from textwrap import dedent
from time import sleep

from make_data import DICTS
from robot import SECRET
from robot.errs import (
    ClickHouseError,
    MaxRetriesError,
    TimeOutError,
    QueryTooSlow,
    QuotaError
)

logger = logging.getLogger(__name__)

ATTR_PREFIX_TO_ID = {
    '': 1,
    'LastSignificantTraficSource': 2,
    'FirstTraficSource': 3,
}


def restartable(retries=2, pause=3):
    def decorator(func):
        def wrapper(*args, **kwargs):
            cpa_task = args[0]
            for try_ in xrange(retries):
                try:
                    return func(*args, **kwargs)

                except KeyboardInterrupt as kb_exception:
                    raise kb_exception

                except (QueryTooSlow, QuotaError) as exc:
                    logger.warning('{}:{}'.format(cpa_task.issue.key, exc.message))
                    if not cpa_task.test_mode:
                        cpa_task.issue.comments.create(text=u'%%\n{}\n%%'.format(exc.message))

                    logger.warning('{}: restarting in {} seconds...'.format(cpa_task.issue.key, exc.timeout))
                    sleep(exc.timeout)

                except ClickHouseError as exc:
                    raise exc

                except Exception as exc:
                    logger.warning('{}: an exception occurred at attempt #{}'.format(cpa_task.issue.key, try_))
                    logger.warning('{}: {}'.format(cpa_task.issue.key, exc.message))
                    logger.warning('{}: restarting in {} seconds...'.format(cpa_task.issue.key, pause))
                    sleep(pause)
                    continue

            raise MaxRetriesError('Max retries has been reached ({})'.format(retries))

        return wrapper

    return decorator


@restartable(retries=60, pause=60)
def clickhouse_request(cpa_task, query, header,
                       host=DICTS['host'],
                       user=SECRET['ClickHouse']['user'], password=SECRET['ClickHouse']['password']):
    request = requests.post(host, auth=(user, password), data=query, timeout=60)
    if 'Timeout: connect timed out' in request.text:
        raise TimeOutError(request.text.strip())

    elif 'Quota for user' in request.text:
        message = request.text.strip()
        raise QuotaError(message, timeout=get_quota_timeout(message))

    elif 'Query is executing too slow' in request.text:
        raise QueryTooSlow(request.text.strip(), timeout=60 * 15)

    elif 'Syntax error' in request.text:
        raise ClickHouseError(request.text.strip(), query=query)

    elif request.status_code != 200:
        raise Exception(request.text.strip())

    result = pd.read_csv(StringIO(request.text), sep='\t', names=header)
    if result.empty:
        return None
    else:
        return result.fillna('')


def get_quota_timeout(msg):
    pattern = re.compile(r'20[0-9]{2}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}', flags=re.DOTALL)
    to_date = re.search(pattern, msg)
    if to_date:
        to_date = datetime.strptime(to_date.group(), '%Y-%m-%d %H:%M:%S')
        now_date = datetime.now().utcnow() + timedelta(hours=3)
        return (to_date - now_date).total_seconds() + 60

    else:
        return 60 * 20


def get_visits_select(separate_data, attr_prefix):
    select = {}

    if separate_data:
        for metric in ['visits', 'bounces', 'views', 'duration']:
            select[metric] = '\'\''

    else:
        select['visits'] = 'sum(Sign)'
        select['bounces'] = 'sum(Sign * IsBounce)'
        select['views'] = 'sum(Sign * PageViews)'
        select['duration'] = 'sum(Sign * Duration)'

    return select


def get_goals_condition(goals):
    if goals:
        return ' OR '.join(['has(Goals.ID, %s)' % goal.strip() for goal in goals.split(',')])
    else:
        return 'has(Goals.ID, 0)'


def get_domains_condition(domain, domain_id):
    return u'AND if(URL != \'\', match(lower(URL), lower(\'{domain}\')), DomainID == {domain_id})'.format(
        domain=domain.strip().lower(), domain_id=domain_id)


def get_domains_condition_for_chyt(domain_id):
    return u'AND domainid == {domain_id}'.format(domain_id=domain_id)


def get_regions_condition(regions, attr_prefix=''):
    if regions:
        if attr_prefix:
            attr_id = ATTR_PREFIX_TO_ID[attr_prefix]
            region_field = 'TrafficSource.ClickRegionID[indexOf(TrafficSource.Model, %s)]' % attr_id
        else:
            region_field = 'RegionID'
        return dedent(
            """
            AND (regionToCountry(toUInt32({region_field})) IN ({regions}) OR
            regionToDistrict(toUInt32({region_field})) IN ({regions}) OR
            regionToArea(toUInt32({region_field})) IN ({regions}) OR
            regionToCity(toUInt32({region_field})) IN ({regions}))
            """.format(region_field=region_field, regions=regions)
        ).strip().replace('\n', ' ')
    else:
        return ''


def get_regions_condition_for_chyt(regions, attr_prefix=''):
    if regions:
        attr_prefix = '%sClick' % attr_prefix if attr_prefix else attr_prefix
        return dedent(
            """
            AND (regionToCountry(toUInt32(regionid)) IN ({regions}) OR
            regionToDistrict(toUInt32(regionid)) IN ({regions}) OR
            regionToArea(toUInt32(regionid)) IN ({regions}) OR
            regionToCity(toUInt32(regionid)) IN ({regions}))
            """.format(attr_prefix=attr_prefix, regions=regions)
        ).strip().replace('\n', ' ')
    else:
        return ''


def get_campaigns_table():
    return dedent(
        """
        INNER JOIN
            (
                SELECT
                    CAST(OrderID AS UInt64) AS orderid,
                    name AS CampaignName
                FROM `//home/direct/db/campaigns`
            ) AS a
        USING orderid
        """
    ).strip().replace('\n', ' ')


def get_query_txt(query_name):
    with open(join(dirname(__file__), 'queries', query_name)) as fd:
        return dedent(fd.read().decode('utf8')).strip()
