# coding: utf-8

from collections import namedtuple

from .yconfig import Config as YConfig

from .utils import ConnParams

PoolRoleParams = namedtuple("PoolRoleParams",
                            "write_size write_overflow read_size read_overflow")

PoolParams = namedtuple("PoolParams",
                        "timeout ping_period fail_ping_period roles")

DatabaseParams = namedtuple("DatabaseParams",
                            "id name write_conn read_conns host_conns pool_params")

HostPort = namedtuple("HostPort", "host port")

PoolConfig = namedtuple("PoolConfig", "master, slaves")

GrinderParams = namedtuple("GrinderParams", "host")

EditorParams = namedtuple("EditorParams", "config")

ReleasesNotificationParams = namedtuple("ReleasesNotificationParams", "work_dir")

MergeParams = namedtuple("MergeParams", "sources layers")


class VrevisionsRefreshParams(object):

    class Stage(object):

        def __init__(self, id, children):
            self.id = id
            self.children = children

    def __init__(self, vrevisions_refresh_ET):

        def parse_stage(stage_ET):
            return VrevisionsRefreshParams.Stage(
                stage_ET.attrib["id"],
                [parse_stage(child) for child in stage_ET.xpath("stage")])

        self.stages = [parse_stage(el)
                       for el in vrevisions_refresh_ET.xpath('stages/stage')]


def load_role(el):
    return PoolRoleParams(int(el.attrib["writePoolSize"]),
                          int(el.attrib["writePoolOverflow"]),
                          int(el.attrib["readPoolSize"]),
                          int(el.attrib["readPoolOverflow"]))


def load_pool_params(el):
    ping_period = int(el.attrib["pingPeriod"])
    fail_ping_period = int(el.attrib["failPingPeriod"])
    timeout = int(el.attrib["timeout"])
    roles = dict((rel.tag, load_role(rel)) for rel in el)
    return PoolParams(timeout, ping_period, fail_ping_period, roles)


def load_database_params(el):
    id = el.attrib["id"]
    dbname = el.attrib.get("name", None)
    override_params = {}
    if dbname:
        override_params['dbname'] = dbname
    write_conn = ConnParams.from_xml(el.find("write"), **override_params) if el.find("write") is not None else None
    read_conns = [ConnParams.from_xml(rel, **override_params) for rel in el.xpath("read")]
    host_conns = [ConnParams.from_xml(rel, **override_params) for rel in el.xpath("host")]
    pool_params = load_pool_params(el.find("pools"))
    return DatabaseParams(id, dbname, write_conn, read_conns, host_conns, pool_params)


class Config(object):

    def __init__(self, config_filename=None):
        config = YConfig(config_filename)

        # Temporary hack to fix arcadia import errors
        if config.config_filename is None:
            return

        self.xml = config
        self.filename = config.config_filename

        self.project = config.get("common/project")
        databases = [load_database_params(el)
                     for el in config.xpath("common/databases/database")]
        self.databases = dict((d.id, d) for d in databases)
        self.grinder_params = GrinderParams(config.get("common/grinder/host"))
        self.vrevisions_refresh_params = VrevisionsRefreshParams(
            config.xpath('services/tasks/vrevisions-refresh')[0])
        self.editor_params = EditorParams(config.get('services/editor/config'))
        self.releases_notification_params = ReleasesNotificationParams(
            config.get('/services/tasks/releases-notification/work-dir'))
        self.merge_params = MergeParams(
            dict((source.get('id'), source) for source in config.xpath('services/tasks/merge/sources/source')),
            config.xpath('services/tasks/merge/layers/layer'))

_config = None


def init_config(config_filename=None):
    global _config
    if _config is None:
        _config = Config(config_filename)


def get_config():
    global _config
    assert _config
    return _config
