# -*- coding: utf-8 -*-
import logging
from itertools import product
from os import remove
from os.path import basename, dirname, join
from multiprocessing.dummy import Pool
from zipfile import ZipFile
from datetime import datetime as dt
from dateutil.relativedelta import relativedelta
import yaml

from robot import PATTERN
from yql_worker import yql_worker
from other_worker import get_period_from_yt, get_staff_position, read_yt_tables

logger = logging.getLogger(__name__)


class TaskWorker(object):
    def __init__(self, test_mode=False, issue=None):
        super(TaskWorker, self).__init__()
        if not test_mode:
            issue.transitions['start_progress'].execute()

        logger.info('{}: init TaskWorker'.format(issue.key))
        self.test_mode = test_mode
        self.issue = issue
        self._parse_issue_description()
        self._validate_task()

        self._set_created_by()

        self.workers = []
        self.pptx_path = None

    def _parse_issue_description(self):
        logger.info('{}: parse issue description'.format(self.issue.key))

        description = PATTERN.match(self.issue.description).groupdict()
        params = yaml.safe_load(description.get('params'))

        self.categories = [cat.strip() for cat in description.get('categories', '').split('\n')]
        regions = params.get('regions')
        self.regions = [u'Россия'] if regions in [None, 'None', ''] else [reg.strip() for reg in regions.split(',')]
        if len(self.regions) != 1 and u'Россия' not in self.regions:
            if len(self.regions) == 4:
                self.regions.insert(0, u'Россия')
            else:
                self.regions.insert(0, u', '.join(self.regions))

        self.query_params = get_query_params(self.issue, params, description.get('competitors'))

    def _set_created_by(self):
        created_by = self.issue.createdBy
        self.created_by = {
            'name': u'%s %s' % (created_by.firstName, created_by.lastName),
            'position': get_staff_position(created_by.login)
        }

    def _validate_task(self):
        logger.info('{}: validate task'.format(self.issue.key))

        errs = []
        yt_brands, yt_cats = read_yt_tables()
        show_data = False

        if self.query_params['client'] not in yt_brands:
            show_data = True if not show_data else show_data

            logger.error(u'{}: bad client {}'.format(self.issue.key, self.query_params['client']))
            errs.append(
                u'* Клиент !!{}!! не найден в исходных данных'.format(self.query_params['client'])
            )

        if self.query_params['competitors'] != '':
            for comp in self.query_params['competitors'].split('\n'):
                if comp not in yt_brands:
                    show_data = True if not show_data else show_data
                    logger.error(u'{}: bad competitor {}'.format(self.issue.key, comp))
                    errs.append(u'* Конкурент !!{}!! не найден в исходных данных'.format(comp))
        if errs:
            errs += ['']

        for cat in self.categories:
            if cat not in yt_cats:
                show_data = True if not show_data else show_data

                logger.error(u'{}: bad category {}'.format(self.issue.key, cat))
                errs.append(u'* Категория !!{}!! не найдена в исходных данных'.format(cat))
        if errs:
            errs += ['']

        if self.query_params['first_date'] > self.query_params['last_date']:
            logger.error(u'{}: invalid period {} - {}'
                         .format(self.issue.key, self.query_params['first_date'], self.query_params['last_date']))
            errs.append(
                u'* Период !!{} - {}!! указан некорректно: начало периода не может быть позже окончания'
                .format(self.query_params['first_date'], self.query_params['last_date'])
            )

        first_valid_date, last_valid_date = get_period_from_yt()
        if self.query_params['first_date'] < first_valid_date or self.query_params['last_date'] > last_valid_date:
            logger.error(u'{}: unavailable period {} - {}'.format(
                self.issue.key, self.query_params['first_date'], self.query_params['last_date']))
            errs.append(
                u'* Период !!{} - {}!!, указаный в задании, недоступен. Данные доступны с {} по {} влючительно'
                .format(self.query_params['first_date'],
                        self.query_params['last_date'],
                        first_valid_date,
                        last_valid_date
                        )
            )

        if self.query_params['comp_num'] < 3:
            logger.error(u'{}: few competitors ({})'.format(self.issue.key, self.query_params['comp_num']))
            errs.append(
                u'* Количество конкурентов для сравнения (!!{}!!) должно быть не меньше 3х, '
                u'есть вероятность раскрыть данные конкурентов'
                .format(self.query_params['comp_num'])
            )
        if errs:
            errs += ['']

        if errs:
            if show_data:
                errs += [u'Актуальный список клиентов и категорий (({} тут)).'.format(
                    u'https://wiki.yandex-team.ru/Sales/Mediaplanning/Aboutus/vipplanners/template/btz/'
                    u'#spiskiaktualnyxbrendovikategorijj')
                ]
            errs += ['']
            self.issue.comments.create(
                text=u'В задании найдены ошибки:\n{}\n{}'.format(
                    u'\n'.join(errs), u'Чтобы перезапустить задачу, поправьте ошибки в теле тикета и переоткройте его.'
                ),
                summonees=[self.issue.createdBy.login, 'vbatraev']
            )
            if not self.test_mode:
                self.issue.transitions['need_info'].execute()

            raise TaskWorkerError('{}: task isn\'t validated'.format(self.issue.key))

        else:
            logger.info('{}: task validated'.format(self.issue.key))

    def run_categories(self):
        logger.info('{}: run categories'.format(self.issue.key))
        if not self.test_mode:
            self.issue.comments.create(text='Запуск расчетов.')

        pool = Pool(5)
        for task_result in pool.imap_unordered(yql_worker, self._generate_query_params()):
            if task_result['status'] == 'OK':
                worker = task_result['worker']
                self.workers.append(worker)

            else:
                pool.terminate()

                self.drop_results()
                if not self.test_mode:
                    self.issue.transitions['need_info'].execute()
                raise TaskWorkerError('{}: terminate pool!'.format(self.issue.key))

    def get_worker(self, cat, region):
        for worker in self.workers:
            if worker.cat == cat and worker.region == region:
                return worker
        return None

    def _generate_query_params(self):
        for id_, (cat, region) in enumerate(product(self.categories, self.regions), 1):
            query_params = self.query_params.copy()

            query_params['id'] = id_
            query_params['category'], query_params['region'] = cat, region
            yield {
                'issue': self.issue,
                'query_params': query_params
            }

    def paste_results(self):

        if self.test_mode:
            return

        fc_text, lc_text = [], []
        for worker in self.workers:
            if worker.few_competitors:
                fc_text += worker.few_competitors

            if worker.less_competitors:
                lc_text += [u'* {}, {}'.format(worker.cat, worker.region)]

        if fc_text:
            fc_text = u'<{В некоторых категориях недостаточно конкурентов:\n%s\n}>\n' % u'\n'.join(fc_text) + \
                      u'Значения конкурентов будут обнулены на графиках в презентации. ' \
                      u'Полные данные можно посмотреть в соответствующей выгрузке.'

        if lc_text:
            lc_text = u'<{В некоторых категориях количество конкурентов меньше, чем указано ' \
                      u'в задании:\n%s\n}>\n' % u'\n'.join(lc_text)

        text = u'\n\n'.join([text for text in [fc_text, lc_text, u'Результаты расчетов и презентация готовы.'] if text])

        results_path = self._zip_results()
        self.issue.comments.create(text=text,
                                   attachments=[results_path],
                                   summonees=[self.issue.createdBy.login])

        remove(results_path)
        self.issue.transitions['close'].execute(resolution='fixed')

        self.drop_results()
        logger.info('{}: successfully completed!'.format(self.issue.key))

    def _zip_results(self):
        logger.info('{}: zip results'.format(self.issue.key))
        zip_path = join(dirname(dirname(__file__)), '{}.zip'.format(self.issue.key))
        zip_file = ZipFile(zip_path, 'w')

        for cat_worker in self.workers:
            zip_file.write(cat_worker.xlsx_path,
                           arcname=basename(cat_worker.xlsx_path))

        zip_file.write(self.pptx_path, arcname=basename(self.pptx_path))

        zip_file.close()
        return zip_path

    def drop_results(self):
        logger.info('{}: drop temp results'.format(self.issue.key))
        for cat_worker in self.workers:
            remove(cat_worker.xlsx_path)

        if self.pptx_path:
            remove(self.pptx_path)


class TaskWorkerError(Exception):
    def __init__(self, msg=None):
        self.msg = msg
    pass


def get_query_params(issue, params, competitors):
    first_date, last_date = get_period(params.get('period'))

    competitors = '' if competitors in [None, 'None', ''] else u'\n'.join(
        [comp.strip() for comp in competitors.split('\n') if comp.strip()])

    comp_num = params.get('comp_num')
    if comp_num in [None, 'None', '']:
        comp_num = 5

    only_search = u'нет' if params.get('partners') in (u'Да', u'Yes') else u'да'

    currency = params.get('currency', u'рубль')
    currency = u'рубль' if currency in [None, 'None', ''] else currency
    q_currency = 'rub' if currency == u'рубль' else 'ye'

    vat = params.get('vat', u'Да')
    vat = u'Да' if vat in [None, 'None', ''] else vat
    vat = True if vat in (u'Да', u'Yes') else False

    query_params = {
        'login': issue.createdBy.login,
        'issue': issue.key,
        'client': params.get('client').strip(),
        'competitors': competitors,
        'category': None,
        'region': None,
        'id': None,
        'first_date': first_date,
        'last_date': last_date,
        'comp_num': comp_num,
        'only_search': only_search,
        'q_currency': q_currency,
        'currency': currency,
        'vat': vat
    }
    logger.info(u'{}: {}'.format(issue.key, query_params))
    return query_params


def get_period(period):
    if period in [None, 'None', '']:
        period = None

    if period:
        first_date, last_date = [p.strip() for p in period.split('-')]

    else:

        first_valid_date, last_valid_date = get_period_from_yt()
        last_date = dt.strptime(unicode(last_valid_date), '%Y%m')
        first_date = last_date - relativedelta(months=2)

        if first_date.month in [1, 4, 7, 10]:
            pass
        elif first_date.month in [2, 5, 8, 11]:
            first_date = first_date - relativedelta(months=1)
        else:
            first_date = first_date - relativedelta(months=2)
        last_date = first_date + relativedelta(months=2)

        first_date = first_date.strftime('%Y%m')
        last_date = last_date.strftime('%Y%m')

    return int(first_date), int(last_date)
