#include "imap_command.h"
#include "search/node.h"
#include "search/factory.h"
#include <common/errors.h>
#include <common/uid_map.h>
#include <common/types.h>
#include <common/settings.h>
#include <common/sequence_ranges.h>
#include <common/quoted_string.h>
#include <yplatform/yield.h>

namespace yimap {

struct SearchBase : ImapSelectedCommand
{
    using YieldCtx = yplatform::yield_context<SearchBase>;

    const bool uidMode;
    FolderRef mailbox;
    string charset;
    string originalRequest;
    SearchNodePtr searchTree;
    SeqRangePtr bounds;
    UidMapPtr searchScope;
    std::vector<SearchNodePtr> nodesForAsyncPresearch;
    std::vector<SearchNodePtr> nodesForAsyncSearch;
    std::vector<SearchNodePtr>::const_iterator it;
    Timer timer;
    ErrorCode ec;
    UidMapPtr result;

    SearchBase(ImapCommandArgs& cmdArgs, bool uidMode)
        : ImapSelectedCommand(cmdArgs), uidMode(uidMode), timer(imapContext->ioService)
    {
    }

    void exec() override
    {
        yplatform::spawn(ioService(), yplatform::shared_from(this));
    }

    void operator()(YieldCtx yieldCtx)
    {
        reenter(yieldCtx)
        {
            mailbox = imapContext->sessionState.selectedFolder;
            charset = getCharset(tree);
            originalRequest = node_to_string(getSearchRequestTree(tree).value);
            searchTree = buildSearchTree(
                getSearchRequestTree(tree), charset, settings_->searchTreeMaxHeight);

            if (needThrottle())
            {
                yield wait(yieldCtx.capture(ec));
                // Ignore ec
            }
            updateStats();

            // XXX extract
            nodesForAsyncPresearch = filterNodesForAsyncPresearch(searchTree);
            for (it = nodesForAsyncPresearch.begin(); it != nodesForAsyncPresearch.end(); ++it)
            {
                yield presearchAndStoreResultInTree(searchTree, mailbox).then(yieldCtx);
            }

            bounds = buildBoundSeqset(searchTree, mailbox);
            yield loadMetadata(searchTree, mailbox, bounds).then(yieldCtx.capture(searchScope));
            if (searchScope->empty()) return completeOkNoMessages();

            // XXX extract
            nodesForAsyncSearch = filterNodesForAsyncSearch(searchTree);
            for (it = nodesForAsyncSearch.begin(); it != nodesForAsyncSearch.end(); ++it)
            {
                yield searchAndStoreResultInTree(
                    *it, originalRequest, mailbox, yieldCtx.capture(ec));
                if (ec) return completeNoBackendError(ec.message());
            }

            result = collectResults(searchTree, mailbox, searchScope);
            sendFoundMessages(result);

            if (uidMode)
            {
                yield updateFolderAndSendDiff().then(yieldCtx);
            }
            completeOk();
        }
    }

    void operator()(YieldCtx::exception_type exception)
    {
        try
        {
            std::rethrow_exception(exception);
        }
        catch (const InvalidSeqSet& e)
        {
            completeBadInvalidFormat(e.what());
        }
        catch (const InvalidFormat& e)
        {
            completeBadInvalidFormat(e.what());
        }
        catch (const UnsupportedEncodingError& /*e*/)
        {
            completeNoUnsupportedEncoding();
        }
        catch (const std::exception& e)
        {
            completeNoBackendError(e.what());
        }
        catch (...)
        {
            completeNoBackendError();
        }
    }

    bool needThrottle()
    {
        auto& throttlingSettings = settings_->searchSettings.throttling;
        auto& searchState = imapContext->searchState;
        if (searchState.requestsCount < throttlingSettings.notThrottledRequestsCount)
        {
            return false;
        }
        if (Clock::now() - searchState.penultRequestTs < throttlingSettings.delayAfterPenultRequest)
        {
            return true;
        }
        if (Clock::now() - searchState.lastRequestTs < throttlingSettings.delayAfterLastRequest)
        {
            return true;
        }
        return false;
    }

    template <typename Handler>
    void wait(const Handler& handler)
    {
        auto& throttlingSettings = settings_->searchSettings.throttling;
        auto& searchState = imapContext->searchState;
        auto delay = std::max(
            searchState.lastRequestTs + throttlingSettings.delayAfterLastRequest - Clock::now(),
            searchState.penultRequestTs + throttlingSettings.delayAfterPenultRequest -
                Clock::now());
        timer.expires_from_now(delay);
        timer.async_wait(handler);
    }

    void updateStats()
    {
        imapContext->searchState.updateStats();
    }

    string getCharset(CommandASTPtr tree)
    {
        auto& cargs = tree->data.trees[0].children;
        auto* keyNode = (uidMode ? &cargs[1].children[0] : &cargs[1]);
        if (keyNode->value.id().to_long() == lex_ids::SEARCH_CHARSET)
        {
            return quoted_string(*keyNode);
        }
        return {};
    }

    const TreeNode& getSearchRequestTree(CommandASTPtr tree)
    {
        if (uidMode)
        {
            auto index = containsCharset(tree) ? 1 : 0;
            return tree->data.trees[0].children[1].children[index];
        }
        else
        {
            auto index = containsCharset(tree) ? 2 : 1;
            return tree->data.trees[0].children[index];
        }
    }

    bool containsCharset(CommandASTPtr tree)
    {
        Trees& cargs = tree->data.trees[0].children;
        const TreeNode* keyNode = (uidMode ? &cargs[1].children[0] : &cargs[1]);
        return keyNode->value.id().to_long() == lex_ids::SEARCH_CHARSET;
    }

    std::vector<SearchNodePtr> filterNodesForAsyncPresearch(SearchNodePtr tree)
    {
        std::vector<SearchNodePtr> result;
        auto visitor = [&result](auto&& node) {
            if (auto casted = nodeCast<SearchNodeMessageId>(node))
            {
                result.push_back(casted);
            }
        };
        visitSearchTree(tree, visitor);
        return result;
    }

    Future<void> presearchAndStoreResultInTree(SearchNodePtr tree, FolderRef& mailbox)
    {
        if (auto casted = nodeCast<SearchNodeMessageId>(tree))
        {
            return metaBackend->getMessagesByMessageId(mailbox, casted->messageId)
                .then([casted](auto future) { casted->found = future.get(); });
        }
        else
        {
            throw std::runtime_error("invalid node type");
        }
    }

    SeqRangePtr buildBoundSeqset(SearchNodePtr tree, FolderRef& mailbox)
    {
        auto wholeMailbox = fullSeqRange(mailbox);

        if (complexRequest(tree))
        {
            return wholeMailbox;
        }

        auto numbers = std::make_shared<seq_range>(mailbox.seqRange(false));
        auto uids = std::make_shared<seq_range>(mailbox.seqRange(true));

        auto collectRanges = [&](auto node) {
            if (auto casted = nodeCast<SearchNodeSeqsetUid>(node))
            {
                *uids += casted->seqRangeFor(mailbox);
            }
            else if (auto casted = nodeCast<SearchNodeSeqset>(node))
            {
                *numbers += casted->seqRangeFor(mailbox);
            }
            else if (auto casted = nodeCast<SearchNodeMessageId>(node))
            {
                if (uidMode && casted->found && casted->found->size())
                {
                    *uids += casted->found->uidRange();
                }
            }
        };

        visitSearchTree(tree, collectRanges);

        return selectBounds(numbers, uids, wholeMailbox);
    }

    Future<UidMapPtr> loadMetadata(SearchNodePtr tree, FolderRef& mailbox, SeqRangePtr bounds)
    {
        auto partial = !needFullMessages(tree);
        return metaBackend->loadMessages(mailbox, *bounds, partial);
    }

    std::vector<SearchNodePtr> filterNodesForAsyncSearch(SearchNodePtr tree)
    {
        std::vector<SearchNodePtr> result;
        auto visitor = [&result](auto&& node) {
            if (auto casted = nodeCast<SearchByExpression>(node))
            {
                result.push_back(casted);
            }
        };
        visitSearchTree(tree, visitor);
        return result;
    }

    template <typename Handler>
    void searchAndStoreResultInTree(
        SearchNodePtr tree,
        const string& originalRequest,
        FolderRef& mailbox,
        Handler handler)
    {
        if (auto casted = nodeCast<SearchByExpression>(tree))
        {
            auto storeResult = [casted, handler](auto ec, auto result) mutable {
                casted->found = std::move(result);
                handler(ec);
            };
            searchBackend->search(
                casted->searchExpression, originalRequest, mailbox.fid(), storeResult);
        }
        else
        {
            throw std::runtime_error("invalid node type");
        }
    }

    UidMapPtr collectResults(SearchNodePtr tree, FolderRef& mailbox, UidMapPtr metadata)
    {
        if (tree->children.size())
        {
            std::vector<UidMapPtr> results;
            for (auto&& child : tree->children)
            {
                results.push_back(collectResults(child, mailbox, metadata));
            }
            auto casted = nodeCast<SearchNodeLogical>(tree);
            if (!casted) throw std::logic_error("incorrect search node type");
            return casted->reduce(metadata, results);
        }
        if (auto casted = nodeCast<SearchByExpression>(tree))
        {
            return casted->reduce(metadata);
        }
        else if (auto casted = nodeCast<SearchNodeMessageId>(tree))
        {
            return casted->reduce(metadata);
        }
        else if (auto casted = nodeCast<SearchNodeStatic>(tree))
        {
            return casted->reduce(metadata);
        }
        else if (auto casted = nodeCast<SearchNodeSeqset>(tree))
        {
            return casted->reduce(metadata, mailbox);
        }
        return metadata;
    }

    SeqRangePtr selectBounds(SeqRangePtr numbers, SeqRangePtr uids, SeqRangePtr wholeMailbox)
    {
        bool bothEmpty = numbers->empty() && uids->empty();
        bool bothFull = !numbers->empty() && !uids->empty();
        if (bothFull || bothEmpty)
        {
            return wholeMailbox;
        }
        else if (uids->empty())
        {
            return numbers;
        }
        else
        {
            return uids;
        }
    }

    bool complexRequest(SearchNodePtr tree)
    {
        auto isOr = std::dynamic_pointer_cast<SearchNodeOr>(tree);
        auto isNot = std::dynamic_pointer_cast<SearchNodeNot>(tree);
        if (isOr || isNot) return true;

        for (auto&& child : tree->children)
        {
            if (complexRequest(child)) return true;
        }
        return false;
    }

    bool needFullMessages(SearchNodePtr tree)
    {
        bool result = false;
        auto filter = [&result](auto node) {
            if (nodeCast<SearchByExpression>(node) || nodeCast<SearchNodeDate>(node) ||
                nodeCast<SearchNodeSize>(node))
            {
                result = true;
            }
        };
        visitSearchTree(tree, filter);
        return result;
    }

    SeqRangePtr fullSeqRange(FolderRef& mailbox)
    {
        auto result = std::make_shared<seq_range>(mailbox.seqRange(false));
        *result += range_t(1, mailbox.info().messageCount);
        return result;
    }

    template <typename Handler>
    void visitSearchTree(SearchNodePtr tree, Handler h) const
    {
        for (auto&& child : tree->children)
        {
            visitSearchTree(child, h);
        }
        h(tree);
    }

    void sendFoundMessages(UidMapPtr messages)
    {
        std::ostringstream result_stream;
        result_stream << "* SEARCH";

        while (!messages->empty())
        {
            const MessageData msg = messages->pop();
            result_stream << ' ' << (uidMode ? msg.uid : msg.num);
        }

        result_stream << "\r\n";
        sendClient() << result_stream.str();
    }

    void completeOkNoMessages()
    {
        sendClient() << "* SEARCH\r\n";
        completeOk("", "Completed (no messages).");
    }

    void completeNoUnsupportedEncoding()
    {
        completeNo("[BADCHARSET]", "Unsupported text encoding.");
    }

    void completeNoBackendError(const string& error = "unknown"s)
    {
        logError() << "SEARCH error=" << error;
        completeNo("[UNAVAILABLE]", "Backend error.");
    }

    void completeBadInvalidFormat(const string& error)
    {
        logError() << "SEARCH error=" << error;
        completeBad("", "Invalid input format" + error);
    }

    ExtraStatFields statExtra() const override
    {
        ExtraStatFields stats;
        if (searchTree)
        {
            std::map<string, int> nodeCounts;
            auto calcAllNodes = [&nodeCounts](auto&& node) { nodeCounts[node->nodeName()]++; };
            visitSearchTree(searchTree, calcAllNodes);
            for (auto&& [name, count] : nodeCounts)
            {
                stats[name] = std::to_string(count);
            }
        }
        return stats;
    }
};

struct Search : SearchBase
{
    Search(ImapCommandArgs& cmdArgs) : SearchBase(cmdArgs, false)
    {
    }
};

struct UidSearch : SearchBase
{
    UidSearch(ImapCommandArgs& cmdArgs) : SearchBase(cmdArgs, true)
    {
    }
};

CommandPtr CommandSearch(ImapCommandArgs& commandArgs)
{
    return CommandPtr(new Search(commandArgs));
}

CommandPtr CommandUidSearch(ImapCommandArgs& commandArgs)
{
    return CommandPtr(new UidSearch(commandArgs));
}
}
