from __future__ import absolute_import
from __future__ import division
from celery import shared_task
from django.utils import timezone
from django.db.models import Count, Max
from django.forms.models import model_to_dict
from django.conf import settings
from django.core.cache import cache
from .deploy import rollout, send_notifications, upload_to_sandbox, SandboxApi, create_deployment, log_tb
from .models import (Snapshot, Deployment, ShowHistogram, AutoQuota, ProductQuota,
        StaleException, TimeoutException)
from .nanny import wait_for_version
from .shoot import trie_deployed, get_empty_rate
from .statface import get_stats
from .quota import update_product_shares
import time
import datetime
import calendar
import subprocess
import requests
import urllib
import tempfile
import pickle
import os
import os.path
import gzip
import re
from six import print_

"""
Deployment states:
    active: new -> enqueued -> uploading -> test_releasing -> shooting -> releasing
    final:  success | failure | superseded
"""


def check_deployment(deployment):
    if timezone.now() - deployment.created < datetime.timedelta(hours=1):
        return
    latest = Deployment.objects.latest('created')
    if latest.id != deployment.id:
        raise StaleException("Superseded by deployment {}".format(latest))


def state_transition(deployment, old_state, new_state):
    check_deployment(deployment)
    rows = Deployment.objects.filter(
            pk=deployment.id, state=old_state).update(state=new_state)
    if rows != 1:
        raise StaleException("Deployment {}: state out of sync".format(deployment.id))
    deployment.state = new_state


def send_mail(deployment):
    send_notifications(deployment.snapshot)


def get_ammo():
    filename = '/tmp/atomsearch.ammo.gz'
    if not os.path.exists(filename):
        get_new_ammo(filename)
    return filename


def shoot_task(deployment, login):
    deployment.add_log('ready for shooting')
    ammo = get_ammo()
    t_rate = get_empty_rate(gzip.open(ammo), True, 1000)
    p_rate = get_empty_rate(gzip.open(ammo), False, 1000)
    if t_rate < 0.001 or p_rate < 0.001:
        return
    ratio = min(t_rate, p_rate) / max(t_rate, p_rate)
    deployment.add_log("empty rates: testing {:.3}, production {:.3}, ratio = {:.2}"
                       .format(t_rate, p_rate, ratio))
    if ratio < 0.95:
        return False
    return True


def get_querysearch_timestamp(branch):
    from .deploy import get_qs_request
    try:
        rq = get_qs_request(branch)
    except:
        return None
    data = rq.json().get('Data')
    if not data:
        return None
    return data[0]['Timestamp']


def wait_deploy(deployment, testing):
    snapshot = deployment.snapshot
    branch = 'atom-test-candidates' if testing else 'atom-candidates'
    retries = 120
    while True:
        ts = get_querysearch_timestamp(branch)
        if ts == int(snapshot.trie_name):
            break
        check_deployment(deployment)
        retries -= 1
        if not retries:
            raise TimeoutException('timeout waiting for deployment to querysearch')
        time.sleep(130 - retries)
    deployment.add_log('{} live on querysearch'.format(branch))
    retries = 60
    while not trie_deployed(branch, int(snapshot.trie_name)):
        check_deployment(deployment)
        retries -= 1
        if not retries:
            raise TimeoutException('timeout waiting for Atom deployment')
        time.sleep(15)


@shared_task
def upload_to_sandbox_task(deployment_id):
    deployment = Deployment.objects.get(pk=deployment_id)
    deployment.add_log('uploading data to Sandbox')
    try:
        task_id, stats = upload_to_sandbox(deployment, settings.SANDBOX_TOKEN)
        deployment.add_log('Sandbox task id: {}'.format(task_id))
        return task_id, stats
    except Exception as e:
        log_tb()
        deployment.add_log('Sandbox upload failed: {}'.format(e.message))
        raise


@shared_task
def two_step_rollout(snapshot_id, login, deployment_id):
    deployment = Deployment.objects.get(pk=deployment_id)
    deployment.add_log('started, trie id {}'.format(deployment.snapshot.trie_name))
    upload = upload_to_sandbox_task.delay(deployment_id)

    state_transition(deployment, 'new', 'enqueued')

    while True:
        previous = (Deployment.objects.filter(created__lte=deployment.created,
                created__gt=(deployment.created - datetime.timedelta(minutes=40)))
                .exclude(pk=deployment.pk)
                .exclude(state__in=['success', 'failure', 'superseded']))
        if not previous:
            break
        time.sleep(20)

    try:
        state_transition(deployment, 'enqueued', 'uploading')
        task_id, stats = upload.get()
        sandbox = SandboxApi(settings.SANDBOX_TOKEN)

        # stub: deploy to stable immediately
        state_transition(deployment, 'uploading', 'releasing')
        version = wait_for_version(task_id, time.time() + 1200, deployment)
        if version:
            deployment.add_log('Version {} on production machines.'.format(version))
            state_transition(deployment, 'releasing', 'success')
            send_notifications(deployment.snapshot, stats)
        else:
            deployment.add_log('Timeout waiting for Atom deploy.')
            state_transition(deployment, 'releasing', 'failure')
        return

        state_transition(deployment, 'uploading', 'test_releasing')
        deployment = Deployment.objects.get(pk=deployment_id)
        if not deployment.is_rollback:
            sandbox.release(task_id, to='testing', subject='trie {} by {}'.format(
                deployment.snapshot.trie_name, deployment.author))
            wait_deploy(deployment, True)
        state_transition(deployment, 'test_releasing', 'shooting')

        deployment = Deployment.objects.get(pk=deployment_id)
        good = deployment.is_rollback or shoot_task(deployment, login)
        if not good:
            # maybe it has been re-forced
            deployment = Deployment.objects.get(pk=deployment_id)
            good = good or deployment.is_rollback

        if good:
            state_transition(deployment, 'shooting', 'deploy')
            stats = rollout(deployment.snapshot, login, False)
            stats['comment'] = deployment.comment
            wait_deploy(deployment, False)
            state_transition(deployment, 'deploy', 'success')
            deployment.add_log('done')
            send_notifications(deployment.snapshot, stats)
        else:
            state_transition(deployment, 'shooting', 'failure')

    except StaleException:
        log_tb()
        deployment.state = 'superseded'
    except TimeoutException as e:
        log_tb()
        deployment.add_log(e.message)
        deployment.state = 'failure'
    except subprocess.CalledProcessError as e:
        log_tb()
        deployment.add_log('External program failure: return code {}, output:\n{}'
                           .format(e.returncode, (e.output or b'').decode('utf-8')))
        deployment.state = 'failure'
    except Exception as e:
        log_tb()
        deployment.add_log('E: {}'.format(e))
        deployment.state = 'failure'
        raise
    finally:
        deployment.save(update_fields=['state'])


@shared_task
def update():
    get_stats(2)


@shared_task
def purge_old_snapshots():
    two_months_ago = timezone.now() - datetime.timedelta(days=62)
    Deployment.objects.filter(created__lt=two_months_ago).delete()
    Snapshot.objects.defer('text').annotate(Count('deployment')).filter(
        created__lt=two_months_ago, deployment__count=0).delete()

    week_ago = timezone.now() - datetime.timedelta(days=7)
    ShowHistogram.objects.filter(period_start__lt=week_ago).delete()


def get_new_ammo(where):
    response = requests.get('https://sandbox.yandex-team.ru/api/v1.0/resource',
                            params=dict(limit=20, type='MOBILESEARCH_ATOM_AMMO',
                                        state='READY')).json()
    ammos = [i for i in response['items']
             if i['description'] == 'ammunition atomsearch.ammo.gz']
    if not ammos:
        return
    url = ammos[0]['http']['proxy']
    archive = tempfile.NamedTemporaryFile(
        prefix='atomsearch.ammo.', suffix='.gz', dir=os.path.dirname(where))
    urllib.urlretrieve(url, filename=archive.name)
    os.rename(archive.name, where)


def read_histogram(period, table_name):
    print_("reading histogram:", table_name)
    records = subprocess.check_output(['/Berkanavt/mapreduce/bin/mapreduce-dev',
            '-server', 'sakura00.search.yandex.net',
            '-read', table_name, '-subkey']).split('\n')
    print_('records:', len(records))
    histograms = []
    for e, r in enumerate(records):
        if not r:
            continue
        tabs = r.split('\t')
        if len(tabs) < 7:
            tabs[:1] = tabs[0].split('/', 2)
        if len(tabs) != 7:
            print 'Bad record:', r
            continue
        good = True
        for i in range(6):
            if len(tabs[i]) > 100:
                print 'Record #{}: Invalid column #{}:'.format(e, i), tabs[i]
                good = False
                break
        if not good:
            continue
        histograms.append(ShowHistogram(client=tabs[0], subclient=tabs[1],
            production_list=tabs[2], fml=tabs[3], success=bool(int(tabs[4])),
            histogram=tabs[6], period_start=period[0],
            period_duration=period[1] - period[0] + datetime.timedelta(seconds=300)))
    print_('adding {} histograms'.format(len(histograms)))
    ShowHistogram.objects.bulk_create(histograms)


LOCK_EXPIRE = 60*5 # 5 minutes

def mutex(task):
    from functools import wraps
    @wraps(task)
    def mtask(*args, **kwargs):
        lock_id = 'task-mutex-' + task.__name__
        if not cache.add(lock_id, 'true', LOCK_EXPIRE):
            print ('Lock active, exiting')
            return
        try:
            return task(*args, **kwargs)
        finally:
            cache.delete(lock_id)
    return mtask


@shared_task
@mutex
def update_histograms(since=None):
    if not since:
        agr = ShowHistogram.objects.all().aggregate(Max('period_start'))
        since = agr['period_start__max'] or 0
    if not isinstance(since, int):
        since = calendar.timegm(since.timetuple())
    two_days_ago_ts = int(
        (timezone.now() - datetime.timedelta(days=2)).strftime('%s')
    )
    if since < two_days_ago_ts:
        since = two_days_ago_ts
    print_("reading histograms since:", since)
    tables = subprocess.check_output(['/Berkanavt/mapreduce/bin/mapreduce-dev',
            '-server', 'sakura00.search.yandex.net',
            '-list', '-prefix', 'atom-quota/']).split('\n')
    new_tables = []
    for table in tables:
        timerange = re.search('([0-9]+)-300-([0-9]+)-300', table)
        if not timerange or not (int(timerange.group(1)) > since):
            continue
        period = [datetime.datetime.fromtimestamp(int(t), timezone.utc)
            for t in [timerange.group(1), timerange.group(2)]]
        print_('period:', period)
        read_histogram(period, table)
        new_tables.append(period[0])

    update_rtmr_stats()

    if new_tables:
        adjust_multipliers(max(new_tables))


def update_rtmr_stats():
    product_lists = set(pq.production_list for pq in ProductQuota.objects.all())
    for pl in product_lists:
        try:
            update_product_shares(pl)
        except Exception:
            pass


def adjust_multipliers(period_start):
    # latest_deployments = Deployment.objects.filter(
    #     state='success', created__lt=period_start).order_by('-created')[:10]
    # for deployment in latest_deployments:
    #     deployed_at = deployment.deploylog_set.latest('timestamp').timestamp
    #     if deployed_at < period_start:
    #         break
    # print_('deployed at:', deployed_at)

    updated = False
    for aq in AutoQuota.objects.all():
        print_(aq)
        # time to update based on the new histogram
        all_hist = ShowHistogram.objects.filter(period_start=period_start,
                production_list=aq.production_list)
        shows = sum(h.bin_sum() for h in all_hist.filter(success=True))
        fails = sum(h.bin_sum() for h in all_hist.filter(success=False))
        print_(aq, '=> shows:', shows, 'fails:', fails)
        total = shows + fails
        if total < 20:
            continue # no data to go on
        rate = shows * 1.0 / total
        if rate < aq.target_show_rate - 0.01:
            aq.multiply(aq.target_show_rate / (rate or 0.1))
            updated = True
            aq.save()
        elif rate > aq.target_show_rate + 0.01:
            aq.multiply(aq.target_show_rate / (rate or 0.1))
            updated = True
            aq.save()

    if updated:
        print_('updated weights, redeploying')
        # redeploy_latest()
        d = create_deployment('robot', 'regular deploy')
        two_step_rollout.delay(None, 'robot', d.pk)


def redeploy_latest():
    latest = Deployment.objects.filter(state='success').latest('created')
    s = latest.snapshot
    d = Deployment.objects.create(created=timezone.now(), snapshot=s,
                                  author='robot', state='new')
    two_step_rollout.delay(None, 'robot', d.pk)
