# coding: utf-8



import os
import logging

from at.aux_.entries import models
from at.common import dbswitch


FORMAT = '%(asctime)-15s %(levelname)s %(message)s'
formatter = logging.Formatter(FORMAT)

log = logging.getLogger()
log.handlers = []
stream_handler = logging.StreamHandler()
file_handler = logging.FileHandler('/tmp/%s.log' % os.path.basename(__file__))
stream_handler.setFormatter(formatter)
file_handler.setFormatter(formatter)

log.addHandler(stream_handler)
log.addHandler(file_handler)
log.setLevel(logging.WARNING)


def count(type):
    sql = """
SELECT count(*)
FROM `EntryXmlContent`
WHERE wf_version = '1'
    """
    if type == 'posts':
        sql += '\nAND comment_id = 0'
    else:
        sql += '\nAND comment_id != 0'

    with dbswitch.root_rw_session() as conn:
        row = conn.execute(sql).fetchone()
        return row[0]


def get_all_entry_ids():
    sql = """
SELECT content.feed_id, content.post_no, content.comment_id
FROM EntryXmlContent as content
WHERE wf_version = '1'
    """
    with dbswitch.root_rw_session() as conn:
        return list(map(tuple, (conn.execute(sql).fetchall())))


def get_model(id):
    return models.load_entry(*id)


def convert_using_model(id=None, entry=None):
    if entry is None:
        entry = get_model(id)
    entry.body = entry.body_original
    entry.save()


def chunked(iterable, chunk_size=1000):
    chunk = []
    for item in iterable:
        chunk.append(item)
        if len(chunk) == chunk_size:
            yield chunk
            chunk = []
    if chunk:
        yield chunk


def convert_all_parallel(limit=None, chunk_size=1000):
    import multiprocessing
    import math

    processes = min(16, multiprocessing.cpu_count())
    print('processes: %s' % processes)
    pool = multiprocessing.Pool(processes=processes)

    all_ids = get_all_entry_ids()
    total_count = len(all_ids)
    ids_in_job = int(math.ceil(total_count / float(processes)))
    args = [
        (
            all_ids[ids_in_job * process_id: ids_in_job * (process_id + 1)],
            process_id,
            limit,
            chunk_size,
        )
        for process_id in range(processes)
    ]
    results = pool.map(convert_ids, args)
    pool.close()
    pool.join()

    return results


def convert_all(limit=None, chunk_size=1000):
    all_ids = get_all_entry_ids()
    return convert_ids((all_ids, 0, limit, chunk_size))


def convert_ids(mega_arg):
    if len(mega_arg) == 4:
        ids, process_id, limit, chunk_size = mega_arg
    elif len(mega_arg) == 3:
        ids, process_id, limit = mega_arg
        chunk_size = 1000
    elif len(mega_arg) == 2:
        ids, process_id = mega_arg
        limit, chunk_size = None, 1000
    elif len(mega_arg) == 1:
        ids, = mega_arg
        process_id = 0
        limit, chunk_size = None, 1000
    else:
        raise RuntimeError('not enough args')

    skipped = []

    counter = 0
    total_count = len(ids)
    if limit and limit < chunk_size:
        chunk_size = limit

    for chunk in chunked(ids, chunk_size):
        entries = models.load_entries_by_ids(chunk)
        for entry in entries:
            handled_part = counter / float(total_count)
            percent = '[%s] %.2f%%' % (process_id, handled_part * 100)
            print(percent, entry._web_url)
            try:
                convert_using_model(entry=entry)
            except Exception as exc:
                log.exception('Fail')
                skipped.append((entry, str(exc)))
            finally:
                counter += 1
                if limit and counter >= limit:
                    print('limit %s' % limit)
                    return skipped
    return skipped
