#include "tvm_auth.h"

namespace NSearchProxy {

    void AuthorizeUserTicket(const TTvmAuthParams& params, TTvmAuthResult& result) {
        if (!params.Tickets.UserTicket) {
            result.Status = TTvmAuthResult::NoUserTicketFromClient;
            return;
        }

        TMaybe<ui64> uid = params.Clients.Main->CheckUserTicket(params.Tickets.UserTicket);
        if (!uid) {
            result.Status = TTvmAuthResult::BadUserTicket;
            return;
        }

        result.Uid = *uid;

        result.Status = params.AllowedUids.contains(*uid)
            ? TTvmAuthResult::Authorized
            : TTvmAuthResult::NotAllowedUid; // FIXME
    }

    TTvmAuthResult AuthorizeTvm(const TTvmAuthParams& params) {
        TTvmAuthResult result;

        if (!params.Tickets.ServiceTicket) {
            result.Status = TTvmAuthResult::NoServiceTicketFromClient;
            return result;
        }

        result.SrcId = params.Clients.Main->CheckServiceTicket(params.Tickets.ServiceTicket);
        if (result.SrcId.Defined()) {
            if (params.TvmProxyId && *result.SrcId == params.TvmProxyId) {
                AuthorizeUserTicket(params, result);
                return result;
            }
        } else if (params.Clients.FlowMirror) {
            result.SrcId = params.Clients.FlowMirror->CheckServiceTicket(params.Tickets.ServiceTicket);
        }

        if (!result.SrcId.Defined()) {
            result.Status = TTvmAuthResult::BadServiceTicket;
            return result;
        }

        result.Status = params.AllowedSourceIds.contains(*result.SrcId)
            ? TTvmAuthResult::Authorized
            : TTvmAuthResult::NotAllowedSrcClientId;

        return result;
    }

    bool TTvmTraits::IsAuthForced(const TCgiParameters& cgi) {
        return IsTrue(cgi.Get(NSearchProxyCgi::force_tvm_auth));
    }

    TString TTvmTraits::GetServiceTicketString(const TSearchRequestData& rd) {
        auto it = rd.CgiParam.Find(NSearchProxyCgi::ya_service_ticket);
        if (it != rd.CgiParam.end()) {
            return *&it->second;
        }

        if (const auto* t = rd.HeadersIn().FindPtr("X-Ya-Service-Ticket")) {
            return *t;
        }

        return {};
    }

    TString TTvmTraits::GetUserTicketString(const TSearchRequestData& rd) {
        auto it = rd.CgiParam.Find(NSearchProxyCgi::ya_user_ticket);
        if (it != rd.CgiParam.end()) {
            return *&it->second;
        }

        if (const auto* t = rd.HeadersIn().FindPtr("X-Ya-User-Ticket")) {
            return *t;
        }

        return {};
    }

    TTvmTraits::TTickets TTvmTraits::GetTickets(const TSearchRequestData& rd) {
        return {
            .ServiceTicket = GetServiceTicketString(rd),
            .UserTicket = GetUserTicketString(rd)
        };
    }

    namespace {
        class TMultiTvmClient: public ITvmClient {
        public:
            TMultiTvmClient(TVector<ITvmClientPtr> clients)
                : Clients(std::move(clients))
            {
            }

            TMaybe<ui32> CheckServiceTicket(const TString& ticket) const override {
                for (const auto& client : Clients) {
                    if (auto ret = client->CheckServiceTicket(ticket)) {
                        return ret;
                    }
                }
                return Nothing();
            }

            TMaybe<ui64> CheckUserTicket(const TString& ticket) const override {
                for (const auto& client : Clients) {
                    if (auto ret = client->CheckUserTicket(ticket)) {
                        return ret;
                    }
                }
                return Nothing();
            }

        private:
            TVector<ITvmClientPtr> Clients;
        };
    }

    ITvmClientPtr CreateMultiTvmClient(TVector<ITvmClientPtr> clients) {
        return clients.size() == 1
            ? std::move(clients[0])
            : ITvmClientPtr(new TMultiTvmClient(std::move(clients)));
    }
}
