import logging

import numpy as np
import torch
import torch.nn as nn
from sklearn.base import BaseEstimator
from torch import optim
from torch.nn import functional as F

logger = logging.getLogger(__name__)


class ArrangerModel(nn.Module, BaseEstimator):
    def __init__(self, input_size, epochs=1, lr=1e-3, top_n=3, disable_cuda=False, batch_size=1):
        super(ArrangerModel, self).__init__()
        self.features = nn.Sequential(
            nn.Linear(input_size, 50),
            nn.ReLU(inplace=True),
            nn.Linear(50, 25),
            nn.ReLU(inplace=True),
            nn.Linear(25, 1)
        )
        self.epochs = epochs
        self.lr = lr
        self.top_n = top_n
        self.batch_size = batch_size
        self.log_frequency = 5000
        if not disable_cuda and torch.cuda.is_available():
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')
        self.to(self.device)
        logger.debug('Model on device: %s', self.device.type)

    def iterate_minibatches(self, X, y):
        training_data = X.drop_columns(['user']).as_2d_array()
        x_batch, y_batch = [], []
        for i, (key, idx) in enumerate(X.arggroupby('user')):
            if len(idx) == 1 or np.sum(y[idx]) == 0:
                continue  # no useful info
            idx = sorted(idx, key=lambda x: -y[x])  # for accuracy computation
            x_batch.append(training_data[idx])
            y_batch.append(y[idx])
            if len(x_batch) == self.batch_size:
                yield x_batch, y_batch
                x_batch, y_batch = [], []

    def fit(self, X, y):
        X.assert_has_columns(['user'])
        self.train()
        logger.debug('Model fit with batch size %d', self.batch_size)

        optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=self.lr)

        def intersec_k(prediction):
            return np.mean(np.argsort(prediction)[:self.top_n] < self.top_n)

        for epoch in range(self.epochs):
            loss_log, intersec_log = [], []
            for i, (x_batch, y_batch) in enumerate(self.iterate_minibatches(X, y)):
                optimizer.zero_grad()
                # limits for each element in batch
                index = np.cumsum([0] + [x.shape[0] for x in x_batch])

                data = torch.tensor(np.concatenate(x_batch), dtype=torch.float)
                target = np.array(np.concatenate(y_batch), dtype=np.float)
                target = torch.tensor(target/np.sum(target), dtype=torch.float)
                if self.device.type == 'cuda':
                    data, target = to_cuda(data, target)

                # loss computation and GD
                output = self(data, index)
                loss = -torch.sum(output * target)
                loss.backward()
                optimizer.step()
                loss_log.append(loss.item())

                prediction = -output.cpu().detach().numpy()
                intersec_total = np.mean([intersec_k(prediction[index[j]:index[j + 1]]) for j in range(len(index) - 1)])
                intersec_log.append(intersec_total)
                if i and (i * self.batch_size) % self.log_frequency < self.batch_size:
                    logger.debug('%s epoch %d train %d loss %.4f intersection %.4f',
                                 self.device.type, epoch, i*self.batch_size, np.mean(loss_log), np.mean(intersec_log))

            logger.debug('final quality on %d epoch: loss %.4f intersection %.4f',
                         epoch, np.mean(loss_log), np.mean(intersec_log))

    def forward(self, X, index=None):
        logits = self.features(X).view(-1)
        if (index is None) or self.batch_size == 1:
            return F.log_softmax(logits, dim=0)
        result = []
        for i in range(len(index) - 1):
            left, right = index[i], index[i + 1]
            result.append(F.log_softmax(logits[left: right], dim=0))
        return torch.cat(result, 0)

    def predict_proba(self, X):
        X.assert_has_columns(['user'])
        self.eval()

        result = np.zeros(X.shape[0])
        input_data = X.drop_columns(['user']).as_2d_array()
        for (key, idx) in X.arggroupby('user'):
            data = torch.tensor(input_data[idx], dtype=torch.float)
            if self.device.type == 'cuda':
                data = data.cuda()
            result[idx] = np.exp(self(data).cpu().detach().numpy())
        return result


def to_cuda(*args):
    return (arg.cuda() for arg in args)
