#include "parse_anyway.h"

#include "tvm_proxy.h"
#include "unittest.h"

#include <library/cpp/colorizer/colors.h>
#include <library/cpp/tvmauth/deprecated/service_context.h>
#include <library/cpp/tvmauth/deprecated/user_context.h>
#include <library/cpp/tvmauth/src/parser.h>
#include <library/cpp/tvmauth/src/utils.h>

#include <util/string/split.h>

using namespace NTvmAuth;

namespace NPassport::NTvmknife {
    TString TBaseResult::PrintBase() {
        TStringStream s;

        s << "Version: " << Version_ << Endl;
        s << "Expire time: " << ExpireTime_
          << " == " << TInstant::Seconds(ExpireTime_).ToStringLocalUpToSeconds()
          << " (" << (ExpireTime_ < time(nullptr) ? "expired" : "ok") << ")" << Endl;
        s << "Key id: " << KeyId_ << Endl;
        s << "Sign: ";
        switch (Sign_) {
            case ESign::Ok:
                s << "ok";
                break;
            case ESign::Broken:
                s << "broken";
                break;
            case ESign::NoSign:
                s << "no sign";
                break;
            case ESign::MalformedKeys:
                s << "keys malformed";
                break;
            case ESign::MissingKeys:
                s << "keys missing";
                break;
        }
        s << Endl;
        s << "Integral check: " << StatusToString(IntegralCheck_).data() << Endl;

        return s.Str();
    }

    TString TServiceResult::Print() {
        TStringStream s;

        s << PrintBase();
        s << Endl;
        s << "Body" << Endl;
        s << "Src: " << Src_ << Endl;
        s << "Dst: " << Dst_ << Endl;
        for (const TStringBuf sc : Scopes_) {
            s << "Scope: " << sc << Endl;
        }
        if (Issuer_) {
            s << "Issuer: " << Issuer_ << Endl;
        } else {
            s << "Issuer: <undefined>" << Endl;
        }

        return s.Str();
    }

    void TServiceResult::ParseTicket(const ticket2::Ticket& t) {
        if (!t.has_service()) {
            ythrow yexception() << "Service ticket does not have 'service' part of protobuf";
        }
        Src_ = t.service().srcclientid();
        Dst_ = t.service().dstclientid();

        for (int idx = 0; idx < t.service().scopes_size(); ++idx) {
            Scopes_.push_back(t.service().scopes(idx));
        }

        if (t.service().has_issueruid()) {
            Issuer_ = t.service().issueruid();
        }
    }

    TString TServiceResult::CheckSign(TStringBuf tvmKeys, TStringBuf ticketFull, TStringBuf ticketStrip, TStringBuf sign) {
        TStringStream s;
        try {
            using namespace NTvmAuth;
            TServiceContext ctx = TServiceContext::CheckingFactory(Dst_, tvmKeys);
            IntegralCheck_ = ctx.Check(ticketFull).GetStatus();
        } catch (const std::exception& e) {
            s << "Integral check: " << e.what() << Endl;
        }

        if (!sign) {
            Sign_ = ESign::NoSign;
            return s.Str();
        }

        try {
            tvm_keys::Keys protoKeys;
            if (!protoKeys.ParseFromString(TParserTvmKeys::ParseStrV1(tvmKeys))) {
                Sign_ = ESign::MalformedKeys;
                ythrow yexception() << "Tvm keys are malformed";
            }

            for (int idx = 0; idx < protoKeys.tvm_size(); ++idx) {
                const tvm_keys::TvmKey& k = protoKeys.tvm(idx);
                if (k.gen().id() == KeyId_) {
                    NRw::TRwPublicKey key(k.gen().body());
                    Sign_ = key.CheckSign(ticketStrip, sign) ? ESign::Ok
                                                             : ESign::Broken;
                    break;
                }
            }
        } catch (const std::exception& e) {
            s << "Sign check: " << e.what() << Endl;
        }

        return s.Str();
    }

    TString TUserResult::Print() {
        TStringStream s;

        s << PrintBase();
        s << Endl;
        s << "Body" << Endl;
        for (TUid u : Uids_) {
            s << "Uid: " << u << Endl;
        }
        s << "Default uid: " << DefaultUid_ << Endl;
        for (const TStringBuf sc : Scopes_) {
            s << "Scope: " << sc << Endl;
        }
        s << "Entry point: " << EntryPoint_ << Endl;
        s << "Env: ";
        switch (Env_) {
            case tvm_keys::Prod:
                s << "prod";
                break;
            case tvm_keys::ProdYateam:
                s << "prod_yateam";
                break;
            case tvm_keys::Test:
                s << "test";
                break;
            case tvm_keys::TestYateam:
                s << "test_yateam";
                break;
            case tvm_keys::Stress:
                s << "stress";
                break;
        }
        s << Endl;

        return s.Str();
    }

    void TUserResult::ParseTicket(const ticket2::Ticket& t) {
        if (!t.has_user()) {
            ythrow yexception() << "User ticket does not have 'user' part of protobuf";
        }

        for (int idx = 0; idx < t.user().users_size(); ++idx) {
            Uids_.push_back(t.user().users(idx).uid());
        }
        DefaultUid_ = t.user().defaultuid();
        for (int idx = 0; idx < t.user().scopes_size(); ++idx) {
            Scopes_.push_back(t.user().scopes(idx));
        }
        EntryPoint_ = t.user().entrypoint();
        Env_ = t.user().env();
    }

    TString TUserResult::CheckSign(TStringBuf tvmKeys, TStringBuf ticketFull, TStringBuf ticketStrip, TStringBuf sign) {
        TStringStream s;
        try {
            using namespace NTvmAuth;
            EBlackboxEnv env = EBlackboxEnv::Prod;
            switch (Env_) {
                case tvm_keys::Prod:
                    env = EBlackboxEnv::Prod;
                    break;
                case tvm_keys::ProdYateam:
                    env = EBlackboxEnv::ProdYateam;
                    break;
                case tvm_keys::Test:
                    env = EBlackboxEnv::Test;
                    break;
                case tvm_keys::TestYateam:
                    env = EBlackboxEnv::TestYateam;
                    break;
                case tvm_keys::Stress:
                    env = EBlackboxEnv::Stress;
                    break;
            }

            TUserContext ctx(env, tvmKeys);
            IntegralCheck_ = ctx.Check(ticketFull).GetStatus();
        } catch (const std::exception& e) {
            s << "Integral check: " << e.what() << Endl;
        }

        if (!sign) {
            Sign_ = ESign::NoSign;
            return s.Str();
        }

        try {
            tvm_keys::Keys protoKeys;
            if (!protoKeys.ParseFromString(TParserTvmKeys::ParseStrV1(tvmKeys))) {
                Sign_ = ESign::MalformedKeys;
                ythrow yexception() << "Tvm keys are malformed";
            }

            for (int idx = 0; idx < protoKeys.bb_size(); ++idx) {
                const tvm_keys::BbKey& k = protoKeys.bb(idx);
                if (k.env() != Env_) {
                    continue;
                }

                if (k.gen().id() == KeyId_) {
                    NRw::TRwPublicKey key(k.gen().body());
                    Sign_ = key.CheckSign(ticketStrip, sign) ? ESign::Ok
                                                             : ESign::Broken;
                    break;
                }
            }
        } catch (const std::exception& e) {
            s << "Sign check: " << e.what() << Endl;
        }

        return s.Str();
    }

    TString TAnywayParser::ParseV3(TStringBuf ticket, TStringBuf tvmKeys) {
        TStrings strs = ParseStringsV3(ticket);
        ticket2::Ticket t;
        if (!t.ParseFromArray(strs.ProtoBin.data(), strs.ProtoBin.size())) {
            ythrow yexception() << "Protobuf is invalid";
        }

        std::unique_ptr<TBaseResult> res;
        if (strs.Type == TParserTickets::ServiceFlag()) {
            res = std::make_unique<TServiceResult>();
        } else if (strs.Type == TParserTickets::UserFlag()) {
            res = std::make_unique<TUserResult>();
        } else {
            ythrow yexception() << "Type is unknown: " << strs.Type;
        }

        res->ExpireTime_ = t.expirationtime();
        res->KeyId_ = t.keyid();
        res->ParseTicket(t);
        res->CheckSign(tvmKeys, ticket, strs.ForCheck, strs.SignBin);

        return res->Print();
    }

    TAnywayParser::TStrings TAnywayParser::ParseStringsV3(TStringBuf body) {
        TStringBuf forCheck = body;

        TStringBuf ver = body.NextTok(TParserTickets::DELIM);
        if (ver != "3") {
            ythrow yexception() << "Unsupported version: " << TString(ver);
        }

        TStringBuf type = body.NextTok(TParserTickets::DELIM);
        if (!type) {
            ythrow yexception() << "Type is empty";
        }

        TStringBuf proto = body.NextTok(TParserTickets::DELIM);
        if (!proto) {
            ythrow yexception() << "Encoded proto is empty";
        }
        TStringBuf sign = body.NextTok(TParserTickets::DELIM);
        if (body) {
            ythrow yexception() << "Body has something after sign";
        }

        TString protoBin = NUtils::Base64url2bin(proto);
        if (!protoBin) {
            ythrow yexception() << "Encoded proto is invalid base64url";
        }
        TString signBin = NUtils::Base64url2bin(sign);

        return {type,
                TStringBuf(forCheck.data(), forCheck.size() - sign.size()),
                protoBin,
                signBin};
    }

    TString ParseAnyTicket(const TString& strTicket) {
        return TAnywayParser::ParseV3(strTicket, TTvmProxy::PublicKeys());
    }

    TString ParseAnyTicketUnittest(const TString& strTicket) {
        TStringStream s;
        s << TAnywayParser::ParseV3(strTicket, TUnittest::GetPublicKeys());

        NColorizer::TColors c(true);
        s << c.BrownColor()
          << "Using public keys from 'tvmknife unittest public_keys' for checking of signature"
          << c.OldColor()
          << Endl;
        return s.Str();
    }
}
