# coding=utf-8
from collections import defaultdict
import datetime
import json
import logging
import multiprocessing
import os
import random
import re
import time

from sandbox import sdk2
from sandbox.sandboxsdk.environments import PipEnvironment

logger = logging.getLogger(__name__)

CH_SERVER = 'mondb.mds.yandex.net'
CH_USER = 'grafana_ro'
CH_DB = 'tacct_db'
STAT_REPORT_INTERNAL = 'NOC/test1'
STAT_REPORT_EXTERNAL = 'NOC/ExternalTraffic2'

TVM_OUR_ID = 2020661
TVM_LB_ID = 2001059  # lbkx cluster

LB_TOPIC = 'yc/yandex/billing-noc'
LB_ENDPOINT = 'lbkx.logbroker.yandex.net'
LB_MAX_CLIENTS_NUMBER = 128
LB_SOURCE_ID = "ExportNetworkMetering"

USE_TVM = True

# Debugging only.
PERIOD = 3600  # Should be 3600 because it's an HOUR


class ExportNetworkMetering(sdk2.Task):
    """
    Export metering data from NOC's CH to Stat
    """

    class Requirements(sdk2.Requirements):
        environments = (
            PipEnvironment('clickhouse_driver'),
            PipEnvironment('python-statface-client'),
        )

    class Parameters(sdk2.Task.Parameters):

        ch_server = sdk2.parameters.String(
            'Clickhouse server',
            default_value=CH_SERVER,
            required=True,
        )

        ch_user = sdk2.parameters.String(
            'Clickhouse user',
            default_value=CH_USER,
            required=True,
        )

        ch_db = sdk2.parameters.String(
            'Clickhouse database',
            default_value=CH_DB,
            required=True,
        )

        custom_period = sdk2.parameters.Bool(
            'Export a custom period',
            default_value=False,
            required=True
        )

        with custom_period.value[False]:
            hour = sdk2.parameters.Timedelta('Hour to export')

        with custom_period.value[True]:
            time_from = sdk2.parameters.String(
                'Datetime from which to export (either 2020-01-01 or 2020-01-01T12:00:00)',
                required=False,
                default_value='2020-01-01'
            )
            time_to = sdk2.parameters.String(
                'Datetime to which to export (either 2020-01-01 or 2020-01-01T12:00:00)',
                required=False,
                default_value=datetime.datetime.utcnow().date().isoformat()
            )

        process_interdc = sdk2.parameters.Bool(
            'Process interDC (DCI) traffic',
            default_value=True,
            required=True,
        )

        process_external = sdk2.parameters.Bool(
            'Process external traffic',
            default_value=True,
            required=True,
        )

        process_totals = sdk2.parameters.Bool(
            'Also send totals data from Solomon (separately)',
            default_value=True,
            required=True,
        )

        process_traffic = sdk2.parameters.Bool(
            'Process traffic data (bytes sent)',
            default_value=True,
            required=True,
        )

        process_p95 = sdk2.parameters.Bool(
            'Export 95th percentile usage: Russia, Russia-MAN, external',
            default_value=True,
            required=True,
        )

        with process_totals.value[True]:
            solomon_oauth = sdk2.parameters.Vault(
                'Vault item with Solomon OAuth token', required=True,
                default_value='solomon_oauth')

        export_to_statface = sdk2.parameters.Bool(
            'Export data to Statface',
            default_value=True,
            required=True,
        )

        with export_to_statface.value[True]:
            stat_oauth = sdk2.parameters.Vault(
                'Vault item with Stat OAuth token', required=True,
                default_value='statbox_oauth')

        export_to_billing = sdk2.parameters.Bool(
            'Export data to YC Billing via Logbroker',
            default_value=True,
            required=True,
        )

        with export_to_billing.value[True]:
            lb_endpoint = sdk2.parameters.String(
                'Logbroker endpoint',
                default_value=LB_ENDPOINT,
                required=True,
            )

            lb_topic = sdk2.parameters.String(
                'Logbroker topic',
                default_value=LB_TOPIC,
                required=True,
            )

            lb_port = sdk2.parameters.Integer(
                'Logbroker port',
                default_value=2135,
                required=False,
            )

            # Billing requests that we space out sending data
            # because their monitoring lights up every hour.
            # 250 ms pauses between batches give us about
            # 0.5x posting speed.
            billing_throttle = sdk2.parameters.Float(
                'Billing export pause between batches (seconds)',
                default_value=0.25,
                required=False)

            use_tvm = sdk2.parameters.Bool(
                'Use TVM for Logbroker (when off, use OAuth)',
                default_value=USE_TVM,
                required=True,
            )

            with use_tvm.value[True]:
                tvm_secret = sdk2.parameters.Vault(
                    'Vault item with the TVM secret', required=True,
                    default_value='NetworkTrafficMetering TVM secret'
                )
            with use_tvm.value[False]:
                lb_oauth = sdk2.parameters.Vault(
                    'Vault item with Logbroker OAuth token', required=True,
                    default_value='NetworkTrafficMetering Logbroker OAuth'
                )

        dump_data = sdk2.parameters.Bool(
            # Avoid using sdk2.Resource here, much easier on the quota.
            'Dump raw data to log',
            default_value=False,
            required=False
        )

        with sdk2.parameters.Group('Secrets') as vault_block:
            ch_password = sdk2.parameters.Vault(
                'Vault item with CH password', required=True,
                default_value='metering_ch_password')

    def on_execute(self):
        impl = ExportNetworkMeteringImpl(self.Parameters)
        impl.on_execute()


class ExportNetworkMeteringParams(object):
    def __init__(self, source):
        if isinstance(source, sdk2.Parameters):
            def get(param, default):
                attr = getattr(source, param)
                if isinstance(attr, sdk2.VaultItem):
                    return attr.data()
                if attr is None:
                    return default
                return attr
        else:
            def get(param, default):
                return source.get(param, default)

        all_params = [
            ("ch_server", CH_SERVER),
            ("ch_user", CH_USER),
            ("ch_password", ""),
            ("ch_db", CH_DB),
            ("custom_period", False),
            ("hour", None),
            ("time_from", None),
            ("time_to", None),
            ("process_interdc", False),
            ("process_external", False),
            ("process_p95", False),
            ("process_traffic", False),
            ("process_totals", False),
            ("solomon_oauth", None),
            ("export_to_statface", False),
            ("stat_oauth", None),
            ("export_to_billing", False),
            ("lb_endpoint", LB_ENDPOINT),
            ("lb_topic", LB_TOPIC),
            ("lb_port", 2135),
            ("billing_throttle", 0.25),
            ("use_tvm", True),
            ("tvm_secret", None),
            ("lb_oauth", None),
            ("dump_data", False),
        ]

        for param_name, param_default in all_params:
            setattr(self, param_name, get(param_name, param_default))


class ExportNetworkMeteringImpl(object):
    def __init__(self, params):
        self.Parameters = ExportNetworkMeteringParams(params)

    def on_execute(self):
        self._connect_db()
        for period_start in self._periods():
            self._export_one_period(period_start)

    def _export_one_period(self, period_start):
        import sandbox.projects.noc.ExportNetworkMetering.ch as ch
        import sandbox.projects.noc.ExportNetworkMetering.format as format
        import sandbox.projects.noc.ExportNetworkMetering.metering as metering
        import sandbox.projects.noc.ExportNetworkMetering.util as util
        from sandbox.projects.noc.ExportNetworkMetering.peers import get_peers
        from sandbox.projects.noc.ExportNetworkMetering.solomon import Solomon

        p = self.Parameters
        solomon = Solomon(p.solomon_oauth)

        if p.process_interdc and p.process_traffic:
            logger.info("Fetching interdc traffic data")
            metering_interdc = self._fetch_metering(
                ch.QUERY_INTERDC, ch.DIMENSIONS_INTERDC,
                metering.SLICES_INTERNAL,
                period_start)
            data_interdc = metering_interdc.result.to_dict('records')

            if p.process_totals:
                interdc_interfaces = self._get_interdc_interfaces(
                    period_start=period_start)
                data_interdc_totals = solomon.get_data(
                    interdc_interfaces,
                    datetime.datetime.fromtimestamp(period_start),
                    datetime.datetime.fromtimestamp(period_start + PERIOD))
            if p.dump_data:
                logger.debug("DUMPING interdc source data:")
                logger.debug(json.dumps(data_interdc, indent=2))
                if p.process_totals:
                    logger.debug("DUMPING interdc totals source data:")
                    logger.debug(json.dumps(data_interdc_totals, indent=2))

        if p.process_external and p.process_traffic:
            logger.info("Fetching external traffic data")
            metering_external = self._fetch_metering(
                ch.QUERY_EXTERNAL, ch.DIMENSIONS_EXTERNAL,
                metering.SLICES_EXTERNAL,
                period_start)
            # TODO pandas.DF.to_dict('records') takes 43 seconds!
            data_external = metering_external.result.to_dict('records')

            if p.process_totals:
                external_interfaces = defaultdict(lambda: [])
                for k, v in get_peers():
                    external_interfaces[k].append(v)
                data_external_totals = solomon.get_data(
                    external_interfaces,
                    datetime.datetime.fromtimestamp(period_start),
                    datetime.datetime.fromtimestamp(period_start + PERIOD))
            if p.dump_data:
                logger.debug("DUMPING external source data:")
                logger.debug(json.dumps(data_external, indent=2))
                if p.process_totals:
                    logger.debug("DUMPING external totals source data:")
                    logger.debug(json.dumps(data_external_totals, indent=2))

        if p.export_to_statface and p.process_traffic:
            logger.debug("Sending data to Statface")
            if p.process_interdc:
                self._to_statface(data_interdc, STAT_REPORT_INTERNAL)
            if p.process_external:
                self._to_statface(data_external, STAT_REPORT_EXTERNAL)
        else:
            logger.info("SKIPPING sending data to Statface")

        if p.export_to_billing:
            logger.info("Fetching P95 data for billing")
            month = datetime.datetime.fromtimestamp(period_start).replace(
                day=1, hour=0, minute=0, second=0)
            next_month = util.next_month(month)
            if p.process_interdc:
                p95_dci_ru = self._fetch_p95(
                    query_template=ch.QUERY_P95_TEMPLATE,
                    query_args=dict(
                        table=ch.INTERDC_TABLE,
                        month=month,
                        next_month=next_month,
                        geo_filter=ch.GEO_FILTER_RU))
                p95_dci_fi = self._fetch_p95(
                    query_template=ch.QUERY_P95_TEMPLATE,
                    query_args=dict(
                        table=ch.INTERDC_TABLE,
                        month=month,
                        next_month=next_month,
                        geo_filter=ch.GEO_FILTER_FI))
                formatted_p95_dci_ru = format.format_p95(
                    p95_dci_ru, format.SKU_DCI_RU, month=month, next_month=next_month)
                formatted_p95_dci_fi = format.format_p95(
                    p95_dci_fi, format.SKU_DCI_FI, month=month, next_month=next_month)

                if p.dump_data:
                    logger.debug("DUMPING p95 DCI-RU data")
                    logger.debug(json.dumps(formatted_p95_dci_ru))
                    logger.debug("DUMPING p95 DCI-FI data")
                    logger.debug(json.dumps(formatted_p95_dci_fi))

            if p.process_external:
                p95_ext = self._fetch_p95(
                    query_template=ch.QUERY_P95_TEMPLATE,
                    query_args=dict(
                        table=ch.EXTERNAL_TABLE,
                        month=month,
                        next_month=next_month,
                        macro_filter=ch.SOME_INTERNET))
                formatted_p95_ext = format.format_p95(
                    p95_ext, format.SKU_EXT, month=month, next_month=next_month)
                if p.dump_data:
                    logger.debug("DUMPING p95 EXT data")
                    logger.debug(json.dumps(formatted_p95_ext))

            logger.info("Sending data to Logbroker (billing)")

            if p.process_interdc:
                self._to_logbroker(formatted_p95_dci_ru)
                self._to_logbroker(formatted_p95_dci_fi)
            if p.process_external:
                self._to_logbroker(formatted_p95_ext)

            """
            if p.process_interdc:
                formatted_interdc = format.format_interdc(data_interdc, period_start)
                self._to_logbroker(formatted_interdc)
                if p.dump_data:
                    logger.debug("DUMPING interdc data formatted for billing")
                    logger.debug(json.dumps(formatted_interdc, indent=2))
                if p.process_totals:
                    formatted_interdc_totals = format.format_interdc_totals(
                        data_interdc_totals, period_start)
                    self._to_logbroker(formatted_interdc_totals)
                    if p.dump_data:
                        logger.debug("DUMPING interdc totals data formatted for billing")
                        logger.debug(json.dumps(formatted_interdc_totals, indent=2))
            if p.process_external:
                formatted_external = format.format_external(data_external, period_start)
                self._to_logbroker(formatted_external)
                if p.dump_data:
                    logger.debug("DUMPING external data formatted for billing")
                    logger.debug(json.dumps(formatted_external, indent=2))
                if p.process_totals:
                    formatted_external_totals = format.format_external_totals(
                        data_external_totals, period_start)
                    self._to_logbroker(formatted_external_totals)
                    if p.dump_data:
                        logger.debug("DUMPING external totals data formatted for billing")
                        logger.debug(json.dumps(formatted_external_totals, indent=2))
            """
        else:
            logger.info("SKIPPING fetching p95 and sending data to Logbroker (billing)")

        logger.info("Done exporting %s (%s).", period_start,
                    datetime.datetime.fromtimestamp(period_start))

    @staticmethod
    def _parse_datetime(text):
        text = text.strip()
        # Because no .fromisoformat in sandboxxy Python2
        # UTC everywhere
        if re.match(r'^[0-9]{4}-[0-9]{2}-[0-9]{2}$', text):
            return datetime.datetime.strptime(text, '%Y-%m-%d')
        if 'T' in text:
            return datetime.datetime.strptime(text, '%Y-%m-%dT%H:00:00')
        raise ValueError('Should be either yyyy-mm-dd or yyyy-mm-ddThh:00:00')

    def _hour(self):
        hour = self.Parameters.hour
        if not hour:
            hour = int(time.time()) - PERIOD
            hour -= hour % PERIOD
        return hour

    def _periods(self):
        import sandbox.projects.noc.ExportNetworkMetering.util as util
        if not self.Parameters.custom_period:
            yield self._hour()
            return

        assert self.Parameters.time_from < self.Parameters.time_to
        time_from = util.to_timestamp(
            self._parse_datetime(self.Parameters.time_from))
        time_to = util.to_timestamp(
            self._parse_datetime(self.Parameters.time_to))
        assert time_from < time_to

        assert (time_from % PERIOD, time_to % PERIOD) == (0, 0)
        while time_from < time_to:
            yield time_from
            time_from += PERIOD

    def _connect_db(self):
        from clickhouse_driver import connect

        # connect to clockhouse
        server = self.Parameters.ch_server
        user = self.Parameters.ch_user
        password = self.Parameters.ch_password
        database = self.Parameters.ch_db
        logger.debug("Connecting to CH on %s as %s", server, user)
        self._db = connect(
            host=server, user=user, password=password, database=database)

    def _fetch_metering(
            self, query, dimensions, slices_def, period_start):
        import sandbox.projects.noc.ExportNetworkMetering.metering as met

        # import metering data
        logger.info('Hour to export: {}'.format(period_start))
        metering = met.Metering(
            self._db,
            period_start,
            dimensions,
            query
        )
        logger.debug("Done running Metering.__init__")
        metering.generate_slices(slices_def)
        logger.info('{}'.format(metering.result))
        return metering

    def _fetch_p95(self, query_template, query_args):
        import sandbox.projects.noc.ExportNetworkMetering.ch as ch
        import sandbox.projects.noc.ExportNetworkMetering.util as util
        cursor = self._db.cursor()
        for field in ("geo_filter", "macro_filter"):
            if field not in query_args:
                query_args[field] = ""
        samples_in_month = util.samples_in_month(query_args["month"])
        zero_p95_reports = False
        if datetime.datetime.now() >= query_args["month"]:
            # Number of 5-minute samples available from the beginning of the month.
            available_samples = (
                datetime.datetime.now() - query_args["month"]).total_seconds() // 300
            if available_samples < samples_in_month // 20:
                zero_p95_reports = True

        query_args["month"] = query_args["month"].date().isoformat()
        query_args["next_month"] = query_args["next_month"].date().isoformat()

        cursor.execute("SET max_query_size = 1000000")
        cursor.execute("SET max_memory_usage = 100000000000")
        cursor.fetchmany(1000)

        cursor.execute(query_template.format(**query_args))

        # TODO Replace with a sandbox task!
        from sandbox.projects.noc.ExportNetworkMetering.attribution import attribution as attribution_tmp
        attribution = {
            item: abc
            for abc, items in attribution_tmp.items()
            for item in items
        }
        per_abc_p95_sum = defaultdict(int)

        while True:
            batch = cursor.fetchmany(ch.BATCH_SIZE)
            if not batch:
                break
            logger.debug("Fetched %s p95 records", len(batch))
            for p95, macro_name in batch:
                owner = attribution.get(macro_name, "net")
                per_abc_p95_sum[owner] += 0 if zero_p95_reports else p95

        res = []
        for owner, p95 in per_abc_p95_sum.items():
            res.append(dict(
                owner=owner,
                p95=p95))
        return res

    def _get_interdc_interfaces(self, period_start):
        import sandbox.projects.noc.ExportNetworkMetering.ch as ch
        cursor = self._db.cursor()
        cursor.execute(ch.QUERY_INTERDC_INTERFACES.format(ts=period_start))
        logger.debug("Fetching interdc interfaces in current use.")

        res = defaultdict(lambda: [])

        while True:
            batch = cursor.fetchmany(ch.BATCH_SIZE)
            if not batch:
                break
            logger.debug("Fetched %s records", len(batch))
            for datum in batch:
                res[datum[0]].append(datum[1])
        return res

    def _to_statface(self, data, report):
        # export to stat
        import statface_client
        oauth_token = self.Parameters.stat_oauth

        logger.debug("Statface client init.")
        sfcli = statface_client.StatfaceClient(
            host='upload.stat.yandex-team.ru', oauth_token=oauth_token)

        report = sfcli.get_report(report)
        logger.debug("Uploading data of length %i to Statface.", len(data))
        report.upload_data(scale=statface_client.HOURLY_SCALE, data=data)
        logger.debug("Done Statface.")

    def _make_logbroker_auth(self):
        import kikimr.public.sdk.python.persqueue.auth as auth
        from tvmauth import BlackboxTvmId as BlackboxClientId
        import tvm2

        assert isinstance(self.Parameters.use_tvm, bool)
        if self.Parameters.use_tvm:
            assert self.Parameters.tvm_secret is not None
            tvm_cli = tvm2.TVM2(
                client_id=TVM_OUR_ID,
                secret=self.Parameters.tvm_secret,
                blackbox_client=BlackboxClientId.Prod
            )
            # Silence the 300-megabyte flood in teh logz.
            logging.getLogger('tvm2.sync.thread_tvm2').setLevel(logging.WARNING)
            return auth.TVMCredentialsProvider(tvm_cli, TVM_LB_ID)
        else:
            return auth.OAuthTokenCredentialsProvider(
                self.Parameters.lb_oauth)

    def _to_logbroker(self, data):
        from sandbox.projects.noc.ExportNetworkMetering.logbroker import LogBroker
        # To save on huge recursive imports
        self._LogBroker = LogBroker

        workers = []
        clients_number = min(len(data), LB_MAX_CLIENTS_NUMBER)
        for i in range(clients_number):
            worker = multiprocessing.Process(
                target=self._spawn_logbroker,
                # Because https://logbroker.yandex-team.ru/docs/concepts/resource_model#message-group
                # assigns workers to partitions "randomly" and thus very unevenly, often creating a very
                # overworked partition or two.  In order to prevent it from sticking if it ever happens,
                # we randomize source ids.  Old source-partition mappings  will be GCed after 14 days.
                args=("%s_%04i_%04i" % (LB_SOURCE_ID, random.randint(0, 9999), i),
                      data[i::clients_number],
                      i)
            )
            worker.start()
            workers.append(worker)
            logger.info("Logbroker clients: started %s/%s", i + 1, clients_number)
        for i, worker in enumerate(workers):
            logger.debug("Worker state: %s", worker)
            worker.join()
            logger.info("Logbroker clients: stopped %s/%s", i + 1, clients_number)

    def _spawn_logbroker(self, source_id, messages, num):
        logger.info("New process spawned, PID %i", os.getpid())
        lb = self._LogBroker(
            cred_provider=self._make_logbroker_auth(),
            endpoint=self.Parameters.lb_endpoint,
            port=self.Parameters.lb_port,
            topic=self.Parameters.lb_topic,
            source_id=source_id,
            throttle_pause=self.Parameters.billing_throttle,
        )
        logger.info("Will send %i messages to Logbroker.", len(messages))
        sleep_dur = num * self.Parameters.billing_throttle / LB_MAX_CLIENTS_NUMBER
        logger.debug("But before that, will sleep for %2.2f seconds.", sleep_dur)
        time.sleep(sleep_dur)
        lb.send(messages)


