# coding=utf-8

import datetime
import logging

import pandas


CH_BATCH_SIZE = 1000


logger = logging.getLogger(__name__)


SLICES_INTERNAL = (
    ('own',),
    ('dc',),

    ('own', 'dc'),
    ('src_own',),
    ('src_own', 'dc'),
    ('dst_own',),
    ('dst_own', 'dc'),

    ('macro',),
    ('macro', 'dc'),
    ('src_macro',),
    ('src_macro', 'dc'),
    ('dst_macro',),
    ('dst_macro', 'dc'),
)


SLICES_EXTERNAL = (
    ('own',),

    #('own', 'dc'),
    ('src_own',),
    ('src_own', 'iface'),
    ('dst_own',),
    ('dst_own', 'iface'),

    #('macro',),
    ('macro', 'iface'),
    ('src_macro',),
    ('src_macro', 'iface'),
    ('dst_macro',),
    ('dst_macro', 'iface'),
)


class Metering:
    def __init__(self, db, hour, dimensions, query):
        self.db = db
        self.query = query
        self.dimensions = dimensions
        self.hour = datetime.datetime.fromtimestamp(hour)
        self.hour_ts = hour
        self.values = ['packets', 'traf']
        self.df = pandas.DataFrame(columns=self.dimensions + self.values)
        self.load()
        self.result = None

    def load(self):
        query = self.query.format(ts=self.hour_ts)

        logger.info(query)

        cursor = self.db.cursor()
        cursor.execute(query)

        while True:
            batch = cursor.fetchmany(CH_BATCH_SIZE)
            if not batch:
                break
            logger.debug("Fetched %s records", len(batch))
            self.df = self.df.append(pandas.DataFrame(data=batch, columns=self.df.columns))

        for d in self.dimensions:
            self.df[d].replace(to_replace='', value='(unknown)', inplace=True)

        for v in self.values:
            self.df[v] = self.df[v].astype(int)

    def generate_slices(self, slices):
        def generate_slice(slice_columns):
            cols = []
            for scol in slice_columns:
                # 'dc' matches both 'src_dc' and 'dst_dc' if present
                more = [x for x in self.dimensions if x.endswith(scol)]
                if not more:
                    return
                cols.extend(more)
            return cols

        totals = [self.df]
        for s in slices:
            extended_slices = generate_slice(s)
            if not extended_slices:
                continue
            totals.append(self.df.groupby(extended_slices).sum().reset_index())

        self.result = pandas.concat(totals, sort=False).fillna('_total_')
        self.result['fielddate'] = self.hour.strftime('%Y-%m-%d %H:%M:%S')
