package ru.yandex.direct.web.core.security.csrf;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.StandardCharsets;
import java.util.function.Supplier;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;

import ru.yandex.direct.utils.HashingUtils;

@Component
public class CsrfValidator {
    public static final int CORRECT_MIME_DECODED_LENGTH = Integer.BYTES * 3;
    public static final long DEFAULT_CSRF_TOKEN_LIFETIME = 24L * 60 * 60;

    private final String secret;
    private final Supplier<Long> currentTimeSecondsSupplier;
    private final long csrfTokenLifetime;

    public CsrfValidator(String secret, long csrfTokenLifetime, Supplier<Long> currentTimeSecondsSupplier) {
        this.secret = secret;
        this.csrfTokenLifetime = csrfTokenLifetime;
        this.currentTimeSecondsSupplier = currentTimeSecondsSupplier;
    }

    @Autowired
    public CsrfValidator(@Value("${csrf.secret_phrase}") String secret) {
        this(secret, DEFAULT_CSRF_TOKEN_LIFETIME, () -> System.currentTimeMillis() / 1000);
    }

    public String createCsrfToken(long uid) {
        long time = currentTimeSecondsSupplier.get();

        long sign = urlHash(time, uid);
        int signLow = (int) sign;
        int hiBits = (int) time ^ signLow;
        int midBits = (int) (sign >>> Integer.SIZE);

        ByteBuffer bb = ByteBuffer.allocate(CORRECT_MIME_DECODED_LENGTH).order(ByteOrder.BIG_ENDIAN);
        bb.putInt(hiBits).putInt(midBits).putInt(signLow);
        return HashingUtils.encode64Ya(bb.array());
    }

    public boolean checkCsrfToken(String csrfToken, long uid) throws CsrfValidationFailureException {
        byte[] mimeDecoded;
        try {
            mimeDecoded = HashingUtils.decode64Ya(csrfToken);
        } catch (IllegalArgumentException ex) {
            throw new CsrfValidationFailureException(ex);
        }
        if (mimeDecoded.length != CORRECT_MIME_DECODED_LENGTH) {
            throw new CsrfValidationFailureException(String.format("wrong mime-decoded length: got %d, expect %d",
                    mimeDecoded.length, CORRECT_MIME_DECODED_LENGTH));
        }

        ByteBuffer bb = ByteBuffer.wrap(mimeDecoded);
        long csrfTime = Integer.toUnsignedLong(bb.getInt());
        long sign = bb.getLong();
        long signLow = Integer.toUnsignedLong((int) sign);
        long timeFromToken = csrfTime ^ signLow;

        long checkSign = urlHash(timeFromToken, uid);
        return checkSign == sign && (currentTimeSecondsSupplier.get() - timeFromToken) < csrfTokenLifetime;
    }

    private long urlHash(long time, long uid) {
        String src = String.format("%d:%d:%s", time, uid, secret);
        byte[] md5Bytes = HashingUtils.getMd5Hash(src.getBytes(StandardCharsets.US_ASCII));
        return ByteBuffer.wrap(md5Bytes, 0, Integer.BYTES * 2).getLong();
    }
}
