# coding: utf-8

import os.path
import pkgutil
import re

from six import binary_type

from .arcadia import is_arcadia
from .tools import strip_q


class QueryConfError(Exception):
    pass


class QueryConfSyntaxError(QueryConfError):
    pass


class DangerousOraBind(QueryConfError):
    pass


class QueryOverrideError(QueryConfError):
    pass


class Query(object):
    def __init__(self, query, args, line_no):
        self.query = query
        self.args = tuple(sorted(args or []))
        self.line_no = line_no

    def __eq__(self, other):
        return (self.query, self.args) == (other.query, other.args)

    def __str__(self):
        return self.query

    def __hash__(self):
        return hash(self.query)

    def __repr__(self):
        return \
            '{0.__class__.__name__}({0.query},{0.args})'.format(self)

_comment_re = re.compile(r'\s*--*\s*(name:\s*(?P<name>\w+))?')


def _split_onto_queries(fd):
    current_name = None
    start_line_no = None
    current_query = []
    for line_no, line in enumerate(fd, 1):
        line = line.rstrip()
        if not line:
            continue
        if isinstance(line, binary_type):
            line = line.decode('utf-8')
        comment_match = _comment_re.match(line)
        if comment_match is not None:
            new_name = comment_match.group('name')
            if new_name:
                if current_name:
                    yield current_name, current_query, start_line_no
                current_name = new_name
                start_line_no = line_no
                current_query = []
        else:
            if not current_name:
                raise QueryConfSyntaxError(
                    'Expect "-- name:" tag before query definition,'
                    ' line: %d, got %r, fd: %r' % (line_no, line, fd)
                )
            current_query.append(line)
    if current_name:
        yield current_name, current_query, start_line_no

ORACLE = 'oracle'
POSTGRE = 'postgre'

BAD_ORA_VARS = (':uid',)


def ora_var_style(match):
    name = match.group('var')
    if name.lower() in BAD_ORA_VARS:
        raise DangerousOraBind(
            "Don't use %s name for oracle queries" % name
        )
    return name


def pg_var_style(match):
    return '%({0})s'.format(match.group('var').lstrip(':'))

VARS_STYLES = {
    ORACLE: ora_var_style,
    POSTGRE: pg_var_style,
}
_vars_re = re.compile(r'(?P<sep>[^:a-zA-Z0-9])(?P<var>:[\w_]+)', re.M)


def translate_binds(query, styler):
    args = set()

    def translate_and_save(match):
        args.add(match.group('var'))
        var_style = styler(match)
        return match.group('sep') + var_style
    tr_query = _vars_re.sub(translate_and_save, query)
    return tr_query, tuple(a.lstrip(':') for a in args)


class QueryNotFound(AttributeError):
    pass


class QueriesHolder(object):
    def __init__(self, queries):
        self.__dict__.update(queries)

    def list_queries(self):
        return [k for k in self.__dict__]

    def __eq__(self, other):
        assert isinstance(other, self.__class__), \
            'Expect %r got %r' % (self.__class__, type(other))
        return self.__dict__ == other.__dict__

    def __repr__(self):
        return '{0}({1})'.format(
            self.__class__.__name__,
            ','.join(self.list_queries())
        )

    def __getattr__(self, name):
        raise QueryNotFound(
            'query %s not found' % name
        )

    def __len__(self):
        return len(self.list_queries())


def load(fd, style=POSTGRE):
    assert style in VARS_STYLES, \
        'Unsupported var style %s' % style
    styler = VARS_STYLES[style]
    res = {}
    for name, query_lines, line_no in _split_onto_queries(fd):
        if name in res:
            raise QueryOverrideError(
                'Try override %s line: %d already defined at line %d' % (
                    name, line_no, res[name].line_no))
        query = '\n'.join(query_lines).rstrip(';').replace('%', '%%')
        query, args = translate_binds(query, styler)
        res[name] = Query(strip_q(query), args, line_no)
    return QueriesHolder(res)


def read_file(path):
    """Read from arcadia resources if arcadia python is used, read from FS otherwise."""
    if is_arcadia():
        try:
            from library.python.runtime.entry_points import __res as res
        except ImportError:
            from library.python.runtime_py3.entry_points import __res as res
        queries_file = res.resfs_read(path.encode('utf-8')).decode('utf-8')
        assert queries_file, 'Resource not found: %s' % path
        return queries_file
    else:
        with open(path) as fd:
            return fd.read()


def load_from_file(filename, style=POSTGRE):
    queries_file = read_file(filename)
    return load(queries_file.split('\n'), style)


def load_from_package(pkg, code_filename, style=POSTGRE):
    sql_filename = find_queries_file(os.path.split(code_filename)[-1])
    data = pkgutil.get_data(pkg, sql_filename)
    return load(data.split(b'\n'), style)


def find_queries_file(code_filename):
    return os.path.splitext(code_filename)[0] + os.path.extsep + 'sql'


def load_from_my_file(code_filename, style=POSTGRE):
    return load_from_file(find_queries_file(code_filename), style)
