#include <yamail/data/serialization/json_writer.h>
#include <yamail/data/deserialization/json_reader.h>
#include <mail_getter/AttachShieldCrypto.h>
#include <butil/butil.h>
#include <butil/crypt.h>
#include <sys/time.h>

#include <vector>
#include <stdexcept>

#include <boost/lexical_cast.hpp>

BOOST_FUSION_DEFINE_STRUCT((mail_getter)(part_id), TimedTemporary,
    (std::time_t, ts)
    (std::string, stid)
    (std::string, hid)
)

BOOST_FUSION_DEFINE_STRUCT((mail_getter)(part_id), TimedSingleMessagePart,
    (std::time_t, ts)
    (std::string, uid)
    (std::string, mid)
    (std::string, hid)
)

BOOST_FUSION_DEFINE_STRUCT((mail_getter)(part_id), TimedMultipleMessagePart,
    (std::time_t, ts)
    (std::string, uid)
    (std::string, mid)
    (std::vector<std::string>, hids)
)

BOOST_FUSION_DEFINE_STRUCT((mail_getter), aesSid,
    (std::string, aesKeyId)
    (std::string, hmacKeyId)
    (std::string, ivBase64)
    (std::string, sidBase64)
    (std::string, hmacBase64)
)

namespace {

using std::string;
using std::vector;

// replace '+' -> '*' for Opera 10-0 and some other browsers
string encodeChars4Opera( const string & str )
{
    string result = "";
    result.reserve(str.length());
    for (size_t i = 0; i < str.length(); i++ ) {
        if ( str[i] == '+')
            result += '*';
        else
            result += str[i];
    }
    return result;
}

string decodeChars4Opera( const string & str )
{
    string result = "";
    result.reserve(str.length());
    for (size_t i = 0; i < str.length(); i++ ) {
        if ( str[i] == '*')
            result += '+';
        else
            result += str[i];
    }
    return result;
}

const char delimiter = ':';

const std::string temporaryFormat = "temporary";
const std::string singleMessagePartFormat = "single_message_part";
const std::string multipleMessagePartFormat = "multiple_message_part";
const std::string aesSidFormat = "aes_sid";

enum class SidFormat {
    old,
    temporary,
    singleMessagePart,
    multipleMessagePart,
    aesSid,
};

SidFormat getSidFormat(const std::string& sid) {
    const auto end = sid.find(delimiter);
    if (end <= aesSidFormat.size() && std::equal(sid.begin(), sid.begin() + static_cast<long>(end), aesSidFormat.begin())) {
        return SidFormat::aesSid;
    }
    if (end == std::string::npos) {
        throw mail_getter::attach_sid::IllformedSid("Unknown sid format: " + sid);
    }
    if (end <= temporaryFormat.size() && std::equal(sid.begin(), sid.begin() + static_cast<long>(end), temporaryFormat.begin())) {
        return SidFormat::temporary;
    }
    if (end <= singleMessagePartFormat.size() && std::equal(sid.begin(), sid.begin() + static_cast<long>(end), singleMessagePartFormat.begin())) {
        return SidFormat::singleMessagePart;
    }
    if (end <= multipleMessagePartFormat.size() && std::equal(sid.begin(), sid.begin() + static_cast<long>(end), multipleMessagePartFormat.begin())) {
        return SidFormat::multipleMessagePart;
    }
    return SidFormat::old;
}

} // namespace

namespace mail_getter::attach_sid {

using Keys = attach_sid::Keys;
using namespace crypto;

void parseSid(const std::string &decryptedSid, std::time_t &ts, std::string &mid, std::string &hid) {
    auto tsEndPos = decryptedSid.find(delimiter);
    auto midEndPos = decryptedSid.rfind(delimiter);
    if (tsEndPos == string::npos || midEndPos == tsEndPos) {
        throw InvalidSid("invalid sid: invalid data");
    }
    if (!boost::conversion::try_lexical_convert(decryptedSid.substr(0, tsEndPos), ts)) {
        throw InvalidSid("invalid sid: invalid timestamp format");
    }
    mid = decryptedSid.substr(tsEndPos + 1, midEndPos - tsEndPos -1);
    hid = decryptedSid.substr(midEndPos + 1);
}

aesSid encrypt(std::string_view data, const Keys& keys, const blob& iv) {
    const std::string ivStr = {reinterpret_cast<const char *>(iv.data()), iv.size()};

    const blob sid = aesEncrypt(data, keys.aesKey, iv);
    const std::string sidStr = {reinterpret_cast<const char *>(sid.data()), sid.size()};

    const std::string ivBase64 = encode_base64(ivStr);
    const std::string sidBase64 = encode_base64(sidStr);
    const std::string hmacBase64 = encode_base64(hmac(ivStr + sidStr, keys.hmacKey));

    return aesSid{keys.aesKeyId, keys.hmacKeyId, ivBase64, sidBase64, hmacBase64};
}

std::time_t timestamp() {
    struct timeval tv;
    if (gettimeofday(&tv, 0) < 0) {
        throw std::runtime_error(std::string("timestamp error: ") + std::strerror(errno));
    }
    return tv.tv_sec;
}

std::string Packer::operator()(const part_id::Old& partId) const {
    using yamail::data::serialization::toJson;
    string in = boost::lexical_cast<string>(_expirationTime)
                + delimiter + partId.stid + delimiter + partId.hid;
    const auto sid = encrypt(in, _keys, _iv);

    return encodeChars4Opera(encode_base64(aesSidFormat + ":" + toJson(sid).str()));
}

std::string Packer::operator()(part_id::Temporary partId) const {
    using yamail::data::serialization::toJson;
    using part_id::TimedTemporary;
    const TimedTemporary timed {_expirationTime, std::move(partId.stid), std::move(partId.hid)};
    const auto data = temporaryFormat + ":" + toJson(timed).str();
    const auto sid = encrypt(data, _keys, _iv);

    return encodeChars4Opera(encode_base64(aesSidFormat + ":" + toJson(sid).str()));
}

std::string Packer::operator()(part_id::SingleMessagePart partId) const {
    using yamail::data::serialization::toJson;
    using part_id::TimedSingleMessagePart;
    const TimedSingleMessagePart timed {_expirationTime, std::move(partId.uid),
                                  std::move(partId.mid), std::move(partId.hid)};
    const auto data = singleMessagePartFormat + ":" + toJson(timed).str();
    const auto sid = encrypt(data, _keys, _iv);

    return encodeChars4Opera(encode_base64(aesSidFormat + ":" + toJson(sid).str()));
}

std::string Packer::operator()(part_id::MultipleMessagePart partId) const {
    using yamail::data::serialization::toJson;
    using part_id::TimedMultipleMessagePart;
    const TimedMultipleMessagePart timed {_expirationTime, std::move(partId.uid),
                                      std::move(partId.mid), std::move(partId.hids)};
    const auto data = multipleMessagePartFormat + ":" + toJson(timed).str();
    const auto sid = encrypt(data, _keys, _iv);

    return encodeChars4Opera(encode_base64(aesSidFormat + ":" + toJson(sid).str()));
}

part_id::Variant Unpacker::operator()(const std::string& sid) const {
    using namespace part_id;
    using yamail::data::deserialization::fromJson;

    const std::string in = decode_base64(decodeChars4Opera(sid));

    std::string decrypted;
    if (in.size() > aesSidFormat.size() && aesSidFormat == in.substr(0, aesSidFormat.size())) {
        auto crypted = fromJson<aesSid>(in.substr(aesSidFormat.size() + 1));

        const auto aesKey = _keyContainer.aesKeys.find(crypted.aesKeyId);
        if (aesKey == _keyContainer.aesKeys.end()) {
            throw InvalidAesKeyId("invalid sid: unknown aesKeyId '" + crypted.aesKeyId + "'");
        }

        const auto hmacKey = _keyContainer.hmacKeys.find(crypted.hmacKeyId);
        if (hmacKey == _keyContainer.hmacKeys.end()) {
            throw InvalidHmacKeyId("invalid sid: unknown hmacKeyId '" + crypted.hmacKeyId + "'");
        }

        const std::string ivStr = decode_base64(crypted.ivBase64);
        const std::string sidStr = decode_base64(crypted.sidBase64);
        if (encode_base64(hmac(ivStr + sidStr, hmacKey->second)) != crypted.hmacBase64) {
            throw InvalidHmacKeyId("invalid sid: hmac did not match");
        }
        const blob iv = {ivStr.data(), ivStr.data() + ivStr.size()};
        const blob sid = {sidStr.data(), sidStr.data() + sidStr.size()};
        try {
            decrypted = aesDecrypt(sid, aesKey->second, iv);
        } catch (const std::runtime_error& e) {
            throw DecryptionError(e.what());
        }
    } else {
        throw OldSidFormat("Old sid format");
    }

    auto format = getSidFormat(decrypted);
    const auto currentTs = timestamp();
    const auto checkTs = [&] (auto ts) {
        if (currentTs > ts) {
            throw ExpiredSid("invalid sid: timed out by " + std::to_string(currentTs - ts)
                + " seconds sid=" + decrypted);
        }
    };

    switch (format) {
        case SidFormat::old: {
            part_id::Old result;
            std::time_t ts;
            parseSid(decrypted, ts, result.stid, result.hid);
            checkTs(ts);
            return result;
        }
        case SidFormat::temporary: {
            auto timed = fromJson<TimedTemporary>(decrypted.substr(temporaryFormat.size() + 1));
            checkTs(timed.ts);
            return Temporary {std::move(timed.stid), std::move(timed.hid)};
        }
        case SidFormat::singleMessagePart: {
            auto timed = fromJson<TimedSingleMessagePart>(decrypted.substr(singleMessagePartFormat.size() + 1));
            checkTs(timed.ts);
            return SingleMessagePart {std::move(timed.uid), std::move(timed.mid), std::move(timed.hid)};
        }
        case SidFormat::multipleMessagePart: {
            auto timed = fromJson<TimedMultipleMessagePart>(decrypted.substr(multipleMessagePartFormat.size() + 1));
            checkTs(timed.ts);
            return MultipleMessagePart {std::move(timed.uid), std::move(timed.mid), std::move(timed.hids)};
        }
        case SidFormat::aesSid: {
            throw IllformedSid("Unexpected sid format: " + std::to_string(static_cast<int>(format)));
        }
    }
    throw IllformedSid("Unmatched sid format: " + std::to_string(static_cast<int>(format)));
}

} // namespace mail_getter
