"""Import modules"""
import logging
import json
import requests
from sandbox import sdk2

""" Docs: https://docs.yandex-team.ru/nanny/reference/awacs/l7-heavy-balancer """


class MediaItsWeightManagement(sdk2.Task):
    """MediaItsWeightManagement"""

    class Parameters(sdk2.Parameters):
        """Parameters"""
        yp_cluster = sdk2.parameters.String('yp_cluster')
        auth_token = sdk2.parameters.String("nanny_token", do_not_copy=True)
        balancers = sdk2.parameters.List("list_balancer")

    api_url = 'https://ext.its.yandex-team.ru/v2/l7/heavy'
    ca_cert = '/usr/share/yandex-internal-root-ca/YandexInternalRootCA.crt'
    list_added = []

    def get_value_weights(self, balancer_id):
        resp = self.session.get('{}/{}/weights/values/'.format(self.api_url, balancer_id), verify=self.ca_cert)
        return({'data': resp.json(), 'etag': resp.headers['ETag']})

    def post_value_weights(self, balancer_id, values):
        resp = self.session.post('{}/{}/weights/values/'.format(self.api_url, balancer_id),
                                 json=values['data'], headers={'If-Match': values['etag']}, verify=self.ca_cert)
        return(resp.headers['ETag'].strip('"'))

    def get_value_its(self, balancer_id):
        resp = self.session.get('{}/{}/weights/its_value/'.format(self.api_url, balancer_id), verify=self.ca_cert)
        data = resp.json()
        return(data['current_version'])

    def push_its(self, balancer_id, current_its_version, current_db_version):
        resp = self.session.post('{}/{}/weights/its_value/'.format(self.api_url, balancer_id),
                                 json={'current_version': current_its_version, 'target_version': current_db_version}, verify=self.ca_cert)
        data = resp.json()
        return(data['current_version'])

    def change_weights(self, balancer_id, sections, cluster):
        values = self.get_value_weights(balancer_id)
        for section_id in sections:
            count_dc = len(values['data']['sections'][section_id]['locations'])
            if count_dc == 2:
                weight = 100
            elif count_dc == 3:
                weight = 50
            elif count_dc == 4:
                weight = 33
            for dc in values['data']['sections'][section_id]['locations']:
                if '' == cluster:
                    values['data']['sections'][section_id]['locations'][dc]['weight'] = values['data']['sections'][section_id]['locations'][dc]['default_weight']
                elif dc == cluster:
                    values['data']['sections'][section_id]['locations'][dc]['weight'] = 0
                else:
                    values['data']['sections'][section_id]['locations'][dc]['weight'] = weight
        current_db_version = self.post_value_weights(balancer_id, values)
        current_its_version = self.get_value_its(balancer_id)
        current_its_version = self.push_its(balancer_id, current_its_version, current_db_version)
        if current_db_version == current_its_version:
            return(True)
        else:
            return(False)

    def main(self):
        """Main"""
        self.session = requests.Session()
        self.session.headers.update({
            'Authorization': 'OAuth {}'.format(
                sdk2.yav.Secret(self.Parameters.auth_token).data()["nanny_oauth"]),
        })
        for balancer in self.Parameters.balancers:
            balancer = json.loads(balancer)
            if self.change_weights(balancer['balancer_id'], balancer['sections'], self.Parameters.yp_cluster):
                self.list_added.append({'balancer_id': balancer['balancer_id'], 'sections': balancer['sections']})
            else:
                for balancer_dis in self.list_added:
                    self.change_weights(balancer_dis['balancer_id'], balancer_dis['sections'], '')
                logging.error(f"NOT CHANGE ITS WEIGHT FOR {balancer['balancer_id']}.{balancer['sections']} DC {self.Parameters.yp_cluster}")
                exit(1)

        return(True)

    def on_execute(self):
        try:
            self.main()
        except RuntimeError as error:
            logging.error(f"ERROR: {error}\n")
