# -*- coding: utf-8 -*-
from collections import OrderedDict
from copy import copy

from mpfs.dao.spec_converter import convert_spec_to_sql


class MongoQueryConverter(object):
    def __init__(self, dao_item_cls):
        assert dao_item_cls is not None
        self.dao_item_cls = dao_item_cls

    @property
    def table(self):
        return self.dao_item_cls.postgres_table_obj

    def find_to_sql(self, spec=None, fields=None, skip=0, limit=0, sort=None):
        columns_for_select = self._get_columns_for_select(fields)

        query = 'SELECT %s FROM %s' % (','.join(columns_for_select), self.table.fullname)

        params = {}
        if spec:
            sql_where, params = self._get_where_key_values(spec)
            query += ' WHERE ' + sql_where

        if sort:
            sql_order = self.convert_sort_fields_to_sql(sort)
            query += ' ORDER BY ' + sql_order

        if limit:
            query += ' LIMIT %d' % limit
        if skip:
            query += ' OFFSET %d' % skip

        return query, params

    def insert_to_sql(self, doc_or_docs, continue_on_error=False, manipulate=True, contains_id_field=False):
        if not isinstance(doc_or_docs, list):
            docs = [doc_or_docs]
        else:
            docs = doc_or_docs

        columns_for_select = self._get_columns_for_select(None, exclude_id_field=not contains_id_field)
        values, columns_to_id_values = self._get_insert_values(columns_for_select, docs)

        if len(docs) == 1:
            columns_to_id = columns_to_id_values[0]
            sql_values = '(%s)' % ','.join([':' + columns_to_id[coll] for coll in columns_for_select])
        else:
            parts = []
            for columns_to_id in columns_to_id_values:
                parts.append('(%s)' % ','.join([':' + columns_to_id[coll] for coll in columns_for_select]))
            sql_values = ','.join(parts)

        query = 'INSERT INTO %s (%s) VALUES %s' % (
            self.table.fullname,
            ','.join(columns_for_select),
            sql_values
        )

        if continue_on_error:
            query += ' ON CONFLICT DO NOTHING'

        if manipulate:
            primary_key_coll_name = self.dao_item_cls.get_postgres_primary_key()
            if primary_key_coll_name is None:
                primary_key_coll_name = '_id'
            query += ' RETURNING %s' % primary_key_coll_name

        # if len(values) == 1:
        #     values = values[0]

        return query, values

    def remove_to_sql(self, spec_or_id=None, multi=True):
        if spec_or_id is None:
            spec = None
        elif not isinstance(spec_or_id, dict):
            spec = {'_id': spec_or_id}
        else:
            spec = spec_or_id

        query = 'DELETE FROM %s' % self.table.fullname
        params = {}

        if multi is True:
            if spec is not None:
                sql_where, params = self._get_where_key_values(spec)
                query += ' WHERE ' + sql_where
        else:
            sql_where, params = self._get_limited_where_for_update_or_delete(spec)
            query += ' WHERE ' + sql_where

        return query, params

    def update_to_sql(self, spec, document, upsert=False, multi=False):
        assert isinstance(document, dict)
        if upsert:
            return self._convert_upsert_update(spec, document, multi)
        else:
            return self._convert_common_update(spec, document, multi)

    def convert_sort_fields(self, sort):
        sort_columns = []
        if sort is None:
            return sort_columns

        for key, direction in sort:
            coll = self.dao_item_cls.convert_mongo_key_to_postgres(key)

            is_ascending = True
            if direction == -1:
                is_ascending = False
            pg_direction = 'ASC' if is_ascending else 'DESC'

            sort_columns.append((coll, pg_direction))

        return sort_columns

    def convert_sort_fields_to_sql(self, sort, append_nulls_last=False, append_nulls_first=False):
        if append_nulls_first and append_nulls_last:
            raise ValueError('append_nulls_last and append_nulls_first are both true')
        sort_fields = self.convert_sort_fields(sort)
        template = '%s %s'
        if append_nulls_last:
            template += ' NULLS LAST'
        if append_nulls_first:
            template += ' NULLS FIRST'
        return ', '.join([template % (coll, direction) for coll, direction in sort_fields])

    def _convert_upsert_update(self, spec, document, multi):
        insert_doc = document
        if len(document) == 1:
            key, value = document.items()[0]
            if key == '$set':
                insert_doc = copy(spec)
                insert_doc.update(value)
                document = value
            elif key.startswith('$'):
                raise NotImplementedError()  # пока не поддерживаем тут такие вещи

        columns_for_select = self._get_columns_for_select(None)
        values, columns_to_id_values = self._get_insert_values(columns_for_select, [insert_doc])
        columns_to_id = columns_to_id_values[0]

        sql_update, update_params = self._convert_common_update(spec, document, multi=multi, set_param_name='item')
        sql_partial_insert = 'INSERT INTO %s (%s)' % (
            self.table.fullname,
            ','.join(columns_for_select)
        )

        sql_values_by_select = 'SELECT %s WHERE NOT EXISTS (SELECT * FROM found)' % \
                               ','.join([':%s' % val for _, val in columns_to_id.iteritems()])

        query = 'WITH found AS (%s RETURNING 1), ins_res AS (%s (%s)) SELECT * FROM found' % \
                (sql_update, sql_partial_insert, sql_values_by_select)

        params = update_params
        params.update(values)
        return query, params

    def _get_limited_where_for_update_or_delete(self, spec):
        primary_key_coll_name = self.dao_item_cls.get_postgres_primary_key()
        if primary_key_coll_name is None:
            primary_key_coll_name = '_id'

        sql_where_subquery, where_params = self._get_where_key_values(spec)
        select_subquery = 'SELECT %s FROM %s WHERE %s LIMIT 1' % \
                          (primary_key_coll_name, self.table.fullname, sql_where_subquery)
        sql_where = '%s IN (%s)' % (primary_key_coll_name, select_subquery)
        return sql_where, where_params

    def _convert_common_update(self, spec, document, multi, set_param_name='value'):
        if multi:
            sql_where, where_params = self._get_where_key_values(spec)
        else:
            sql_where, where_params = self._get_limited_where_for_update_or_delete(spec)

        sql_set = None
        set_params = {}

        if len(document) == 1:
            key, value = document.items()[0]
            if key == '$set':
                set_columns = {}
                for i, item in enumerate(value.iteritems()):
                    mongo_key, mongo_value = item
                    pg_key, pg_value = self.dao_item_cls.convert_mongo_value_to_postgres_for_key(mongo_key, mongo_value)
                    key_id = 'value%d' % i
                    set_params[key_id] = pg_value
                    set_columns[pg_key] = key_id
                sql_set = ','.join([key + '=' + ':' + val for key, val in set_columns.iteritems()])
            elif key == '$unset':
                set_columns = {}
                for i, item in enumerate(value.iteritems()):
                    mongo_key, mongo_value = item
                    pg_key, pg_value = self.dao_item_cls.convert_mongo_key_to_postgres(mongo_key), None
                    key_id = 'value%d' % i
                    set_params[key_id] = pg_value
                    set_columns[pg_key] = key_id
                sql_set = ','.join([key + '=' + ':' + val for key, val in set_columns.iteritems()])
            elif key == '$push':
                if len(value) == 1:
                    mongo_key, mongo_value = value.items()[0]
                    if isinstance(mongo_value, list):
                        raise NotImplementedError()  # пока не понадобилось
                    pg_key, pg_value = self.dao_item_cls.convert_mongo_value_to_postgres_for_key(mongo_key, [mongo_value])
                    set_params['value'] = pg_value[0]
                    sql_set = "%(key)s = %(key)s || ARRAY[:value]" % {'key': pg_key}
                else:
                    raise NotImplementedError()  # пока не понадобилось
            elif key == '$pull':
                if len(value) == 1:
                    mongo_key, mongo_value = value.items()[0]
                    if isinstance(mongo_value, list):
                        raise NotImplementedError()  # пока не понадобилось
                    pg_key, pg_value = self.dao_item_cls.convert_mongo_value_to_postgres_for_key(mongo_key, [mongo_value])
                    set_params['value'] = pg_value[0]
                    sql_set = "%(key)s = ARRAY(SELECT unnest(%(key)s) EXCEPT SELECT unnest(ARRAY[:value]))" %\
                              {'key': pg_key}
                else:
                    raise NotImplementedError()  # пока не понадобилось
            elif key == '$inc':
                set_columns = {}
                for i, item in enumerate(value.iteritems()):
                    mongo_key, mongo_value = item
                    pg_key, pg_value = self.dao_item_cls.convert_mongo_value_to_postgres_for_key(mongo_key, mongo_value)
                    key_id = 'value%d' % i
                    set_params[key_id] = pg_value
                    set_columns[pg_key] = key_id
                sql_set = ','.join([key + '=' + key + '+ :' + val for key, val in set_columns.iteritems()])
            elif key.startswith('$'):
                raise NotImplementedError()  # все прочее пока не делаем, а может и никогда не делаем

        if sql_set is None:
            dao_item = self.dao_item_cls.create_from_mongo_dict(document)
            pg_data = dao_item.get_postgres_representation(skip_missing_fields=True)

            sql_set_strings = []
            for num, coll in enumerate([c for c in self.table.columns if c in pg_data]):
                value_id = '%s%d' % (set_param_name, num)
                set_params[value_id] = pg_data[coll]
                sql_set_strings.append('%s=:%s' % (coll.name, value_id))

            sql_set = ','.join(sql_set_strings)

        query = 'UPDATE %s SET %s WHERE %s' % (self.table.fullname, sql_set, sql_where)
        params = where_params
        params.update(set_params)
        return query, params

    def _get_columns_for_select(self, fields, exclude_id_field=False):
        columns = []

        if fields is None:
            for coll in self.table.columns:
                columns.append(coll.name)
        # если в полях только False, надо взять всё кроме этих полей
        elif isinstance(fields, dict) and all(not visible for field, visible in fields.iteritems()):
            for coll in self.table.columns:
                columns.append(coll.name)
            for field in fields.iterkeys():
                columns.remove(self.dao_item_cls.convert_mongo_key_to_postgres(field))
        else:
            assert isinstance(fields, (list, dict, tuple))
            if isinstance(fields, dict):
                fields = [field for field, visible in fields.iteritems() if visible]

            for field in fields:
                columns.append(self.dao_item_cls.convert_mongo_key_to_postgres(field))

        if exclude_id_field:
            if '_id' in columns:
                columns.remove('_id')
            elif 'id' in columns:
                columns.remove('id')

        return columns

    def _get_where_key_values(self, spec, param_key='param'):
        sql, params = convert_spec_to_sql(
            spec,
            self.dao_item_cls.convert_mongo_key_to_postgres,
            self.dao_item_cls.convert_mongo_value_to_postgres_for_key,
            param_key
        )
        return sql, params

    def _get_insert_values(self, columns_for_select, docs):
        values = {}
        columns_to_id_values = []

        for i, doc in enumerate(docs):
            columns_to_id = OrderedDict()

            dao_item = self.dao_item_cls.create_from_mongo_dict(doc)
            pg_data = dao_item.get_postgres_representation(skip_missing_fields=True)

            for coll_num, coll in enumerate(columns_for_select):
                if len(docs) == 1:
                    id_ = 'value%d' % coll_num
                else:
                    id_ = 'value%d_%d' % (i, coll_num)
                columns_to_id[coll] = id_

            for coll, pg_value in pg_data.iteritems():
                value_id = columns_to_id[coll.name]
                values[value_id] = pg_value

            columns_to_id_values.append(columns_to_id)

        return values, columns_to_id_values
