from osgeo import ogr, osr
import os
import json
import subprocess

from .utils import geometry_column_exists, set_search_path

from .yconfig import Config as YConfig

import six
if six.PY3:
    xrange = range


GEOM_ATTR = 'geometry'
JSON_ENCODING = 'utf-8'
SHAPE_ENCODING = 'cp1251'
SHAPE_MAX_INT_WIDTH = 9
SHAPE_PROJECTION = 4326

FAKE_GEOMETRY_COLUMN = 'fake_geom'

YCONFIG_FILENAME = YConfig().config_filename


def execute_process(process):
    stdout, stderr = process.communicate()
    if process.wait() != 0:
        raise Exception("stdout: " + stdout + "\n stderr: " + stderr)
    return (stdout, stderr)


def revision2json(branch, batch_size, output_file, categories=None):
    args = ['revisionapi',
            '--cfg=' + YCONFIG_FILENAME,
            '--cmd=export',
            '--branch=' + branch,
            '--batch-size=' + batch_size,
            '--output-file=' + output_file]
    if categories is not None:
        args.append('--categories=' + ",".join(str(cat) for cat in categories))

    process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    execute_process(process)


def json2revision(branch, json_path, uid):
    process = subprocess.Popen(
        ['revisionapi',
            '--cfg=' + YCONFIG_FILENAME,
            '--cmd=import',
            '--branch=' + branch,
            '--ignore-json-ids',
            '--path=' + json_path,
            '--user-id=' + str(uid)],
        stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    return execute_process(process)[0].split()


def add_fake_geom_columns(conn, schema_name, tables):
    curr = conn.cursor()
    fake_geom_tables = []
    for table in tables:
        if not geometry_column_exists(curr, table):
            curr.execute(
                "SELECT AddGeometryColumn('%s', '%s', '%s', 4326, 'POINT', 2)" % (
                    schema_name, table, FAKE_GEOMETRY_COLUMN))
            fake_geom_tables.append(table)

    curr.close()
    return fake_geom_tables


def remove_shape_files(shape_dir, fake_geom_tables):
    for table in fake_geom_tables:
        os.remove(os.path.join(shape_dir, table + ".shp"))
        os.remove(os.path.join(shape_dir, table + ".shx"))


def drop_fake_geom_columns(conn, fake_geom_tables):
    curr = conn.cursor()
    for table in fake_geom_tables:
        curr.execute(
            "SELECT DropGeometryColumn('%s', '%s')" % (
                table, FAKE_GEOMETRY_COLUMN))
    curr.close()


def postgre2shape(conn, conn_str, schema_name, tables, output_dir):
    set_search_path(conn.cursor(), schema_name + ",public")

    # ogr library transform tables only with the geometry columns
    fake_geom_tables = add_fake_geom_columns(conn, schema_name, tables)
    conn.commit()

    src_ds = ogr.GetDriverByName('PostgreSQL').Open(
        "PG:" + conn_str + " active_schema=" + schema_name)
    if src_ds is None:
        raise Exception("open src failed")

    dst_ds = ogr.GetDriverByName('ESRI Shapefile').CreateDataSource(str(output_dir))
    if dst_ds is None:
        raise Exception("open output failed")

    for src_layer_name in tables:
        src_layer = src_ds.GetLayerByName(src_layer_name)
        src_layer_defn = src_layer.GetLayerDefn()

        prj = osr.SpatialReference()
        prj.ImportFromEPSG(SHAPE_PROJECTION)

        dst_layer = dst_ds.CreateLayer(src_layer_name,
                                       geom_type=src_layer_defn.GetGeomType(),
                                       srs=prj)

        # define dst_layer
        field_mapping = {}
        for field_indx in xrange(src_layer_defn.GetFieldCount()):
            src_field = src_layer_defn.GetFieldDefn(field_indx)

            src_field_name = src_field.GetName()
            dst_field_name = src_field_name if len(src_field_name) <= 10 else \
                             src_field_name[:10]
            field_mapping[field_indx] = dst_field_name

            dst_field = ogr.FieldDefn(dst_field_name, src_field.GetType())

            if src_field.GetType() == ogr.OFTInteger:
                if src_field.GetWidth() == 0:
                    dst_field.SetWidth(SHAPE_MAX_INT_WIDTH)
                else:
                    dst_field.SetWidth(src_field.GetWidth() - 1)
            else:
                dst_field.SetWidth(src_field.GetWidth())

            dst_field.SetPrecision(src_field.GetPrecision())
            dst_layer.CreateField(dst_field)

        # copy objects
        src_layer.ResetReading()
        src_feature = src_layer.GetNextFeature()
        while src_feature is not None:
            dst_feature = ogr.Feature(feature_def=dst_layer.GetLayerDefn())

            for src_field_indx, dst_field_name in six.iteritems(field_mapping):
                val = src_feature.GetField(src_field_indx)
                if dst_feature.GetFieldDefnRef(src_field_indx).GetType() == ogr.OFTString and val:
                    val = val.decode('utf-8').encode(SHAPE_ENCODING, 'replace')
                dst_feature.SetField(dst_field_name, val)

            dst_feature.SetGeometry(src_feature.GetGeometryRef())

            dst_layer.CreateFeature(dst_feature)
            src_feature = src_layer.GetNextFeature()

    remove_shape_files(output_dir, fake_geom_tables)
    drop_fake_geom_columns(conn, fake_geom_tables)
    conn.commit()


def postgre2dump(conn, schema_name, tables, output_dir):
    if os.path.exists(output_dir):
        raise Exception("output directory '%s' already exists" % output_dir)
    os.makedirs(output_dir)

    curr = conn.cursor()
    for table in tables:
        curr.copy_expert(
            "COPY %s.%s TO STDOUT WITH CSV" % (schema_name, table),
            open(os.path.join(output_dir, table), 'w'))
    curr.close()


def json_object2ogr_feature(category, obj, layer, categories_map):
    feature = ogr.Feature(feature_def=layer.GetLayerDefn())

    attrs = obj['attributes']
    layer_defn = layer.GetLayerDefn()
    fields_count = layer_defn.GetFieldCount()

    for i in range(fields_count):
        field_defn = layer_defn.GetFieldDefn(i)
        field_name = field_defn.GetNameRef()
        field_type = field_defn.GetType()
        attr_name = category + ':' + field_name
        value = None
        if attr_name in attrs:
            value = attrs[attr_name]
        elif 'default' in categories_map[category][field_name]:
            value = categories_map[category][field_name]['default']
        else:
            raise Exception("attribute '%s' is not found" % (attr_name))

        if field_type == ogr.OFTString:
            value = value.encode(SHAPE_ENCODING, 'replace')

        feature.SetField(field_name, str(value))

    feature.SetGeometry(ogr.CreateGeometryFromJson(json.dumps(obj[GEOM_ATTR])))

    layer.CreateFeature(feature)


def create_ogr_layer(ds, name, geometry_type, attrs):
    prj = osr.SpatialReference()
    prj.ImportFromEPSG(SHAPE_PROJECTION)

    layer = ds.CreateLayer(name, geom_type=geometry_type, srs=prj)

    for attr_name, attr_descr in six.iteritems(attrs):
        attr_type = eval('ogr.' + attr_descr['type'])
        if attr_name != GEOM_ATTR:
            field_defn = ogr.FieldDefn(attr_name, attr_type)
            if (attr_type == ogr.OFTInteger):
                field_defn.SetWidth(SHAPE_MAX_INT_WIDTH)
            layer.CreateField(field_defn)

    return layer


def json2shape(input_file, output_dir, categories_map):
    dst_ds = ogr.GetDriverByName('ESRI Shapefile').CreateDataSource(str(output_dir))
    if dst_ds is None:
        raise Exception("open output failed")

    layers = {}
    for cat, attrs in six.iteritems(categories_map):
        geometry_type = eval('ogr.' + categories_map[cat]['geometry']['type'])
        layers[cat] = create_ogr_layer(dst_ds, cat, geometry_type, categories_map[cat])

    json_data = json.load(open(input_file), encoding='utf-8')
    objects = json_data['objects']
    ex_message = ''
    for key, obj in six.iteritems(objects):
        attrs = obj['attributes']
        for cat, layer in six.iteritems(layers):
            cat_attr = 'cat:' + cat
            if (cat_attr in attrs) and (attrs[cat_attr] == '1'):
                try:
                    json_object2ogr_feature(cat, obj, layer, categories_map)
                except Exception as ex:
                    ex_message = \
                        ex_message + ("object '%s' contains error: %s\n" % (key, ex))

    if len(ex_message) > 0:
        raise Exception(ex_message)


def ogr_feature2attributes(category, feature):
    feature_defn = feature.GetDefnRef()
    fields_count = feature_defn.GetFieldCount()

    res = dict()
    for field_indx in range(fields_count):
        field_defn = feature_defn.GetFieldDefn(field_indx)
        field_name = field_defn.GetNameRef().lower()
        field_type = field_defn.GetType()
        value = feature.GetFieldAsString(field_indx)

        if field_type == ogr.OFTString:
            value = value.decode(SHAPE_ENCODING).encode(JSON_ENCODING, 'replace')
        if field_type in (ogr.OFTInteger, ogr.OFTReal):
            value = value.strip()
        res[category + ':' + field_name] = value
    res['cat:' + category] = '1'
    return res


def shape2json(shape_path, json_path, attributes):
    src_ds = ogr.GetDriverByName('ESRI Shapefile').Open(shape_path)
    if src_ds is None:
        raise Exception("open input failed")

    res = {'attributes': attributes, 'objects': {}}
    cur_id = 0
    layer_count = src_ds.GetLayerCount()

    ex_message = ''
    for layer_indx in range(layer_count):
        layer = src_ds.GetLayer(layer_indx)
        layer.ResetReading()
        feature = layer.GetNextFeature()
        while feature is not None:
            try:
                res['objects']['id' + str(cur_id)] = \
                    {'attributes': ogr_feature2attributes(layer.GetName(), feature),
                      'geometry': eval(feature.GetGeometryRef().ExportToJson())}
                cur_id += 1
            except Exception as ex:
                ex_message = ex_message + (
                    "feature '%s' from layer '%s' contains error: %s\n" %
                    (feature.DumpReadable(), layer.GetName(), ex))

            feature = layer.GetNextFeature()

    if len(ex_message) > 0:
        raise Exception(ex_message)

    with open(json_path, 'w') as result_file:
        result_file.write(json.dumps(res, encoding=JSON_ENCODING))
