# coding: utf-8

from __future__ import unicode_literals

import calendar
import logging
import tempfile
from copy import copy
from operator import itemgetter

import requests

import sandbox.common.types.client as ctc
import sandbox.common.types.misc as ctm
from sandbox import sdk2
from sandbox.common.types import task as ctt
from sandbox.projects.common import binary_task

MAX_ROWS_IN_REQUEST = 100000


class FetchJanglesPasswords(binary_task.LastBinaryTaskRelease, sdk2.Task):
    # upload to sandbox via ./fetch_jangles_passwords upload --attr 'name=FetchJanglesPasswords' and release resource
    TASKS_RESOURCE_NAME = 'FetchJanglesPasswords'
    LOCAL_PORT = 8080
    _pswd_key_object = None

    class Requirements(sdk2.Task.Requirements):
        client_tags = ctc.Tag.LINUX_BIONIC
        dns = ctm.DnsType.DNS64
        cores = 1  # vCores
        ram = 1024  # Mb
        disk_space = 2 * 1024

        class Caches(sdk2.Requirements.Caches):
            pass  # means that task do not use any shared caches

    # noinspection DuplicatedCode
    class Parameters(binary_task.LastBinaryReleaseParameters):

        with sdk2.parameters.Group('SSH'):
            user = sdk2.parameters.String('SSH user', default='yandex')
            # noinspection SpellCheckingInspection
            host = sdk2.parameters.String('SSH host', default='changhong.jangles.yandex')
            port = sdk2.parameters.Integer('SSH port', default=8543)

        with sdk2.parameters.Group('Filters'):
            after_timestamp = sdk2.parameters.Float('after timestamp', default=0.0)
            device_model = sdk2.parameters.String('Device model', default='yandexmodule_2')

        with sdk2.parameters.Group('YT'):
            with sdk2.parameters.RadioGroup('YT Cluster') as yt_cluster:
                yt_cluster.values.hahn = yt_cluster.Value(value='hahn', default=True)
                yt_cluster.values.arnold = yt_cluster.Value(value='arnold')

            yt_path = sdk2.parameters.String('Target table', required=True)
            # noinspection SpellCheckingInspection
            yt_secret = sdk2.parameters.YavSecret('YAV secret YT token', default='sec-01d2ffwrdbwyj37zkj4r8zegsn')
            yt_secret_key = sdk2.parameters.String('Key to extract from YAV secret',
                                                   default_value='robot-quasar-yt-token')

        with sdk2.parameters.Group('YAV'):
            # noinspection SpellCheckingInspection
            yav_secret = sdk2.parameters.YavSecret('YAV secret with ssh key',
                                                   default='sec-01f7txcd9ctaxng3y77skaecp0')
            yav_secret_key = sdk2.parameters.String('Key to extract from YAV secret',
                                                    default_value='yandex.key-txt')
            yav_public_pswd_key = sdk2.parameters.String('Key to extract public password key from YAV secret',
                                                         default_value='public_pswd.pem')

    @property
    def binary_executor_query(self):
        return {"attrs": {'released': ctt.ReleaseStatus.STABLE, 'name': self.TASKS_RESOURCE_NAME}}

    @property
    def pswd_key_object(self):
        if self._pswd_key_object is None:
            from cryptography.hazmat.backends import default_backend
            from cryptography.hazmat.primitives import serialization

            self._pswd_key_object = serialization.load_pem_public_key(
                self.password_key.encode(),
                backend=default_backend()
            )

        return self._pswd_key_object

    def encrypt(self, message):
        from cryptography.hazmat.primitives import hashes, asymmetric

        return self.pswd_key_object.encrypt(
            message.encode(),
            asymmetric.padding.OAEP(
                mgf=asymmetric.padding.MGF1(algorithm=hashes.SHA256()),
                algorithm=hashes.SHA256(),
                label=None
            )
        )

    def patch_fields(self, devices):
        from dateutil.parser import isoparse
        from dateutil.tz import UTC

        def timestamp(utc_dt):
            return calendar.timegm(utc_dt.timetuple()) + float(utc_dt.microsecond) / 1000000

        ts_key = 'bake_timestamp'
        patched_devices = []
        for device in devices:
            device_cur = copy(device)
            utc_dt = isoparse(device[ts_key]).astimezone(UTC).replace(tzinfo=None)
            device_cur[ts_key] = timestamp(utc_dt)

            encryption = device_cur.pop('encryption')
            if encryption is None:
                logging.info("Skipping device with device_uuid = %s", device_cur['device_uuid'])
                continue
            password = self.encrypt(encryption['password'])
            dvuk = None
            jtag = None
            if 'dvuk' in encryption and encryption['dvuk'] is not None:
                dvuk = self.encrypt(encryption['dvuk'])
            if 'jtag' in encryption and encryption['jtag'] is not None:
                jtag = self.encrypt(encryption['jtag'])

            chip_id = encryption['chip_id']

            device_cur['password'] = str(password)
            device_cur['chip_id'] = str(chip_id)
            device_cur['dvuk'] = str(dvuk) if dvuk else ''
            device_cur['jtag'] = str(jtag) if jtag else ''
            patched_devices.append(device_cur)

        patched_devices.sort(key=itemgetter(ts_key))
        return patched_devices

    @property
    def last_ts_query(self):
        return "bake_timestamp FROM [{}] ORDER BY bake_timestamp DESC LIMIT 1".format(self.Parameters.yt_path)

    def get_last_row(self):
        import yt.wrapper as yt
        client = yt.YtClient(self.Parameters.yt_cluster,
                             config={'backend': 'rpc'},
                             token=yt.config['token'])
        client.mount_table(self.Parameters.yt_path, sync=True)
        result = list(client.select_rows(self.last_ts_query))
        return result[0] if result else None

    def upload_to_yt(self, data):
        import yt.wrapper as yt

        client = yt.YtClient(self.Parameters.yt_cluster,
                             config={'backend': 'rpc'},
                             token=yt.config['token'])
        table = self.Parameters.yt_path
        if not yt.exists(table):
            from quasar.manufacturing.lib import DEVICE_PASSWORD_SCHEMA
            client.create('table', table, force=True, attributes={'schema': DEVICE_PASSWORD_SCHEMA, 'dynamic': True})
        client.mount_table(table, sync=True)
        if len(data) < MAX_ROWS_IN_REQUEST:
            client.insert_rows(table, data, format='json')
        else:
            for i in range(len(data) // MAX_ROWS_IN_REQUEST + 1):
                client.insert_rows(table, data[i * MAX_ROWS_IN_REQUEST:i * MAX_ROWS_IN_REQUEST + MAX_ROWS_IN_REQUEST])

    def get_passwords_url(self, ts):
        from furl import furl

        url = furl('http://localhost')
        url.port = 8080
        # noinspection PyPropertyAccess
        url.path = 'api/v1/report/passwords/'
        url.args['after'] = '{:f}'.format(ts)
        url.args['device_model'] = self.Parameters.device_model

        logging.info('requested url: %s', url.url)
        return url.url

    def calculate_ts_param(self):
        import yt.wrapper as yt

        ts = float(self.Parameters.after_timestamp)
        if ts > 0:
            return ts

        if yt.exists(self.Parameters.yt_path):
            row = self.get_last_row()
            if row:
                return row['bake_timestamp']

        return ts

    def init_yt(self):
        import yt.wrapper as yt

        yt.config.set_proxy(self.Parameters.yt_cluster)
        yt.config['token'] = self.Parameters.yt_secret.data()[self.Parameters.yt_secret_key]

    @property
    def ssh_key(self):
        return self.Parameters.yav_secret.data()[self.Parameters.yav_secret_key]

    @property
    def password_key(self):
        return self.Parameters.yav_secret.data()[self.Parameters.yav_public_pswd_key]

    def write_key_to_file(self, filename):
        with open(filename, 'w') as f:
            f.write(self.ssh_key)
            f.flush()

    def on_execute(self):
        from sandbox.projects.quasar.jangles_lib import SshClient
        self.init_yt()

        # noinspection PyUnusedLocal
        devices = None
        with tempfile.NamedTemporaryFile(delete=True) as ssh_key_file:
            self.write_key_to_file(ssh_key_file.name)
            ssh = SshClient(self.Parameters.host, self.Parameters.port, self.Parameters.user, ssh_key_file.name)

            try:
                ssh.forward_jangles_port_to(self.LOCAL_PORT)
                ts = self.calculate_ts_param()
                url = self.get_passwords_url(ts)

                response = requests.get(url)
                logging.debug('jangles response code: %d', response.status_code)
                if response.status_code == 400:
                    logging.debug('client error, details: %s', response.json())

                response.raise_for_status()
                devices = response.json()['content']
            except requests.exceptions.ConnectionError:
                raise Exception('localhost connection error, jangles is down?')
            except requests.exceptions.HTTPError:
                raise Exception('invalid jangles response code')
            finally:
                ssh.terminate_connect()

        if devices is None:
            logging.info('no new data')
            return

        yt_data = self.patch_fields(devices)
        self.upload_to_yt(yt_data)

        logging.info('New %d passwords received: %s', len(devices))
