import logging
import os
import shutil
import hashlib
import re
from sandbox import sdk2

from sandbox.projects.common import time_utils as tu
from sandbox.projects.common import error_handlers as eh
from sandbox.projects.common import utils2
from sandbox.sdk2.helpers import subprocess as sp
from sandbox.projects.prs_ops import resources


class PrsOpsComponent(object):

    def __init__(
        self,
        name="prs_ops",
        mode=None,
        mr_server=None,
        save_to=None,
        disable_samohod=None,
        substitute_ranks=None,
        is_queries_full=True,
        args=None,
        customCGI=None,
        selected_slices=None,
        binary=None,
        queries=None,
        ratings=None,
        write_mode=None,
        success_threshold=None,
        tm=None,
        judged_urls=None,
        plus=None,
        task=None,
        requests=None,
        reader=None,
        converter=None,
        mapreduce=None,
        gdtsf=None,
        yt_client=None,
        need_check=True,
    ):
        self._need_check = need_check
        self._name = name
        self._task = task
        self._cmd_line = []
        self._check_list = [self.check_features]
        self._environment = {}

        if binary:
            self._binary = sdk2.ResourceData(binary)
            self._cmd_line.append(str(self._binary.path))
        else:
            raise Exception("No binary")

        if queries:
            self._queries = sdk2.ResourceData(queries)
            self._queries_md5 = self.get_md5(str(self._queries.path))
            self._cmd_line.append("-i {}".format(str(self._queries.path)))
            self._check_list.append(self.check_queries)
        else:
            self._queries = None

        if ratings:
            self._ratings = sdk2.ResourceData(ratings)
            self._ratings_md5 = self.get_md5(str(self._ratings.path))
            self._cmd_line.append("-r {}".format(str(self._ratings.path)))
            self._check_list.append(self.check_ratings)
        else:
            self._ratings = None

        if requests:
            self._requests = sdk2.ResourceData(requests)
            self._requests_md5 = self.get_md5(str(self._requests.path))
            self._cmd_line.append("-c {}".format(str(self._requests.path))[:-4])
            self._check_list.append(self.check_requests)
        else:
            self._requests = None

        self._is_queries_full = is_queries_full
        if self._is_queries_full:
            self._cmd_line.append("--queries-file-format queries-full")
        else:
            self._cmd_line.append("--queries-file-format queries")

        self._substitute_ranks = substitute_ranks
        if self._substitute_ranks:
            self._cmd_line.append("--substitute-ranks")

        self._customCGI = customCGI
        if self._customCGI:
            self._cmd_line.append("--additional-params '{}'".format(self._customCGI))

        self._selected_slices = selected_slices
        if self._selected_slices:
            self._cmd_line.append("--selected-slices {}".format(self._selected_slices))

        self._mode = mode
        if self._mode:
            self._cmd_line.append("--mode {}".format(self._mode))

        self._write_mode = write_mode
        self._mr_server = mr_server
        self._save_to = save_to

        if self._write_mode:
            self._cmd_line.append("--write-mode {}".format(self._write_mode))

        if "mr" in self._write_mode:
            self._cmd_line.append("--mr-server {}".format(self._mr_server))
            self._cmd_line.append("-o {}".format(self._save_to))

        self._disable_samohod = disable_samohod
        if self._disable_samohod:
            self._cmd_line.append("--disable-samohod-on-middle")

        self._success_threshold = success_threshold
        if self._success_threshold:
            self._cmd_line.append("--success-threshold {}".format(self._success_threshold))

        self._plus = plus
        if self._plus:
            self._cmd_line.append("--prs-plus")

        self._tm = tm
        if self._tm:
            self._cmd_line.append("--gather-text-machine-hits")
            self._check_list.append(self.check_tm_hits)

        self._judged_urls = judged_urls
        if self._judged_urls:
            self._cmd_line.append("--gather-text-machine-hits-only-for-judged-urls")

        self._gdtsf = gdtsf
        if self._gdtsf:
            self._cmd_line.append("--gather-dynamic-top-sizes-factors")
            self._check_list.append(self.check_gdtsf)

        self._args = args
        if self._args:
            self._cmd_line.append(self._args)

        if reader:
            self._reader = sdk2.ResourceData(reader)
        else:
            self._reader = None

        if converter:
            self._converter = sdk2.ResourceData(converter)
        else:
            self._converter = None

        if mapreduce:
            self._mapreduce = sdk2.ResourceData(mapreduce)
        else:
            self._mapreduce = None

        if yt_client:
            self._yt_client = yt_client

    def get_md5(self, filename):
        with open(filename, 'r') as f:
            ans = f.read()
        return hashlib.md5(ans).hexdigest()

    def start(self):
        with sdk2.helpers.ProcessLog(self._task, logger="run_{}".format(self._name)) as pl:
            sp.check_call(
                " ".join(self._cmd_line),
                shell=True,
                stdout=pl.stdout,
                stderr=pl.stderr,
                env=self._environment or None,
            )
        logging.info("Process {} was successfully started".format(self._name))

    def stop(self):
        logging.debug("Process {} was successfully finihed".format(self._name))

    def get_log_resource(self):
        log_resource = resources.PRS_OPS_LOG(self._task, "{}_log".format(self._name), "{}_log_{}".format(
            self._name,
            tu.date_ymdhm(sep="_")
        ))
        log_path = os.path.abspath(os.path.join(str(self._task.log_path()), "run_{}.err.log".format(self._name)))
        logging.debug("log_path: {}\n try this: {}\n".format(log_path, os.getcwd()))
        shutil.move(log_path, str(log_resource.path))
        sdk2.ResourceData(log_resource).ready()
        return log_resource

    def check_correctness(self):
        logging.debug("Start checks")
        for check in self._check_list:
            try:
                check()
            except Exception as exc:
                eh.log_exception('check {} failed'.format(str(check.__name__)), exc)
                eh.check_failed("prs_ops don't consistent {}".format(exc))

    def check_queries(self):
        if self._queries_md5 != self.get_md5(str(self._queries.path)):
            raise Exception("md5sum of queries was changed after prs_ops's run")

    def check_ratings(self):
        if self._ratings_md5 != self.get_md5(str(self._ratings.path)):
            raise Exception("md5sum of ratings was changed after prs_ops's run")

    def check_requests(self):
        if self._requests_md5 != self.get_md5(str(self._requests.path)):
            raise Exception("md5sum of requests was changed after prs_ops's run")

    def check_features(self):
        if "CASTING" in self._mode:
            return

        if "tsv" == self._write_mode:
            features_path = os.path.abspath("features.tsv")
            logging.info("creating features resource")
            features = resources.PRS_OPS_FEATURES(
                self._task,
                "features", "features__{}".format(tu.date_ymdhm(sep="_"))
            )
            shutil.move(features_path, str(sdk2.ResourceData(features).path))

            factor_path = os.path.abspath("factor_slices.tsv")
            logging.info("creating factor_slices resource")
            factor = resources.PRS_OPS_FACTOR_SLICES(
                self._task,
                "factor_slices", "factor_slices__{}".format(tu.date_ymdhm(sep="_"))
            )
            shutil.move(factor_path, str(sdk2.ResourceData(factor).path))
            with open(str(sdk2.ResourceData(features).path), "r") as f:
                with open(str(sdk2.ResourceData(factor).path), "r") as ff:
                    feat = f.readlines()
                    fact = ff.read()
                    len_fact = 0
                    pattern = re.compile(r"\[(\d+);(\d+)\)")
                    for patt in pattern.findall(fact):
                        len_fact += int(patt[1]) - int(patt[0])
                    for feature in feat:
                        try:
                            feature_split = feature.split()
                            if int(feature_split[0]) < 0 or int(feature_split[3]) < 0:
                                logging.debug("fail in {} or {}".format(feature_split[0], feature_split[3]))
                                raise Exception
                        except Exception:
                            raise Exception("feature key is not int")
                        if ("http://" not in feature_split[2]) and ("https://" not in feature_split[2]):
                            raise Exception("Don't found http:// or https:// in features url")
                        if len(feature_split) != len_fact + 4:
                            raise Exception("Feature len() != factor slices len({})".format(len_fact))
            sdk2.ResourceData(factor).ready()
            sdk2.ResourceData(features).ready()
        elif "mr-tsv" == self._write_mode:
            pass
        elif "proto" == self._write_mode:
            pass
        elif "mr-proto" == self._write_mode:
            factor = self._yt_client.read_table("{}/factor_slices".format(self._save_to))
            len_fact = 0
            pattern = re.compile(r"\[(\d+);(\d+)\)")
            for i in factor:
                for patt in pattern.findall(i["value"]):
                    len_fact += int(patt[1]) - int(patt[0])
            cmd_line = [
                "{} -read {}/features -server hahn -subkey -lenval".format(str(self._mapreduce.path), self._save_to),
                "| {} mr-proto final".format(str(self._converter.path)),
                "| {} Feature | awk '{{print NF}}' > tmp.txt".format(str(self._reader.path)),
            ]
            with sdk2.helpers.ProcessLog(self._task, logger="check") as pl:
                sp.check_call(
                    " ".join(cmd_line),
                    shell=True,
                    stdout=pl.stdout,
                    stderr=pl.stderr,
                    env=self._environment or None,
                )
            with open("tmp.txt") as f:
                for i in f.readlines():
                    if int(i) != len_fact:
                        raise Exception("len_fact_slices != len features")

    def check_tm_hits(self):
        if "CASTING" in self._mode:
            return
        if "mr-proto" == self._write_mode:
            cmd_line = [
                "{} -read {}/features -server hahn -subkey -lenval".format(str(self._mapreduce.path), self._save_to),
                "| {} mr-proto final".format(str(self._converter.path)),
                "| {} TMHits | awk '{{print NF}}' > tmp_tmhits.txt".format(str(self._reader.path)),
            ]
            with sdk2.helpers.ProcessLog(self._task, logger="check_tm") as pl:
                sp.check_call(
                    " ".join(cmd_line),
                    shell=True,
                    stdout=pl.stdout,
                    stderr=pl.stderr,
                    env=self._environment or None,
                )
            with open("tmp_tmhits.txt") as f:
                content = f.readlines()
                if not content:
                    raise Exception("TMHits doesnt exist")
                prev_state = 0
                for i in content:
                    current_state = int(i)
                    if current_state == 0 and prev_state == 0:
                        raise Exception("Feature has no TMHits")
                    prev_state = current_state

    def check_gdtsf(self):
        logging.debug("Start check gdtsf")
        gdtsf = self._yt_client.read_table("{}/dynamic_top_sizes_features".format(self._save_to))
        gdtsfs = self._yt_client.read_table("{}/dynamic_top_sizes_factor_slices".format(self._save_to))
        len_fact = 0
        pattern = re.compile(r"\[(\d+);(\d+)\)")
        for i in gdtsfs:
            for patt in pattern.findall(i["value"]):
                len_fact += int(patt[1]) - int(patt[0])
        for i in gdtsf:
            lst = i["value"].split("\t")
            if len(lst) != len_fact + 3:
                logging.debug("gdtsf factors :{}, slices :{}".format(len(lst), len_fact))
                raise Exception("gdtsf error with len")

            for j in lst[2:]:
                try:
                    float(j)
                except ValueError:
                    logging.debug("{} has not float value".format(i["key"]))
                    raise Exception("GDTSF has problem")
            logging.debug("GDTSF: {}".format(gdtsf))

    def get_environment(self):
        return self._environment

    def set_environment(self, env):
        self._environment = env

    def __enter__(self):
        try:
            self.start()
        except Exception as exc:
            eh.log_exception('Can not run process', exc)
            message = '<a href="{}" target="_blank">prs_ops_log</a>'.format(
                utils2.resource_redirect_url(self.get_log_resource().id)
            )
            self._task.set_info('See {}'.format(message), do_escape=False)
            eh.check_failed("prs_ops didn't run")

        return self

    def __exit__(self, exc_type, exc_value, traceback):
        if self._need_check:
            self.check_correctness()
