import numpy as np
import yt.wrapper as yt
try:
    from crypta.lib.nirvana.email_organization.vectorize import (
        vectorize_one,
        MAX_LEN,
        DICT_SIZE,  # noqa
        DOMAIN_DICT_SIZE
    )
except ImportError:
    import sys
    sys.path.append('vectorize/pymodule')
    from vectorize import (MAX_LEN, vectorize_one)
    from vectorize import DOMAIN_DICT_SIZE, DICT_SIZE  # noqa


def make_output(login, domain, is_org):
    X = dict(login_input=login, domain_input=domain)
    return (X, is_org)


def generator_wrapper(path, batch_size=25):
    batch_idx = 0
    X_login = np.zeros((batch_size, MAX_LEN), dtype=np.float16)
    X_domain = np.zeros((batch_size, DOMAIN_DICT_SIZE), dtype=np.float16)
    y = np.zeros(batch_size, dtype=np.float16)
    while 1:
        for row in yt.read_table(path, raw=False):
            if batch_idx == batch_size:
                batch_idx = 0
            if batch_idx == 0:
                yield make_output(X_login, X_domain, y)
                X_login = np.zeros((batch_size, MAX_LEN), dtype=np.float16)
                X_domain = np.zeros((batch_size, DOMAIN_DICT_SIZE),
                                    dtype=np.float16)
                y = np.zeros(batch_size, dtype=np.float16)
            if len(row['email'].split('@')) != 2:
                continue
            email = row['email']
            X_login[batch_idx], X_domain[batch_idx] = vectorize_one(email)
            y[batch_idx] = row['is_org']
            batch_idx += 1
        yield make_output(X_login[:batch_idx], X_domain[:batch_idx],
                          y[:batch_idx])
        batch_idx = 0
