# -*- coding: utf-8 -*-
import logging
import dateutil

from django.conf import settings
from django.utils import timezone
from yt.wrapper import YtClient, YtError
from yt.packages.requests.exceptions import Timeout

from events.celery_app import app
from events.data_sources.models import TableMeta, TableRow
from events.data_sources.utils import get_table_name, get_table_id
from events.surveyme.models import SurveyQuestion

logger = logging.getLogger(__name__)

MAX_TABLE_SIZE = 100000
MAX_BATCH_SIZE = 100


def _parse_json(s, default=None):
    from json import loads as json_loads
    try:
        return json_loads(s)
    except Exception:
        return default


def _get_cluster_name(path):
    for it in path.split('/'):
        if it:
            return it


def _get_table_path(query):
    query_path = query.get('path')
    if query_path:
        return query_path[0]


def _get_config():
    return {
        'token': settings.YT_TOKEN,
    }


def _get_metadata(cluster_name, table_path, attributes=None):
    attributes = attributes or ['modification_time', 'row_count', 'schema']
    client = YtClient(cluster_name, config=_get_config())
    try:
        metadata = client.get(table_path, attributes=attributes)
        if metadata is not None:
            return metadata.attributes
    except (YtError, Timeout):
        pass


def _decode_yt_string(buf):
    return buf.decode().encode('raw_unicode_escape').decode()


def _get_table_data(cluster_name, table_path):
    client = YtClient(cluster_name, config=_get_config())
    try:
        return [
            _parse_json(_decode_yt_string(it))
            for it in client.read_table(table_path, 'json', raw=True)
        ]
    except (YtError, Timeout):
        pass


def get_table_name_from_params(params):
    if isinstance(params, str):
        params = _parse_json(params, {})
    for it in params.get('filters') or []:
        url = (it.get('value') or '').strip()
        if url:
            return get_table_name(url)
    return None, None


def get_external_tables():
    questions_qs = (
        SurveyQuestion.objects.using(settings.DATABASE_ROLOCAL)
        .filter(
            param_data_source='yt_table_source',
            survey__is_published_external=True,
            survey__is_deleted=False,
        )
    )
    external_tables = set()
    for params in questions_qs.values_list('param_data_source_params', flat=True):
        cluster_name, table_path = get_table_name_from_params(params)
        if cluster_name and table_path:
            external_tables.add((cluster_name, table_path))
    return external_tables


class SyncError(Exception):
    pass


class SyncTableDoesNotExist(SyncError):
    pass


class SyncTableIncorrectSchema(SyncError):
    pass


class SyncTableTooLarge(SyncError):
    pass


def sync_yt_table(cluster_name, table_path):
    table_id = get_table_id(cluster_name, table_path)
    metadata = _get_metadata(cluster_name, table_path)
    if not metadata:
        raise SyncTableDoesNotExist(f'Table {table_id} does not exist')

    schema = {
        column['name']
        for column in metadata.get('schema') or []
    }
    if schema and ('id' not in schema) and ('name' not in schema):
        raise SyncTableIncorrectSchema(f'Table {table_id} has incorrect schema {schema}')

    row_count = metadata.get('row_count', 0)
    if row_count > MAX_TABLE_SIZE:
        raise SyncTableTooLarge(f'Table {table_id} is too large with {row_count} rows count')

    table_meta, created = TableMeta.objects.get_or_create(table_identifier=table_id)
    last_modified = dateutil.parser.parse(metadata['modification_time'])
    if not created and table_meta.modification_time > last_modified:
        logger.info(f'Table {table_id} has nothing to sync')
        return False

    table_data = _get_table_data(cluster_name, table_path)
    if table_data is None:
        logger.info(f'Table {table_id} is empty')
        return False

    external_data = {
        row.get('id'): row
        for row in table_data
    }
    local_data = {
        row.source_id: row
        for row in TableRow.objects.filter(table_identifier=table_id)
    }

    rows_to_delete = [
        row.pk
        for row in local_data.values()
        if row.source_id not in external_data
    ]
    rows_to_create = []
    rows_to_update = []
    for source_id, row in external_data.items():
        local_row = local_data.get(source_id)

        name = row.get('name') or ''
        parent = row.get('parent') or ''
        display_name = row.get('display_name')

        if local_row:
            if (
                local_row.text != name
                or local_row.filter_by != parent
                or local_row.display_text != display_name
            ):
                local_row.text = name
                local_row.filter_by = parent
                local_row.display_text = display_name
                rows_to_update.append(local_row)
        else:
            row = TableRow(
                table_identifier=table_id,
                source_id=source_id,
                text=name,
                display_text=display_name,
                filter_by=parent,
            )
            rows_to_create.append(row)

    if rows_to_delete:
        TableRow.objects.filter(pk__in=rows_to_delete).delete()
    if rows_to_update:
        fields = ('text', 'filter_by', 'display_text')
        TableRow.objects.bulk_update(rows_to_update, fields=fields, batch_size=MAX_BATCH_SIZE)
    if rows_to_create:
        TableRow.objects.bulk_create(rows_to_create, ignore_conflicts=True, batch_size=MAX_BATCH_SIZE)

    table_meta.modification_time = timezone.now()
    table_meta.save()

    return True


@app.task(default_retry_delay=5, max_retries=3)
def sync_external_table_with_db(source_name, param_data_source_params):
    cluster_name, table_path = get_table_name_from_params(param_data_source_params)
    if cluster_name and table_path:
        try:
            sync_yt_table(cluster_name, table_path)
        except SyncError as e:
            logger.warn(str(e))
        except Exception:
            logger.exception('Unexpected error while processing table %s %s', cluster_name, table_path)


@app.task
def sync_external_tables_for_all_questions():
    for cluster_name, table_path in get_external_tables():
        try:
            sync_yt_table(cluster_name, table_path)
        except SyncError as e:
            logger.warn(str(e))
        except Exception:
            logger.exception('Unexpected error while processing table %s %s', cluster_name, table_path)
