#!/usr/bin/env python
# coding: utf-8
from contextlib import closing, contextmanager


try:
    from itertools import zip_longest
except ImportError:
    from itertools import izip as zip
    from itertools import izip_longest as zip_longest
from io import BytesIO
import re
import logging

from psycopg2 import IntegrityError

from .connect import make_connection
from .copy_escape import pgcopy, SEP, NULL
from .query_conf import load_from_package


log = logging.getLogger(__name__)

Q = load_from_package(__package__, __file__)


def chunks(data, SIZE=10000):
    it = iter(data)
    # http://stackoverflow.com/a/312644/3130355
    z = zip_longest(*[it] * SIZE)
    for chunk in z:
        yield [k for k in chunk if k]


def simple_insert(table_name, cols):
    return "INSERT INTO {0} ({1}) VALUES ({2})".format(
        table_name,
        ','.join(cols),
        ','.join(['%({0})s'.format(c) for c in cols])
    )


def exec_simple_insert(cur, table_name, **kwargs):
    cur.execute(
        simple_insert(table_name, kwargs.keys()),
        kwargs
    )


def simple_insert_multi(cur, table_name, data, rows):
    return "INSERT INTO {0} ({1}) VALUES {2}".format(
        table_name,
        ','.join(rows),
        ','.join(
            cur.mogrify('%s', (tuple(row[k] for k in rows),))
            for row in data)
    )


def describe_cursor(cur):
    return [c.name.lower() for c in cur.description]


def fetch_as_dicts(cur):
    # server-side cursors does not have cur.description initialized before first fetch
    desc = None
    for row in cur:
        desc = desc or describe_cursor(cur)
        yield dict(zip(desc, row))


def fetch_as_objects(cur, Result):
    for obj_dict in fetch_as_dicts(cur):
        yield Result(**obj_dict)


def copy_error_line_no(pgerror):
    # CONTEXT:  COPY box, line 323
    m = re.search(
        r'CONTEXT:\s+COPY.*line\s+(?P<line>\d+)',
        pgerror, re.MULTILINE)
    if m is not None:
        return int(m.group('line'))
    return None


def get_file_line(fd, line_no):
    fd.seek(0)
    for file_line, line in enumerate(fd, 1):
        if file_line == line_no:
            return line
    log.error('line %d not found in %r', line_no, fd)
    return None


def join_line_with_names(line, rows):
    return ' '.join(
        '%s=%s' % kv for kv in zip(rows, line.split(SEP)))


def simple_copy(cur, table_name, data, rows):
    stringifized_data = (SEP.encode('utf-8').join(pgcopy(d[col]) for col in rows)
                         for d in data)
    accum = BytesIO(b'\n'.join(stringifized_data))
    try:
        cur.copy_from(accum, table_name, sep=SEP, null=NULL, columns=rows)
    except IntegrityError as exc:
        line_no = copy_error_line_no(exc.pgerror)
        if line_no:
            copy_line = get_file_line(accum, line_no)
            if copy_line:
                log.exception(
                    'IntegrityError at line: %d: %s',
                    line_no,
                    join_line_with_names(copy_line, rows))
        raise


def qexec(conn, query, **qargs):
    """ query: pymdb.query_conf.Query """
    cur = conn.cursor()
    cur.execute(query.query, qargs)
    return cur


def simple_executemany(cur, table_name, data, rows=None):
    if not rows:
        if not data or not data[0]:
            return 0
        rows = data[0].keys()
    cnt = 0
    for chunk in chunks(data):
        cnt += len(chunk)
        simple_copy(cur, table_name, chunk, rows)
    return cnt


def make_dsn(hostname, database, dsn_suffix):
    return ' '.join([
        'host=%s' % hostname,
        'dbname=%s' % database,
        dsn_suffix
    ])


def is_master_database(conn):
    cur = qexec(conn, Q.is_slave)
    return not cur.fetchone()[0]


SQL_LINE_LIMIT = 200


@contextmanager
def transaction(dsn):
    with closing(make_connection(dsn, autocommit=False)) as conn:
        try:
            conn.SQL_LINE_LIMIT = SQL_LINE_LIMIT
            yield conn
            conn.commit()
        except Exception as outer:
            try:
                conn.rollback()
            except Exception as inner:  # pylint: disable=W0703
                log.warning('Rollback failed: %s', inner)
            raise outer


@contextmanager
def readonly_repeatable_read_transaction(dsn):
    with transaction(dsn) as conn:
        cur = conn.cursor()
        cur.execute(
            'BEGIN ISOLATION LEVEL REPEATABLE READ READ ONLY'
        )
        yield conn


@contextmanager
def autocommit_connection(dsn):
    with closing(make_connection(dsn, autocommit=True)) as conn:
        conn.SQL_LINE_LIMIT = SQL_LINE_LIMIT
        yield conn


@contextmanager
def get_connection(dsn=None, conn=None):
    if not conn:
        assert dsn, 'Need either dsn or pg connection'
        with transaction(dsn) as new_conn:
            yield new_conn
    else:
        yield conn
