# coding: utf-8
from __future__ import print_function
from contextlib import closing
import os
import sys
from argparse import ArgumentParser
import re
import logging

import psycopg2
import sqlparse


log = logging.getLogger(__name__)


def termcode(num):
    return '\033[%sm' % num

RED_COLOR = termcode(31)
GREEN_COLOR = termcode(32)
RESET_COLOR = termcode(0)

OK = GREEN_COLOR + 'OK' + RESET_COLOR
FAIL = RED_COLOR + 'FAIL' + RESET_COLOR


def hl_error(err):
    return RED_COLOR + err + RESET_COLOR


class MalformedStatement(RuntimeError):
    def __init__(self, message, code, line_no):
        RuntimeError.__init__(self, message)
        self.code = code
        self.line_no = line_no


class Statement(object):
    def __init__(self, filename, add_firstline, code):
        self.filename = filename
        self.add_firstline = add_firstline
        self.code = code
        if not self.has_unicode_modeline():
            self.check_ascii_only()

    def has_unicode_modeline(self):
        try:
            code_lines = unicode(self.code, 'utf-8').split(u'\n')
            return any(
                u'/* pgmigrate-encoding: utf-8 */' in l
                for l in code_lines)
        except UnicodeDecodeError as exc:
            raise MalformedStatement(
                "got non ascii for {0}: {1}".format(
                    self.filename, exc
                ),
                self.code,
                line_no=None)

    def check_ascii_only(self):
        try:
            unicode(self.code, 'ascii')
        except UnicodeDecodeError as exc:
            line_no = self.code.count('\n', 0, exc.start)
            raise MalformedStatement(
                "got non ascii for {0}".format(
                    self.filename,
                ),
                self.code,
                line_no)

    @property
    def name(self):
        if not self.add_firstline:
            return self.filename
        firstline = self.code[:self.code.find('\n')]
        if '(' in firstline:
            firstline = firstline[:firstline.find('(')]
        firstline = firstline[:51]
        return '{0} {1}'.format(self.filename, firstline)


def get_separated_statements(dir_path):
    for fname in sorted(os.listdir(dir_path)):
        full_path = os.path.join(
            dir_path, fname
        )
        if os.path.isfile(full_path) and fname.lower().endswith('.sql'):
            with closing(open(full_path)) as fd:
                yield Statement(full_path, None, fd.read())


def get_statements_from_file(file_path):
    with closing(open(file_path)) as fd:
        for st in sqlparse.parsestream(fd):
            statement_text = str(st).strip()
            if statement_text:
                yield Statement(
                    file_path,
                    add_firstline=True,
                    code=statement_text
                )


def print_code_with_error(code, line_no, line_replace, context_lines):
    code_lines = code.split('\n')
    broken_line = line_replace
    if broken_line is None:
        broken_line = code_lines[line_no]
    code_lines = code_lines[:line_no-1][-context_lines:] + \
        [hl_error(broken_line)] + \
        code_lines[line_no:][:context_lines]
    print('\n'.join(code_lines))


PRINTABLE_ERRORS = (
    psycopg2.ProgrammingError,
    psycopg2.NotSupportedError,
    psycopg2.DataError)


def main():
    parser = ArgumentParser(
        description='install ddl',
    )
    parser.add_argument(
        'dsn',
        metavar='DSN',
        type=str,
    )
    parser.add_argument(
        'source',
        metavar='SOURCE',
        nargs='+',
        help='sql file or dir with sql files'
    )
    parser.add_argument(
        '--context',
        type=int,
        default=3,
        metavar='NUM',
        help='print NUM lines of leading and trailing context surrounding error'
    )
    parser.add_argument(
        '--print-notices',
        action='store_true',
    )
    args = parser.parse_args()

    statements = []
    for src in args.source:
        if not os.path.exists(src):
            parser.error("source {0} don't exists".format(src))
            return 1
        try:
            if os.path.isdir(src):
                statements += list(get_separated_statements(src))
            else:
                statements += list(get_statements_from_file(src))
        except MalformedStatement as exc:
            print(exc.message)
            if exc.line_no:
                broken_line = exc.code.split('\n')[exc.line_no]
                print_code_with_error(
                    exc.code,
                    exc.line_no,
                    repr(broken_line).strip("'"),
                    args.context)
            return 1

    format_line = '{0:<3} {1:<%d} {2}' % max([len(s.name) for s in statements])

    with closing(psycopg2.connect(args.dsn)) as conn:
        with conn as _:
            cur = conn.cursor()
            for i, st in enumerate(statements):
                try:
                    cur.execute(st.code)
                    print(format_line.format(i+1, st.name, OK))
                except PRINTABLE_ERRORS as exc:
                    print(format_line.format(i+1, st.name, FAIL))
                    message = exc.message.strip()
                    print(message)
                    at_line = re.search(r'LINE (\d+): (.*)', message, re.M)
                    if at_line is None:
                        print(st.code)
                    else:
                        print_code_with_error(
                            st.code,
                            int(at_line.group(1)),
                            at_line.group(2),
                            args.context)
                    return 1
                finally:
                    if args.print_notices:
                        print(conn.notices.pop().strip())
    return 0


if __name__ == '__main__':
    sys.exit(main())
