# coding: utf-8
from __future__ import absolute_import, division, print_function, unicode_literals

import logging
from collections import OrderedDict

import boto3
import requests

from travel.library.python.rasp_vault.api import get_secret
from travel.rasp.bus.db import session_scope
from travel.rasp.bus.db.models.matching import PointMatching, PointType, parse_point_key
from travel.rasp.bus.scripts.automatcher.scenarios.base import BaseMatcher
from travel.rasp.bus.scripts.automatcher.policy import TypePolicy

log = logging.getLogger(__name__)


class UnitikiPointIdUpdate(BaseMatcher):
    name = 'unitiki_update'
    point_type_policy = TypePolicy.TYPE_STATION
    supplier = 'unitiki-new'

    S3_ENDPOINT = 'https://s3.mds.yandex.net'
    S3_BUCKET = 'yandex-bus.automatcher'
    S3_ACCESS_KEY = 'sec-01d956rm0wknnrbhhjsyx0rh56.robot-sputnik-s3-mds-key'
    S3_ACCESS_SECRET_KEY = 'sec-01d956rm0wknnrbhhjsyx0rh56.robot-sputnik-s3-mds-secret'
    S3_LAST_ACTION_FN = 'last_action.txt'
    S3_BACKUP_FN = 'last_action_backup.txt'

    UNITIKI_MERGE_LIST_URL = 'http://api.geo.gds.unitiki.com/station/merge/list/'

    def __init__(self, **params):
        super(UnitikiPointIdUpdate, self).__init__(**params)

        log.info('prepare scenario: %s', self.name)
        common_config = self.get_config(params, 'common')

        self.dry_run = common_config['dry']

        s3_key = get_secret(self.S3_ACCESS_KEY)
        s3_secret_key = get_secret(self.S3_ACCESS_SECRET_KEY)

        self.s3 = self._connect_s3(s3_key, s3_secret_key, self.S3_ENDPOINT)

        log.info('retriving last action ID from S3')
        self.prev_last_action_id = self._get_last_action()
        if not self.prev_last_action_id:
            log.info('no valid last action ID in S3')
        else:
            log.info('got last action ID: %s', self.prev_last_action_id)

        mapping = self._get_mapping()
        point_changes_list = self._fetch_changes()
        log.info("got changes, count: %d", len(point_changes_list))

        self.changes = self._get_changes(point_changes_list, mapping)
        log.info("new last action ID: %s", str(self.last_action_id))

    def _run(self, point):
        unitiki_station_id = int(point.supplier_point_id[1:])
        if unitiki_station_id in self.changes:
            return True, self.changes[unitiki_station_id]
        return False, None

    def _connect_s3(self, access_key, secret_key, endpoint):
        session = boto3.session.Session(
            aws_access_key_id=access_key,
            aws_secret_access_key=secret_key,
        )
        return session.client(service_name='s3', endpoint_url=endpoint, verify=False)

    def _get_mapping(self):
        mapping = {}
        with session_scope() as session:
            matchings = session.query(PointMatching.supplier_point_id, PointMatching.point_key).filter(
                PointMatching.supplier_id == self.get_scenario_supplier_id()
            ).all()
            for matching in matchings:
                supp_point_type, supp_point_id = parse_point_key(matching.supplier_point_id)
                if supp_point_type == PointType.CITY:
                    continue
                mapping[int(supp_point_id)] = matching.point_key
        return mapping

    def _fetch_changes(self):
        params = {}
        if self.prev_last_action_id:
            params['action_id'] = self.prev_last_action_id
        unitiki_result = requests.get(self.UNITIKI_MERGE_LIST_URL, params=params)
        point_changes_data = unitiki_result.json().get('data')
        return point_changes_data.get('station_merge_list') if point_changes_data else [] or []

    def _get_changes(self, point_changes_list, mapping):
        changes = {}
        report_data = OrderedDict([
            ('not_found', 0),
            ('found_but_not_matched', 0),
            ('after_id_not_found', 0),
            ('found_but_new_id_matched_too', 0),
            ('found_and_saved', 0),
        ])
        self.last_action_id = None
        for change in point_changes_list:
            self.last_action_id = change['action_id']
            id_before = int(change['station_id_before'])
            id_after = int(change['station_id_after'])
            point_key_before = mapping.get(id_before)
            mapping_keys = set(mapping)
            if id_before not in mapping_keys:
                report_data['not_found'] += 1
                continue
            if not point_key_before:
                report_data['found_but_not_matched'] += 1
                continue
            if id_after not in mapping_keys:
                report_data['after_id_not_found'] += 1
                mapping[id_after] = point_key_before
                continue
            after_point_key = mapping.get(id_after)
            if after_point_key:
                report_data['found_but_new_id_matched_too'] += 1
                continue
            changes[id_after] = point_key_before
            report_data['found_and_saved'] += 1
        self.report_data = report_data
        return changes

    def _set_last_action(self):
        if not self.last_action_id:
            log.error("trying to save invalid last action ID: %s", self.last_action_id)
            return
        if self._check_last_action():
            copy_source = {
                'Bucket': self.S3_BUCKET,
                'Key': self.S3_LAST_ACTION_FN
            }
            self.s3.copy(copy_source, self.S3_BUCKET, self.S3_BACKUP_FN)
        self.last_action_id = str(self.last_action_id)
        response = self.s3.put_object(
            Body=self.last_action_id.encode('utf-8'),
            Bucket=self.S3_BUCKET,
            Key=self.S3_LAST_ACTION_FN,
        )
        if response['ResponseMetadata']['HTTPStatusCode'] == 200:
            return True, None
        else:
            return False, response

    def _get_last_action(self):
        if not self._check_last_action():
            return None
        response = self.s3.get_object(
            Bucket=self.S3_BUCKET,
            Key=self.S3_LAST_ACTION_FN,
        )
        la = response['Body'].read()
        return la.decode('utf-8')

    def _check_last_action(self):
        response = self.s3.list_objects_v2(
            Bucket=self.S3_BUCKET,
            MaxKeys=10,
        )
        if response['KeyCount'] == 0:
            return False
        for obj in response['Contents']:
            if obj['Key'] == self.S3_LAST_ACTION_FN:
                return True
        return False

    def postprocess(self):
        log.info('postprocess. prepare stats and update last action ID')
        self.report.append(('cases', 'count'))
        self.report.extend(self.report_data.items())

        if self.dry_run or not self.last_action_id:
            log.info('dry run or empty changes, last action ID will not be updated')
            return
        log.info('updating last action ID')
        is_ok, response = self._set_last_action()
        if is_ok:
            log.info('last action ID updated')
        else:
            log.error('error updating last action ID: %s', response)
