# -*- coding: utf-8 -*-
import operator

from django.conf import settings
from django.db import connection
from django.db.models import Q
from django.db.models.expressions import RawSQL
from functools import reduce

from events.common_app.utils import get_lang_with_fallback, re_escape


class TranslationQuerySetMixin:
    def _get_languages(self):
        return [lang for lang in get_lang_with_fallback() if lang]

    def _get_translatable_fields(self, *fields):
        return set(fields).intersection(self.model.FIELDS_FOR_TRANSLATION)

    def _get_nontranslatable_fields(self, *fields):
        return set(fields) - set(self.model.FIELDS_FOR_TRANSLATION)

    def _get_sql_statement(self, field, lang, translatable_fields):
        if field in translatable_fields:
            return "(lower(translations -> '{field}' ->> '{lang}' {collation}) ~* %s)".format(
                field=field,
                lang=lang,
                collation=settings.PGAAS_COLLATION,
            )
        else:
            return "(lower({field}::text {collation}) ~* %s)".format(
                field=field,
                collation=settings.PGAAS_COLLATION,
            )

    def filter_translated_fields_postgresql(self, text, *fields):
        queryset = self
        translatable_fields = self._get_translatable_fields(*fields)
        words = re_escape(text).split()
        langs = self._get_languages()
        raw_sql = ' AND '.join(
            '( %s )' % ' OR '.join(
                self._get_sql_statement(field, lang, translatable_fields)
                for field in fields
                for lang in langs
            )
            for word in words
        )
        raw_args = tuple(
            f'\\m{word}'
            for word in words
            for i in range(len(fields) * len(langs))
        )
        return queryset.annotate(suggested=RawSQL(raw_sql, raw_args)).filter(suggested=True)

    def filter_translated_fields_general(self, text, *fields):
        queryset = self
        q_list = []
        translatable_fields = self._get_translatable_fields(*fields)
        words = text.split()
        if translatable_fields:
            langs = self._get_languages()
            q_list.extend(
                Q(translations__icontains='"{lang}":"{text}'.format(lang=lang, text=word))
                for word in words
                for lang in langs
            )

        nontranslatable_fields = self._get_nontranslatable_fields(*fields)
        if nontranslatable_fields:
            q_list.extend(
                Q(**{'{field}__icontains'.format(field=field): word})
                for word in words
                for field in nontranslatable_fields
            )

        if q_list:
            queryset = queryset.filter(reduce(operator.or_, q_list))
        return queryset

    def filter_translated_fields(self, text, *fields):
        if connection.vendor == 'postgresql':
            return self.filter_translated_fields_postgresql(text, *fields)
        else:
            return self.filter_translated_fields_general(text, *fields)

    def filter_exact_postgresql(self, text, field):
        langs = self._get_languages()
        raw_sql = ' OR '.join(
            "(translations -> '{field}' ->> '{lang}' = %s)".format(
                field=field,
                lang=lang,
            )
            for lang in langs
        )
        raw_args = (text,) * len(langs)
        return self.annotate(suggested=RawSQL(raw_sql, raw_args)).filter(suggested=True)

    def filter_exact_general(self, text, field):
        langs = self._get_languages()
        q_list = [
            Q(translations__contains='"{lang}":"{text}"'.format(lang=lang, text=text))
            for lang in langs
        ]
        return self.filter(reduce(operator.or_, q_list))

    def filter_exact(self, text, field):
        if connection.vendor == 'postgresql':
            return self.filter_exact_postgresql(text, field)
        else:
            return self.filter_exact_general(text, field)
