# -*- coding: utf-8 -*-
import time
import requests
import threading
import itertools
from multiprocessing.pool import ThreadPool
from argparse import ArgumentParser

from mpfs.engine.process import setup_anyone_script
setup_anyone_script()

from mpfs.dao.session import Session


HOST = 'http://mpfs.disk.yandex.net'
RM_URL = HOST + '/json/rm'
DRY_RUN = True
PRINT_LOCK = threading.Lock()
THREADS_NUM = 10

GET_ALL_FILES = """
WITH RECURSIVE
recurse AS (
    SELECT
        dir.fid, dir.parent_fid, :dir_path as path
    FROM
        disk.folders dir
    WHERE
        dir.uid = :uid AND
        dir.fid = (select fid from code.path_to_fid(:dir_path, :uid))
UNION
    SELECT
        child.fid, child.parent_fid, path || '/' || child.name
    FROM
        recurse parent join disk.folders child on parent.fid = child.parent_fid
    WHERE
        child.uid = :uid
)
SELECT path || '/' || name as path FROM
disk.files f INNER JOIN recurse folder ON f.uid = :uid AND folder.fid = f.parent_fid
"""

GET_ALL_FOLDERS = """
WITH RECURSIVE
recurse AS (
    SELECT
        dir.fid, dir.parent_fid, :dir_path as path
    FROM
        disk.folders dir
    WHERE
        dir.uid = :uid AND
        dir.fid = (select fid from code.path_to_fid(:dir_path, :uid))
UNION
    SELECT
        child.fid, child.parent_fid, path || '/' || child.name
    FROM
        recurse parent join disk.folders child on parent.fid = child.parent_fid
    WHERE
        child.uid = :uid
)
SELECT path FROM recurse ORDER BY length(path) DESC
"""

def thread_save_print(msg):
    with PRINT_LOCK:
        print msg


def parallel_do(func, data):
    pool = ThreadPool(processes=THREADS_NUM)
    pool.imap_unordered(func, data)
    pool.close()
    pool.join()


def do_request(url, params):
    retries = 3
    while True:
        try:
            resp = requests.get(url, params=params)
            resp.raise_for_status()
        except Exception as e:
            if retries > 0:
                print "Retry: %r" % e
                time.sleep(5)
            else:
                raise
            retries -= 1
        else:
            return resp


def do_rm(data):
    uid, path = data
    thread_save_print('%s: %s:%s' % ('dry_rm' if DRY_RUN else 'rm', uid, path.encode('utf8')))
    if not DRY_RUN:
        do_request(RM_URL, {'uid': uid, 'path': path})


def remove_all_files_in_dir(uid, dir_path):
    if not dir_path.startswith('/disk/') and not dir_path.startswith('/trash'):
        raise NotImplementedError()

    params = {
        'uid': int(uid),
        'dir_path': dir_path,
    }

    session = Session.create_from_uid(uid)

    # delete files in parallel
    cursor = session.execute(GET_ALL_FILES, params)
    data_source = ((uid, i['path']) for i in cursor)
    parallel_do(do_rm, data_source)

    # delete folder
    cursor = session.execute(GET_ALL_FOLDERS, params)
    for depth, folders in itertools.groupby(cursor, key=lambda x: x['path'].count('/')):
        if depth <= 1:
            continue
        data_source = ((uid, i['path']) for i in folders)
        parallel_do(do_rm, data_source)


if __name__ == '__main__':
    parser = ArgumentParser(description="Remove all files in dir.")
    parser.add_argument('-u', dest='uid', type=str, required=True, help='uid')
    parser.add_argument('-p', dest='path', type=str, required=True, help='dir path')
    parser.add_argument('-r', dest='remove', action='store_true', help='do rm')
    args = parser.parse_args()

    uid, path, DRY_RUN = args.uid, args.path, not args.remove
    if path.startswith('/disk/'):
        assert path != '/disk/'
    elif path.startswith('/trash'):
        pass
    else:
        raise NotImplementedError()

    remove_all_files_in_dir(uid, path)
