from django.db import models, transaction
from django.template.loader import render_to_string
from django.utils.text import slugify
from django.utils import timezone
from django.contrib.postgres.fields import JSONField
from collections import namedtuple, defaultdict
from . import json_template
import itertools
import json
import re
import binascii
import os
import tempfile
import zipfile
import datetime
from six import string_types, iteritems, BytesIO, python_2_unicode_compatible


def pretty_json(text):
    '''Returns indented JSON in Unicode, suitable for plugging in a form.'''
    return json.dumps(json.loads(text),
            ensure_ascii=False, indent=4, sort_keys=True)


@python_2_unicode_compatible
class CandidateUrl(models.Model):
    url = models.CharField(max_length=100, unique=True, null=True)
    internal_url_prefix = models.CharField(max_length=100, null=True)
    list = models.ForeignKey('Entry')
    author = models.CharField(max_length=30, null=True)
    product = models.CharField(max_length=100, null=True)
    json = JSONField(default=dict)

    def __str__(self):
        return u'[{} -> {}]'.format(self.url, self.list.slug)

    def as_dict(self):
        if not isinstance(self.json, dict):
            return {'malformed content': self.json}
        j = self.json.copy()
        internal_url = '/'.join((self.internal_url_prefix or '', self.url or ''))
        if len(internal_url) > 1:
            j['internal-url'] = internal_url
        if self.author:
            j["__textauthor"] = self.author
        if self.product:
            j["__product"] = self.product
        return j

    def terse(self):
        clone = {}
        for key, value in iteritems(self.json):
            if key not in ('filter', 'aux-data', 'text-subst'):
                clone[key] = value
        return json.dumps(clone, indent=4, ensure_ascii=False, sort_keys=True
                ).replace('{\n ', '{').replace('\n}', '}')


@python_2_unicode_compatible
class Entry(models.Model):
    slug = models.CharField(max_length=50, unique=True)
    json = models.TextField()
    lastchanged = models.DateTimeField()
    changedby = models.CharField(max_length=50)
    bannerid = models.IntegerField(default=0)

    def __str__(self):
        return self.slug

    def save(self, *args, **kwargs):
        if not self.lastchanged:
            self.lastchanged = timezone.now()
        return super(Entry, self).save(*args, **kwargs)

    def is_locked(self):
        return hasattr(self, 'lock')

    def _extract_bannerids(self):
        for c in self.candidates():
            if isinstance(c, dict):
                iup, _, bannerid = c.pop('internal-url', '').partition('/')
                if not bannerid:
                    bannerid = None
                author = c.pop("__textauthor", '')
                product = c.pop("__product", None)
                yield CandidateUrl(list=self, url=bannerid,
                        internal_url_prefix=iup or None,
                        author=author, product=product, json=c)

    def reindex_urls(self):
        with transaction.atomic():
            CandidateUrl.objects.filter(list=self).delete()
            CandidateUrl.objects.bulk_create(self._extract_bannerids())

    def urlid(self):
        parts = self.slug.split('_')
        if len(parts) > 1 and parts[-1] in ('ru', 'tr', 'uk'):
            parts.pop()
        return '_'.join(parts)

    def candidates(self):
        return json.loads(self.json)

    def set_candidates(self, lst):
        self.json = json.dumps(lst, ensure_ascii=False,
                sort_keys=True, separators=(',', ':'))

    def pretty_json(self):
        '''Returns indented JSON in Unicode, suitable for plugging in a form.'''
        return pretty_json(self.json)

    def flat_json(self):
        '''Returns UTF-8 string for candidates.txt.'''
        return json.dumps(json.loads(self.json), ensure_ascii=False, sort_keys=True)


class EditLog(models.Model):
    login = models.CharField(max_length=50)
    timestamp = models.DateTimeField()
    list_slug = models.CharField(max_length=50, db_index=True)

    class Meta:
        index_together = [["login", "timestamp"]]

    @classmethod
    def add(klass, login, slug):
        klass.objects.create(login=login, list_slug=slug, timestamp=timezone.now())


Substitution = namedtuple('Substitution', 'v lines')


@python_2_unicode_compatible
class BannerTemplate(models.Model):
    slug = models.SlugField(unique=True)
    json = models.TextField()

    def __str__(self):
        return self.slug

    def parse_json(self):
        return json.loads(self.json)

    def get_variables(self, lst):
        j = self.parse_json()
        var_names = defaultdict(set)

        def use_kv(var_names, k, v):
            if isinstance(v, string_types):
                for var_ref, var_name, typename in re.findall(
                        '({{([a-z][a-z0-9_]+)(:[a-z]+)?}})', v):
                    if var_name != 'banner_id':
                        location_set = var_names[var_name]
                        if v == var_ref:
                            location_set.add(k)

        def use_kv_dict(var_names, j):
            if isinstance(j, dict):
                for k, v in iteritems(j):
                    use_kv(var_names, k, v)
                aux_data = j.get('aux-data')
                if isinstance(aux_data, dict):
                    for k, v in iteritems(aux_data):
                        use_kv(var_names, 'aux-data__' + k, v)

        if isinstance(j, list):
            for i in j:
                use_kv_dict(var_names, i)
        else:
            use_kv_dict(var_names, j)

        if isinstance(lst, list):
            existing = lst
        else:
            existing = lst.candidates()
        subs = {}
        for var_name, locations in iteritems(var_names):
            values = set()
            for location in locations:
                for c in existing:
                    if location.startswith('aux-data__'):
                        try:
                            values.add(c['aux-data'][location.partition('__')[2]])
                        except:
                            pass
                    else:
                        try:
                            values.add(c[location])
                        except:
                            pass
            subs[var_name] = values
        return [Substitution(k, '\n'.join(unicode(l) for l in v))
                for k, v in iteritems(subs)]

    def replace(self, var_values):
        myself = self.parse_json()
        result = json_template.substitute(myself, dict(var_values))
        if not isinstance(result, list):
            result = [result]
        return (json.dumps(candidate,
                indent=4, ensure_ascii=False, sort_keys=True)
                for candidate in result)


class Lock(models.Model):
    entry = models.OneToOneField(Entry)
    created = models.DateTimeField(auto_now_add=True)
    author = models.CharField(max_length=50)


@python_2_unicode_compatible
class Editor(models.Model):
    login = models.CharField(max_length=50)
    allowed_slugs = models.CharField(max_length=300, null=True, blank=True)

    def __str__(self):
        return self.login

    def get_allowed_slugs(self):
        if self.allowed_slugs:
            return [s.strip() for s in self.allowed_slugs.split(',')]
        return []


class Log(models.Model):
    text = models.TextField()
    timestamp = models.DateTimeField(db_index=True)


def get_list_sources(mappings):
    list_sources = defaultdict(set)
    for m in mappings:
        if isinstance(m, string_types):
            list_sources[m].add(m)
        else:
            for entry_name, lists in iteritems(m):
                for l in lists:
                    list_sources[l].add(entry_name)
    return list_sources


def get_list_sources_from_db():
    all_keys = Entry.objects.filter(slug='all_keys')
    if len(all_keys) == 1:
        return get_list_sources(all_keys[0].candidates())
    else:
        return {}


@python_2_unicode_compatible
class Snapshot(models.Model):
    class Meta:
        ordering = ['-created']
    text = models.TextField()
    author = models.CharField(max_length=50, null=True)
    comment = models.TextField()
    trie_name = models.CharField(max_length=100)
    created = models.DateTimeField(db_index=True)
    task_id = models.IntegerField(null=True)

    def __str__(self):
        return u'{} by {}'.format(self.trie_name, self.author)

    def as_dict(self):
        if not os.path.exists(self.text):
            return {}
        with zipfile.ZipFile(self.text, 'r') as f:
            try:
                lines = f.read("snapshot.txt").decode(
                    'utf8', errors='replace'
                ).strip().split('\n')
            except:
                return {}
        tsv = (line.split('\t') for line in lines)
        return dict((t[1], t[2]) for t in tsv if len(t) == 3)

    def zipped(self):
        d = self.as_dict()
        if not d:
            return b"sorry"
        s = BytesIO()
        z = zipfile.ZipFile(s, 'w', zipfile.ZIP_DEFLATED)
        for name, value in iteritems(d):
            name = (slugify(name) + '.txt')
            formatted_json = pretty_json(value).encode('utf-8')
            z.writestr(name, formatted_json)
        z.close()
        return s.getvalue()

    def write_tsv(self, receiver):
        special_lists = ('LOAD_CONFIG',)
        stats = {'lists': {}}
        d = self.as_dict()
        mappings = json.loads(d.get('all_keys', '[]'))
        list_sources = defaultdict(set)
        for m in mappings:
            if isinstance(m, string_types):
                list_sources[m].add(m)
            else:
                for entry_name, lists in iteritems(m):
                    for l in lists:
                        list_sources[l].add(entry_name)
        stats['list_count'] = len(list_sources)
        for l, sources in iteritems(list_sources):
            all_candidates = []
            for source in sources:
                source_json = json.loads(d.get(source, '[]'))
                if l not in special_lists:
                    all_candidates.extend(source_json)
                else:
                    all_candidates = source_json
            mode = 'LOAD_CONFIG' if l in special_lists else 'default'
            list_stats = receiver.write_kv(
                l.encode('utf-8'), all_candidates, mode=mode
            )
            stats['lists'][l] = list_stats

        receiver.write_qs(list_sources.keys(), self.trie_name)
        return stats

    def get_comment(self):
        try:
            return json.loads(self.comment)['comment']
        except:
            return self.comment

    def stats(self):
        try:
            return json.loads(self.comment)
        except:
            return {'comment': self.comment, 'lists': {}}

    def print_stats(self, stats=None):
        stats = stats or self.stats()
        for l, lstats in stats['lists'].items():
            lstats['name'] = l
        stats['lists'] = stats['lists'].values()
        return render_to_string('atom/snapshot_stats.txt', stats).encode('utf-8')


@python_2_unicode_compatible
class ProductQuota(models.Model):
    production_list = models.CharField(max_length=100)
    product = models.CharField(max_length=30)
    share_min = models.FloatField(default=0.0)
    share_max = models.FloatField(default=1.0)
    bandits_multiplier = models.FloatField(default=1.0)
    fml_multiplier = models.FloatField(default=1.0)
    updated = models.DateTimeField(default=datetime.datetime(2016,1,1))

    class Meta:
        index_together = [['production_list', 'product']]

    def __str__(self):
        return '{}/{}: {}<=p<={} (b*{}, f*{})'.format(
            self.production_list, self.product, self.share_min, self.share_max,
            self.bandits_multiplier, self.fml_multiplier)


class StaleException(RuntimeError):
    pass

class TimeoutException(RuntimeError):
    pass

@python_2_unicode_compatible
class Deployment(models.Model):
    class Meta:
        ordering = ['-created']
        get_latest_by = 'created'
    author = models.CharField(max_length=50)
    comment = models.CharField(max_length=300)
    created = models.DateTimeField(db_index=True)
    is_rollback = models.BooleanField(default=False)
    state = models.CharField(max_length=20, default='new')
    snapshot = models.ForeignKey(Snapshot)
    task_id = models.IntegerField(null=True, blank=True)

    def __str__(self):
        return u'{} @{} ({})'.format(self.author, self.created, self.state)

    @classmethod
    def newest(klass):
        return klass.objects.first()

    def add_log(self, message):
        DeployLog.add(self, message)

    def can_be_forced(self):
        return (not self.is_rollback and
                self.state in ['new', 'deploy_testing', 'shooting'])

    def cleanup(self):
        if self.temp_directory:
            shutil.rmtree(self.temp_directory)
            self.temp_directory = None
            self.save(update_fields=['temp_directory'])


    def check_deployment(deployment):
        if timezone.now() - deployment.created < datetime.timedelta(hours=1):
            return
        latest = Deployment.objects.latest('created')
        if latest.id != deployment.id:
            raise StaleException("Superseded by deployment {}".format(latest))

    def state_transition(deployment, old_state, new_state):
        deployment.check_deployment()
        rows = Deployment.objects.filter(
                pk=deployment.id, state=old_state).update(state=new_state)
        if rows != 1:
            raise StaleException("Deployment {}: state out of sync".format(deployment.id))
        deployment.state = new_state


class DeployLog(models.Model):
    class Meta:
        index_together = [['deployment', 'timestamp']]
        ordering = ['timestamp']
    deployment = models.ForeignKey(Deployment)
    timestamp = models.DateTimeField(db_index=True)
    text = models.CharField(max_length=300)

    @classmethod
    def add(klass, dep, message):
        klass.objects.create(deployment=dep, text=message, timestamp=timezone.now())


class GlobalVar(models.Model):
    name = models.SlugField(db_index=True)
    value = models.CharField(max_length=100)
    author = models.CharField(max_length=50, null=True)
    modified = models.DateTimeField()

    def get_value(self):
        if self.value:
            return json.loads(self.value)

    @classmethod
    def set_value(klass, name, value, author):
        gv, created = klass.objects.get_or_create(name=name, defaults={
            'author': author, 'modified': timezone.now()})
        gv.value = json.dumps(value, ensure_ascii=False)
        gv.author = author
        gv.modified = timezone.now()
        gv.save()

    @classmethod
    def get(klass, name):
        try:
            return klass.objects.get(name=name)
        except klass.DoesNotExist:
            return None

    @classmethod
    def xchg(klass, name, oldvalue, newvalue=None, author=''):
        """
        Check that the variable ``name'' has value==oldvalue.
        If newvalue is supplied, atomically set value=newvalue.
        """
        v = klass.get(name)
        if not v or v.get_value() != oldvalue:
            raise RuntimeError("GlobalVar.xchg: oldvalue {}, expected {}".format(
                v.get_value(), oldvalue))
        if newvalue is None or newvalue == oldvalue:
            return
        with transaction.atomic():
            v = klass.get(name)
            if not v or v.get_value() != oldvalue:
                raise RuntimeError("GlobalVar.xchg: oldvalue {}, expected {}".format(
                    v.get_value(), oldvalue))
            klass.set_value(name, newvalue, author)


class StatDatum(models.Model):
    candidate = models.CharField(max_length=50)
    date = models.DateField(null=False, db_index=True)
    clicks = models.IntegerField()
    shows = models.IntegerField()
    installs = models.IntegerField(null=True)

    def info(self):
        try:
            return CandidateUrl.objects.get(url=self.candidate)
        except CandidateUrl.DoesNotExist:
            return None

    def ctr(self):
        if self.shows == 0:
            return 'NA'
        return self.clicks * 100.0 / self.shows

    def install_rate(self):
        if self.shows == 0:
            return 'NA'
        return self.installs * 100.0 / self.shows

    class Meta:
        index_together = [
            ['candidate', 'date', 'clicks', 'shows', 'installs']
        ]


@python_2_unicode_compatible
class JsonStatDatum(models.Model):
    timestamp = models.DateTimeField(null=True)
    key = models.CharField(max_length=100)
    json = JSONField(default=dict)

    def __str__(self):
        return u'{}: {}: {}'.format(self.timestamp, self.key, self.json)

    class Meta:
        index_together = [['timestamp', 'key']]

    @staticmethod
    def add(key, value, timestamp=None):
        if not timestamp:
            timestamp = timezone.now()
        return JsonStatDatum.objects.create(key=key, json=value, timestamp=timestamp)

    # TODO: purge old statdata.


@python_2_unicode_compatible
class AutoQuota(models.Model):
    production_list = models.CharField("List", max_length=100, unique=True, db_index=True)
    target_show_rate = models.FloatField()
    current_multiplier = models.FloatField(default=1.0)
    updated = models.DateTimeField(null=True)

    def __str__(self):
        return u'{}={} (x{})'.format(self.production_list, self.target_show_rate,
                                     self.current_multiplier)

    def graph(self):
        start_time = timezone.now() - datetime.timedelta(hours=48)
        all_hist = ShowHistogram.objects.filter(period_start__gte=start_time,
                production_list=self.production_list)
        sums = defaultdict(lambda: defaultdict(int))
        period_lengths = {}
        for h in all_hist:
            bins = h.bins()
            current_sums = sums[h.period_start]
            current_sums[h.success] += sum(bins)
            current_sums['0'] += bins[0] # zeros
            period_lengths[h.period_start] = h.period_duration

        if sums:
            latest_timestamp = max(sums)
            latest_bins = {}
            for h in all_hist:
                if h.period_start != latest_timestamp:
                    continue
                if h.success in latest_bins:
                    this = latest_bins[h.success]
                    for idx, b in enumerate(h.bins()):
                        this[idx] += b
                else:
                    latest_bins[h.success] = h.bins()
        else:
            latest_bins = {True: [0, 0], False: [0, 0]}

        graph_width = 640.0
        graph_height = 100.0
        def dt_to_graph(dt):
            return graph_width * dt.total_seconds() / (3600*48)
        def ratio_to_graph(n, d):
            return graph_height * n / (d + 0.000001)
        return {
            'target': ratio_to_graph(1.0 - self.target_show_rate, 1.0),
            'datapoints': [{
                'zeros': ratio_to_graph(v['0'], v[True] + v[False]),
                'x': dt_to_graph(k - start_time),
                'w': dt_to_graph(period_lengths[k]),
                'y': ratio_to_graph(v[False], v[True] + v[False])}
                for k, v in sums.iteritems()],
            'latest': {
                'bins': {
                    'show': latest_bins.get(True, [0])[1:],
                    'noshow': latest_bins.get(False, [0])[1:]},
            },
        }

    def multiply(self, m):
        self.current_multiplier *= m
        self.updated = timezone.now()


@python_2_unicode_compatible
class ShowHistogram(models.Model):
    client = models.CharField(max_length=30)
    subclient = models.CharField(max_length=30)
    production_list = models.CharField(max_length=100)
    fml = models.CharField(max_length=100)

    period_start = models.DateTimeField(db_index=True)
    period_duration = models.DurationField()
    success = models.BooleanField()
    histogram = models.TextField()

    class Meta:
        index_together = [
            ['production_list', 'period_start'],
        ]

    def __str__(self):
        return

    def bins(self):
        return [int(v) for v in self.histogram.split()]

    def bin_sum(self):
        return sum(self.bins())


class ShowMeasurement(models.Model):
    timestamp = models.DateTimeField()
    show_rate = models.FloatField()
    multiplier = models.FloatField()


@python_2_unicode_compatible
class AuthToken(models.Model):
    key = models.CharField("Key", max_length=40, primary_key=True)
    login = models.CharField("Login", max_length=30)
    created = models.DateTimeField("Created")

    class Meta:
        verbose_name = "Token"
        verbose_name_plural = "Tokens"

    def save(self, *args, **kwargs):
        if not self.key:
            self.key = self.generate_key()
        if not self.created:
            self.created = timezone.now()
        return super(AuthToken, self).save(*args, **kwargs)

    def generate_key(self):
        return binascii.hexlify(os.urandom(20)).decode()

    def __str__(self):
        return self.key
