# encoding: utf-8
"""
Command to sync group membershuips from Staff API v3.
Update Center syncs staff data OK except GroupMemberships.
For some unidentified reason there is a huge difference between
Maillists table and Staff API data - this command should fix the problem.
"""
import json
import traceback
from urlparse import urljoin

from django.conf import settings
from django.db import transaction
from django.db.transaction import get_connection
from termcolor import cprint
from contextlib import closing
from optparse import make_option
from dateutil.parser import parse as parse_datetime

import requests

from django.core.management import BaseCommand

from mlcore.utils.tvm2 import get_tvm_2_header


GET_GROUP_MEMBERSHIPS_FROM_STAFF_API_ID_STEP = 100000

CREATE_TEMPORARY_TABLE_INSERTS_BATCH_SIZE = 10000

CREATE_TEMPORARY_TABLE_QUERY = """
CREATE TEMPORARY TABLE `_group_membership_from_staff_api_v3` (
  `id` int(11) NOT NULL,
  `staff_id` int(11) NOT NULL,
  `group_id` int(11) NOT NULL,
  `joined_at` datetime NOT NULL,
  PRIMARY KEY (`id`),
  UNIQUE KEY `staff_id` (`staff_id`,`group_id`),
  INDEX (`staff_id`),
  INDEX (`group_id`)
) ENGINE=InnoDB AUTO_INCREMENT=2238186 DEFAULT CHARSET=utf8
"""

INSERT_ROW_INTO_TEMPORARY_TABLE_QUERY = """
INSERT INTO `_group_membership_from_staff_api_v3`(`id`, `staff_id`, `group_id`, `joined_at`)
VALUES (%s, %s, %s, %s)
"""

DELETE_GROUP_MEMBERSHIP_ROWS_NOT_PRESENT_IN_STAFF_API_DATA_QUERY = """
DELETE FROM `intranet_groupmembership`
WHERE NOT EXISTS (
    SELECT `_group_membership_from_staff_api_v3`.`id`
    FROM `_group_membership_from_staff_api_v3`
    WHERE `_group_membership_from_staff_api_v3`.`id` = `intranet_groupmembership`.`id`
)
"""

INSERT_GROUP_MEMBERSHIP_FROM_STAFF_API_DATA_NOT_PRESENT_IN_DB_QUERY = """
INSERT INTO `intranet_groupmembership`
    SELECT `id`, `staff_id`, `group_id`, `joined_at`
    FROM `_group_membership_from_staff_api_v3`
    WHERE NOT EXISTS (
        SELECT `id`
        FROM `intranet_groupmembership`
        WHERE `intranet_groupmembership`.`id` = `_group_membership_from_staff_api_v3`.`id`
    )
    AND EXISTS (
        SELECT `id`
        FROM `intranet_group`
        WHERE `intranet_group`.`id` = `_group_membership_from_staff_api_v3`.`group_id`
    )
    AND EXISTS (
        SELECT `id`
        FROM `intranet_staff`
        WHERE `intranet_staff`.`id` = `_group_membership_from_staff_api_v3`.`staff_id`
    )
"""

COUNT_GROUP_MEMBERSHIPS_QUERY = "SELECT COUNT(DISTINCT `id`) FROM `intranet_groupmembership`;"


class StaffAPIClient(object):
    def __init__(self):
        self._tvm_id = settings.STAFF_API_TVM_ID
        self._base_url = settings.STAFF_API_BASE_URL

    def get_group_memberships(self, page=None, limit=None, sort=None, query=None, fields=None):
        url = urljoin(self._base_url, '/v3/groupmembership')
        params = {}
        for name, value in [('_page', page), ('_limit', limit), ('_sort', sort), ('_query', query),
                            ('_fields', fields)]:
            if value:
                params[name] = value

        response = requests.get(url, headers=get_tvm_2_header(self._tvm_id), params=params)
        return response.json()['result']


def get_all_group_memberships(id_step, fields='id,person.id,group.id,joined_at', limit=None):
    """Get all GroupMembership objects at the current moment

    GroupMembership object example:
        {
           "group" : {
              "department" : {
                 "id" : null
              },
              "url" : "fired",
              "name" : "Группа",
              "is_deleted" : false,
              "id" : 123,
              "type" : "wiki",
              "service" : {
                 "id" : null
              }
           },
           "person" : {
              "uid" : "123123123123",
              "name" : {
                 "first" : {
                    "en" : "Name",
                    "ru" : "Имя"
                 },
                 "last" : {
                    "en" : "Surname",
                    "ru" : "Фамилия"
                 }
              },
              "is_deleted" : false,
              "id" : 123456,
              "official" : {
                 "is_robot" : false,
                 "is_dismissed" : true,
                 "is_homeworker" : true,
                 "affiliation" : "external"
              },
              "login" : "hey"
           },
           "id" : 12345678,
           "joined_at" : "2019-10-10T17:00:01.228000+00:00"
        }
    """

    client = StaffAPIClient()
    print_message("Get GroupMemberships from Staff API v3", end='')
    if limit:
        print_message("with limit %s" % limit)
    else:
        print_message("")

    id_gte, id_lt = 0, id_step
    id_max = client.get_group_memberships(page=1, fields='id', limit=1, sort='-id')[0]['id']

    print_extra("max(id) = {}".format(id_max))

    counter = 0
    while id_gte < id_max:
        _query = 'id >= {} and id < {}'.format(id_gte, id_lt)
        batch = client.get_group_memberships(query=_query, page=1, limit=id_step, sort='id', fields=fields)
        print_extra("Get {} records with {} <= id < {}".format(len(batch), id_gte, id_lt))

        for el in batch:
            if limit and counter >= limit:
                break
            yield el
            counter += 1

        id_gte, id_lt = id_lt, id_lt + id_step

    print_message("Got {} group memberships from Staff API".format(counter))


def split_into_lists(iterable, list_size):
    """Given iterable yield lists of given size until iterable is exhausted"""
    assert list_size > 0
    iterate_from = iter(iterable)
    exhausted = False
    while not exhausted:
        batch = []
        for _ in range(list_size):
            try:
                batch.append(next(iterate_from))
            except StopIteration:
                exhausted = True
                break
        if batch:
            yield batch


def existing_ids(model_cls, ids):
    """Given Model class and ids return only existing ones"""
    return [int(i) for i in model_cls.objects.filter(id__in=[int(i) for i in ids]).values_list('id', flat=True)]


def map_group_memberships_to_row_data(d):
    return int(d['id']), int(d['person']['id']), int(d['group']['id']), parse_datetime(d['joined_at'])


def get_group_memberships_rows():
    _warning_done = False
    for el in get_all_group_memberships(id_step=GET_GROUP_MEMBERSHIPS_FROM_STAFF_API_ID_STEP):
        try:
            yield map_group_memberships_to_row_data(el)
        except (ValueError, TypeError):
            if not _warning_done:
                cprint("Unable to validate some group membership object from Staff API: %s" % json.dumps(el),
                       "red", attrs=["bold"])
                print_extra(traceback.format_exc())
                _warning_done = True


class Command(BaseCommand):
    option_list = BaseCommand.option_list + (
        make_option('-s', '--something', dest='start_timestamp', default='bla'),
    )

    def handle(self, *args, **options):
        try:
            self.update()
        except KeyboardInterrupt:
            cprint("Interrupted", "red")
        except Exception:
            print_message('Failed to update group memberships table')
            print_extra(traceback.format_exc())

    def update(self):
        self.populate_temporary_table()
        self.update_group_memberships()

    def populate_temporary_table(self):
        """
        Create and populates temporary table with group memberships data from staff API v3.
        Temporary table exists until the end of current session - there is no need to explicitly drop it.
        Temporary tables in MySQL exist on session level (not for transaction).
        """
        connection = get_connection()
        with closing(connection.cursor()) as cursor:
            print_group_memberships_count(cursor)
            cursor.execute(CREATE_TEMPORARY_TABLE_QUERY)
            cursor.execute('SHOW CREATE TABLE ml.`_group_membership_from_staff_api_v3`;')
            result = cursor.fetchall()
            print_message('Temporary table created by statement:\n%s\n' % result[0][1])

            print_message('Insert group_memberships from API into temporary table...')
            for batch in split_into_lists(get_group_memberships_rows(),
                                          list_size=CREATE_TEMPORARY_TABLE_INSERTS_BATCH_SIZE):
                cursor.executemany(INSERT_ROW_INTO_TEMPORARY_TABLE_QUERY, batch)
                print_extra("Insert %s records" % len(batch))

            cursor.execute("SELECT COUNT(*) FROM `_group_membership_from_staff_api_v3`;")
            print_extra('%s records inserted into temporary table' % cursor.fetchall()[0][0])

    def update_group_memberships(self):
        """
        Remove group memberships not present in data from Staff API, then create all missing rows.
        When new row is inserted check for existence of corresponding staff, group rows in DB
        and skip row if some entities does not exist yet.
        """
        with transaction.atomic():
            connection = get_connection()
            with closing(connection.cursor()) as cursor:
                print_message('Delete group memberships that are not present in Staff API data...')
                cursor.execute(DELETE_GROUP_MEMBERSHIP_ROWS_NOT_PRESENT_IN_STAFF_API_DATA_QUERY)
                print_group_memberships_count(cursor)
                print_message('Insert new records into `intranet_groupmembership`')
                cursor.execute(INSERT_GROUP_MEMBERSHIP_FROM_STAFF_API_DATA_NOT_PRESENT_IN_DB_QUERY)

        # reopen connection to check results
        connection.close()
        connection = get_connection()
        with closing(connection.cursor()) as cursor:
            print_group_memberships_count(cursor)
            print_message("Done")


def print_message(msg, end='\n'):
    cprint(msg, "yellow", end=end)


def print_extra(msg, end='\n'):
    cprint(msg, "green", end=end)


def print_group_memberships_count(cursor):
    cursor.execute(COUNT_GROUP_MEMBERSHIPS_QUERY)
    print_extra("`intranet_groupmembership` table contains %s records" % cursor.fetchall()[0][0])
