#!/usr/bin/env python
# coding=utf-8

import os
import json
import copy
import tempfile
import subprocess
import logging
import operator
from collections import OrderedDict, defaultdict

logger = logging.getLogger(__name__)


def is_close(a, b, rel_tol=1e-09, abs_tol=0.0):
    return abs(a - b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol)


def fix_path(path):
    return os.path.realpath(os.path.expanduser(path))


class BasicModifier(object):

    def __init__(self, data):
        self.data = data

    class OperationTypes(object):
        pass

    def get_operation(self, name):
        return operator.attrgetter(name)(self.OperationTypes)

    def load_json(self, path):
        with open(path, 'r') as f:
            self.data = json.load(f)

    def dump_json(self, path):
        with open(path, 'w') as f:
            json.dump(self.data, f, indent=2)


class ScriptModifier(BasicModifier):

    def __init__(self, data, script_path, data_dir=''):
        super(ScriptModifier, self).__init__(data=data)

        script_format_dict = {}
        temp_file_path = ""
        if "{current_data_file_path}" in script_path:
            temp_file_path = tempfile.NamedTemporaryFile(delete=False).name
            script_format_dict['current_data_file_path'] = fix_path(temp_file_path)
            self.dump_json(temp_file_path)

        if data_dir and "{data_dir}" in script_path:
            script_format_dict['data_dir'] = fix_path(data_dir)

        if script_format_dict:
            script_path = script_path.format(**script_format_dict)

        logger.debug("\nlaunching script:\n{}".format(script_path))
        output = subprocess.check_output(script_path, shell=True)
        self.data = json.loads(output, object_pairs_hook=OrderedDict)

        if temp_file_path:
            os.unlink(temp_file_path)


class DataModifier(BasicModifier):

    class OperationTypes(object):
        add = 1
        multiply = 2
        allocate = 3

    def __init__(self, data, **kwargs):
        super(DataModifier, self).__init__(data=data)

        self.operation = self.get_operation(kwargs.get("operation", "add"))
        self.value_proportion = kwargs.get("value_proportion", False)
        self.data_dir = kwargs.get("data_dir", '')
        self.data_settings = kwargs.get("data_settings", {})
        self.data_variable = kwargs.get("data_variable")
        self.target_data_section = kwargs.get("target_data_section")
        self.filter = kwargs.get("filter", {})
        self.filter = {key: tuple(val) if isinstance(val, list) else val for key, val in self.filter.iteritems()}

        self.data_settings_path = kwargs.get("data_settings_path")
        if self.data_settings_path and not self.data_settings:
            path = self.data_settings_path
            if "{data_dir}" in path and self.data_dir:
                path = path.format(data_dir=fix_path(self.data_dir))
            with open(path, 'r') as f:
                self.data_settings = json.load(f)

        self.data_modifier = kwargs.get("data_modifier", {})
        if not self.data_modifier:
            self.data_modifier = self.data_settings.get(self.data_variable, {})

        valid_res_count = defaultdict(int)
        skip_rows = set()
        for data_row in self.data:
            skip_row = False
            for filter_key, filter_vals in self.filter.iteritems():
                if not data_row.get(filter_key).startswith(filter_vals):
                    skip_row = True
                    break
            if skip_row:
                skip_rows.add(data_row['invnum'])
                continue
            for res_type in data_row:
                valid_res_count[res_type] += 1

        for data_row in self.data:
            if data_row['invnum'] in skip_rows:
                continue
            for res_type, res_value in self.data_modifier.iteritems():
                if res_type in data_row:
                    self.process_row(data_row, res_type, res_value, valid_res_count.get(res_type, 0.0))

    def process_row(self, data_row, append_type, append_values, hosts_len):
        row_values = {k: sum(v.values()) for (k, v) in data_row[append_type].iteritems()}
        if self.target_data_section:
            for k, v in data_row.get(self.target_data_section, {}).iteritems():
                row_values[k] += sum(v.values())
        row_sum = sum(row_values.values())
        for append_key, append_val in append_values.iteritems():
            filter_keys = append_val.keys() if isinstance(append_val, dict) else []
            for res_key, res_values in data_row[append_type].iteritems():
                if filter_keys and res_key not in filter_keys:
                    continue
                add_value = append_val[res_key] if filter_keys else append_val
                if self.operation == self.OperationTypes.add:
                    add_value = add_value
                elif self.operation == self.OperationTypes.multiply:
                    add_value = row_values[res_key] * add_value - row_values[res_key]
                elif self.operation == self.OperationTypes.allocate:
                    add_value = add_value / hosts_len if hosts_len else 0.0

                if self.target_data_section:
                    if self.target_data_section not in data_row:
                        data_row[self.target_data_section] = OrderedDict()
                    if res_key not in data_row[self.target_data_section]:
                        data_row[self.target_data_section][res_key] = OrderedDict()
                    target_append_block = data_row[self.target_data_section][res_key]
                else:
                    target_append_block = res_values

                if filter_keys:
                    target_append_block[append_key] = add_value
                elif self.value_proportion:
                    coefficient = row_values[res_key] / row_sum if row_sum else 1.0 / len(row_values)
                    target_append_block[append_key] = add_value * coefficient
                else:  # simple proportion
                    target_append_block[append_key] = add_value / len(row_values)


class Converter(BasicModifier):

    def __init__(self, data, **kwargs):
        super(Converter, self).__init__(data=data)

        self.convert_rules = kwargs.get("convert_rules", [])

        for data_row in self.data:
            for rule in self.convert_rules:
                from_k1, from_k2 = rule["from"].split(".")
                to_k1, to_k2 = rule["to"].split(".")
                if from_k1 not in data_row:
                    logging.error('no source for inv: {}'.format(data_row["invnum"]))
                    continue
                if to_k1 not in data_row:
                    data_row[to_k1] = OrderedDict()
                if to_k2 not in data_row[to_k1]:
                    data_row[to_k1][to_k2] = OrderedDict()
                for k, v in data_row[from_k1][from_k2].iteritems():
                    if k in data_row[to_k1][to_k2]:
                        data_row[to_k1][to_k2][k] += v
                    else:
                        data_row[to_k1][to_k2][k] = v


class PriceDistributor(BasicModifier):

    def __init__(self, data, **kwargs):
        super(PriceDistributor, self).__init__(data=data)

        self.quota_price = kwargs.get("quota_price", {})
        self.elementary_resources_price = kwargs.get("elementary_resources_price", {})

        if self.quota_price:
            for data_row in self.data:
                if "quota" not in data_row or "quota_costs" not in data_row:
                    logging.error("invnum {}, section quota and quota_costs must present".format(data_row["invnum"]))
                    continue
                exceed_costs = 0.0
                for target_res, target_value in self.quota_price["target_price"].iteritems():
                    total_res_quota = sum(data_row["quota"][target_res].values())
                    total_res_quota_costs = sum(data_row["quota_costs"][target_res].values())
                    total_res_quota_costs += sum(data_row.get("quota_cost_operations", {}).get(target_res, {}).values())

                    if is_close(total_res_quota, 0.0) and is_close(total_res_quota_costs, 0.0):
                        continue
                    elif is_close(total_res_quota, 0.0):
                        logging.warning("invnum {}, no {} quota, taking all costs".format(data_row["invnum"], target_res))
                        delta = total_res_quota_costs
                    else:
                        delta = total_res_quota_costs - target_value * total_res_quota
                    data_row["quota_cost_operations"][target_res]["hdd_ssd_cost_correction"] = -delta
                    exceed_costs += delta
                left_costs = exceed_costs
                for leftover_res, leftover_coefficient in self.quota_price["leftovers_to"].iteritems():
                    delta = exceed_costs * leftover_coefficient
                    data_row["quota_cost_operations"][leftover_res]["hdd_ssd_cost_correction"] = delta
                    left_costs -= delta
                if not is_close(left_costs, 0.0):
                    logging.error("invnum {}, non distributed leftover {}".format(data_row["invnum"], left_costs))


class SegmentModifier(BasicModifier):

    class OperationTypes(object):
        split = 1
        redefine = 2

    def __init__(self, data, **kwargs):
        super(SegmentModifier, self).__init__(data=data)

        self.operation = self.get_operation(kwargs.get("operation", "split"))
        self.split_by_files = kwargs.get("split_by_files", False)
        self.filter = kwargs.get("filter", OrderedDict())
        self.split_depth = kwargs.get("split_depth", 0)
        self.segmentation = kwargs.get("segmentation", OrderedDict())

        if self.operation == self.OperationTypes.redefine:
            new_data = []
            counter = len(self.data)
            for data_row in reversed(self.data):
                counter -= 1
                skip = False
                if self.filter:
                    for key, val in self.filter.iteritems():
                        if data_row.get(key) != val:
                            skip = True
                            break
                if skip:
                    continue

                rest_segmentation = OrderedDict()
                for segment_name, segmentation in self.segmentation.iteritems():
                    if segmentation == 'rest':
                        segmentation = rest_segmentation
                    else:
                        for res_name, coefficient in segmentation.iteritems():
                            if res_name in rest_segmentation:
                                rest_segmentation[res_name] -= coefficient
                            else:
                                rest_segmentation[res_name] = 1 - coefficient

                    new_row = copy.deepcopy(data_row)
                    for res_block, resources in new_row.iteritems():
                        if not isinstance(resources, dict):
                            continue
                        for res_name, res_values in resources.iteritems():
                            if res_name in segmentation:
                                if isinstance(res_values, dict):
                                    for key in res_values.keys():
                                        res_values[key] *= segmentation[res_name]
                                else:
                                    resources[res_name] *= segmentation[res_name]
                    new_row['segment'] = segment_name
                    new_row['segmentation'] = segmentation
                    new_data.append(new_row)

                self.data.pop(counter)

            self.data.extend(new_data)

    def dump_json(self, path):

        if not self.split_by_files:
            super(SegmentModifier, self).dump_json(path=path)
        else:  # split
            split_data = defaultdict(list)
            for element in self.data:
                segment = element["segment"]
                segment_parts = segment.split('.')
                if self.split_depth and len(segment_parts) >= self.split_depth:
                    segment = '.'.join(segment_parts[:self.split_depth])
                split_data[segment].append(element)

            for segment, segment_elements in split_data.iteritems():
                segment_path = path.format(segment=segment)
                segment_dir = os.path.dirname(segment_path)
                if not os.path.exists(segment_dir):
                    os.makedirs(segment_dir)
                with open(segment_path, 'w') as f:
                    json.dump(segment_elements, f, indent=2)


def by_name(name, *args, **kwargs):
    constructor = globals()[name]
    return constructor(*args, **kwargs)
