#!/usr/bin/env python2


from __future__ import print_function

import requests
from requests.packages.urllib3.exceptions import (
    InsecurePlatformWarning, InsecureRequestWarning, SNIMissingWarning
)
requests.packages.urllib3.disable_warnings(InsecurePlatformWarning)
requests.packages.urllib3.disable_warnings(InsecureRequestWarning)
requests.packages.urllib3.disable_warnings(SNIMissingWarning)

import mpfs.engine.process
mpfs.engine.process.setup_admin_script()

import argparse
from collections import defaultdict
from copy import copy
import datetime
import sys
from time import sleep
from urllib2 import HTTPError
import yaml

from mpfs.core.services.djfs_albums import djfs_albums
from mpfs.dao.session import Session
from mpfs.metastorage.postgres.exceptions import EofDetectedError
from mpfs.metastorage.postgres.services import SharpeiUserNotFoundError
from mpfs.common.util import retry_decorator


OFFSET = None
RUNNING_MIN =  20000
RUNNING_MAX = 100000


def get_time():
    return datetime.datetime.now().replace(microsecond=0)


def yprint(x, explicit_start=False, explicit_end=False):
    return print(
        yaml.safe_dump(
            x, explicit_start=explicit_start, explicit_end=explicit_end
        ),
        end='',
    )


@retry_decorator(exception=requests.ReadTimeout, tries=3)
def reindex(uid, resetAlbums):
    try:
        response = djfs_albums.request(
            'POST',
            '/api/v1/albums/faces/reindex',
            {'uid': uid, 'resetAlbums': str(resetAlbums).lower()},
        )
        status_code = response.status_code
        error = (
            response.json()['error']['stackTrace'].splitlines()[0]
            if response.status_code == 400
            else ''
        )
    except HTTPError, e:
        if e.code == 451:
            status_code = e.code
            error = ''
        else:
            raise
    assert (
        status_code in {200, 451}
        or error
            .startswith(
                'ru.yandex.chemodan.app.djfs.core.user.UserNotInitializedException:'
            )
    ), response.text
    return (status_code, error)


def check_and_reindex(uid, force_reindex_running):
    print(uid, end='\t')
    try:
        session = Session.create_from_uid(uid)
    except SharpeiUserNotFoundError:
        return
    user = (
        session
        .execute('select * from disk.user_index where uid = :uid', {'uid': uid})
        .fetchone()
    )
    if user.blocked:
        print('blocked', user.faces_indexing_state, end='\t')
    elif (
        user.faces_indexing_state is None
        or (force_reindex_running and user.faces_indexing_state == 'running')
    ):
        status_code, error = reindex(uid, force_reindex_running)
        print(status_code, error, end='\t')
    else:
        print('already', user.faces_indexing_state, end='\t')


def add_batch(uids, force_reindex_running):
    global OFFSET
    start = OFFSET
    batch_size = min(len(uids) - OFFSET, RUNNING_MAX - RUNNING_MIN)
    for i in xrange(batch_size):
        print(get_time(), end='\t')
        check_and_reindex(uids[start + i], force_reindex_running)
        OFFSET += 1
        print('OFFSET:', OFFSET)


def get_users_state(session):
    return (
        session
        .execute("""
            select
                uid,
                faces_indexing_state,
                faces_indexing_state_time,
                blocked
            from disk.user_index
            where faces_indexing_state is not null;
        """)
        .fetchall()
    )


def check_status():
    sessions = Session.create_for_all_shards()
    hist_by_uid = {}
    print('checking shards:', len(sessions), end='...')
    for i, session in enumerate(sessions, 1):
        users = get_users_state(session)
        for uid, state, time, blocked in users:
            if time is None:
                time = get_time()
            if blocked:
                state = 'blocked'
            if uid in hist_by_uid:
                hist_by_uid[uid] = max(hist_by_uid[uid], (time, state))
            else:
                hist_by_uid[uid] = (time, state)
        print(len(sessions) - i, end='...')
    state_hist = defaultdict(int)
    running_by_date = defaultdict(int)
    running_uids = defaultdict(list)
    for uid, (time, state) in hist_by_uid.iteritems():
        date = time.date()
        state_hist[state] += 1
        if state == 'running':
            running_by_date[copy(date)] += 1
            running_uids[copy(date)].append(uid)
    print()

    if running_uids:
        mindate = min(running_uids.keys())
        uids = running_uids[mindate]
        with open('faces.running.oldest.txt', 'w') as file:
            file.writelines(str(uid) + '\n' for uid in uids)

    for date, uids in running_uids.items():
        with open('/tmp/faces.running.{}.txt'.format(date), 'w') as file:
            file.writelines(str(uid) + '\n' for uid in uids)

    report = yaml.safe_dump(
        dict(
            state_hist,
            running_by_date=dict(running_by_date),
            running_uids={
                k: v for k, v in running_uids.items() if len(v) < 10
            },
            time=get_time(),
        ),
        explicit_start=True,
        explicit_end=True,
    )
    print(report, end='')
    with open('faces_status.txt', 'w') as status_file:
        status_file.write(report)
    return state_hist


def is_batch_needed(skip_check_status):
    if skip_check_status:
        return True
    else:
        try:
            state_hist = check_status()
        except EofDetectedError, e:
            print(e)
            return False
        return state_hist['running'] < RUNNING_MIN


def main():
    global OFFSET

    parser = argparse.ArgumentParser('Start faces indexing')
    parser.add_argument('offset', type=int)
    parser.add_argument('--uid-file', default='uids.txt')
    parser.add_argument('--skip-check-status', action='store_true')
    parser.add_argument('--force-reindex-running', action='store_true')
    args = parser.parse_args()

    OFFSET = args.offset
    uids = map(int, open(args.uid_file))

    while True:
        if OFFSET < len(uids):
            if is_batch_needed(args.skip_check_status):
                add_batch(uids, args.force_reindex_running)
        else:
            try:
                state_hist = check_status()
            except EofDetectedError, e:
                print(e)
        print(get_time(), 'waiting 10 minutes, OFFSET =', OFFSET)
        sleep(60 * 10)


if __name__ == '__main__':
    main()
