#pragma once

#include <tvm_guard/types.h>
#include <tvm_guard/helpers.h>

#include <boost/range/algorithm/find_if.hpp>
#include <boost/variant/apply_visitor.hpp>
#include <boost/range/algorithm/find.hpp>

#include <library/c/tvmauth/tvmauth_wrapper.h>

namespace tvm_guard {
namespace detail {

inline ResponseFactory& withRuleDefaultAction(ResponseFactory& factory, Action defaultAction, Reason errorReason) {
    return factory.action(defaultAction)
            .reason(defaultAction == Action::accept ? Reason::ruleDefaultPolicy : errorReason);
}

template <class String>
struct UserTicketVisitor {
    const bool strongUidCheck;
    const String& uidFromArgs;
    const Action defaultRuleAction;
    ResponseFactory& factory;

    Response operator ()(const boost::system::error_code& error) const {
        return withRuleDefaultAction(factory, defaultRuleAction, Reason::wrongUserTicket).error(error).product();
    }

    template <class T>
    Response operator ()(const T& userTicket) const {
        if (!TicketTraits<T>::ticketIsSuccessful(userTicket)) {
            return withRuleDefaultAction(factory, defaultRuleAction, Reason::wrongUserTicket)
                    .error(boost::system::error_code(userTicket.GetStatus())).product();
        }

        std::vector<std::string> strUids;
        const auto uids = userTicket.GetUids();
        std::transform(uids.begin(), uids.end(), std::back_inserter(strUids), [] (const u_int64_t& uid) {
            return std::to_string(uid);
        });
        const auto defaultUid = TicketTraits<T>::defaultUid(userTicket);

        if (strongUidCheck) {
            if (std::find(strUids.begin(), strUids.end(), uidFromArgs) == strUids.end()) {
                return factory.action(Action::reject)
                                .reason(Reason::uidsMismatch)
                                .uidsFromUserTicket(strUids)
                                .defaultUid(defaultUid)
                                .product();
            }
        }

        return factory.action(Action::accept).reason(Reason::rule).uidsFromUserTicket(strUids)
                .defaultUid(defaultUid).product();
    }
};

template <class Tvm2Module, class String, class OptionalString>
struct ServiceTicketVisitor {
    const std::shared_ptr<Tvm2Module>& tvm;
    const Action defaultAction;
    const boost::optional<TA_EBlackboxEnv> bbEnv;
    const boost::optional<u_int32_t> rootClientId;
    const OptionalString& user;
    const Rule& rule;
    const bool strongUidCheck;
    const String& uidFromArgs;

    Response operator ()(const boost::system::error_code& error) const {
        ResponseFactory factory;
        return withRuleDefaultAction(factory, rule.defaultAction, Reason::wrongServiceTicket)
                .error(error).product();
    }

    template <class T>
    Response operator ()(const T& serviceTicket) const {
        ResponseFactory factory;

        if (!TicketTraits<T>::ticketIsSuccessful(serviceTicket)) {
            return withRuleDefaultAction(factory, rule.defaultAction, Reason::wrongServiceTicket)
                    .error(boost::system::error_code(serviceTicket.GetStatus())).product();
        }

        const u_int32_t src = serviceTicket.GetSrc();
        factory.source(src);

        if (rootClientId && src == *rootClientId) {
            if (TicketTraits<T>::issuerUidIsEmpty(serviceTicket)) {
                return factory.action(Action::reject).reason(Reason::rootServiceTicketWithoutUid).product();
            }

            return factory.action(Action::accept)
                            .reason(Reason::rootServiceTicket)
                            .issuerUid(TicketTraits<T>::issuerUid(serviceTicket))
                            .product();
        }

        if (rule.acceptByUser.count(src)) {
            if (!bbEnv) {
                throw Exception("Cannot check user ticket without 'bb_env'");
            }

            auto ticket = tvm->get_native_user_ticket_or_error(*bbEnv, user ? detail::toString(*user) : std::string());
            return boost::apply_visitor(
                detail::UserTicketVisitor<String> {strongUidCheck, uidFromArgs, rule.defaultAction, factory},
                ticket
            );
        }

        if (rule.acceptByService.count(src)) {
            return factory.action(Action::accept).reason(Reason::rule).product();
        }

        return factory.action(rule.defaultAction).reason(Reason::unknownService).product();
    }
};

} // namespace detail

template<class Tvm2Module>
class Guard {
    std::vector<Rule> rules;
    std::shared_ptr<Tvm2Module> tvm;
    Action defaultAction;
    boost::optional<TA_EBlackboxEnv> bbEnv;
    boost::optional<u_int32_t> rootClientId;
    bool strongUidCheck;

public:
    Guard(std::vector<Rule> r, Action action, boost::optional<TA_EBlackboxEnv> env,
          boost::optional<u_int32_t> rootId, bool strongUidCheck, std::shared_ptr<Tvm2Module> module)
        : rules(std::move(r))
        , tvm(module)
        , defaultAction(action)
        , bbEnv(env)
        , rootClientId(rootId)
        , strongUidCheck(strongUidCheck)
    { }

    template <class String, class OptionalString>
    Response check(const String& path, const String& uidFromArgs,
                   const OptionalString& service, const OptionalString& user) const {

        ResponseFactory factory;

        const auto it = boost::find_if(rules, [&path] (const Rule& r) {
            return boost::find(r.paths, path) != r.paths.end();
        });

        if (it == rules.end()) {
            return factory.action(defaultAction).reason(Reason::defaultPolicy).product();
        }

        const Rule& rule = *it;

        if (!service || service->empty()) {
            return factory.action(rule.defaultAction).reason(Reason::ruleDefaultPolicy).product();
        }

        auto ticket = tvm->get_native_service_ticket_or_error(detail::toString(*service));

        return boost::apply_visitor(
            detail::ServiceTicketVisitor<Tvm2Module, String, OptionalString> {
                tvm, defaultAction, bbEnv, rootClientId, user, rule, strongUidCheck, uidFromArgs
            },
            ticket
        );
    }

    template<class Request>
    Response check(const Request& req) const {
        return check(endpoint(req),
                     req.getOptionalArg("uid").value_or(""),
                     req.getOptionalHeader(header::serviceTicket()),
                     req.getOptionalHeader(header::userTicket()));
    }

    std::shared_ptr<Tvm2Module> module() const {
        return tvm;
    }

    std::ostream& print (std::ostream& out) const;
    template <typename Out>
    void to_tskv(Out& out, const char* key) const;
private:
    template<class Request>
    static auto endpoint(const Request& req) -> decltype(req.method) {
        return req.method;
    }

    template<class Request>
    static auto endpoint(const Request& req) -> decltype(req.endpoint()) {
        return req.endpoint();
    }
};

}
