#include "tcp.h"
#include "data_stream.h"

#include <solomon/agent/misc/logger.h>

#include <library/cpp/neh/asio/asio.h>
#include <library/cpp/neh/asio/executor.h>

#include <util/datetime/base.h>
#include <util/generic/size_literals.h>
#include <util/stream/mem.h>

#include <atomic>

using namespace NAsio;
using namespace NAddr;

using TTcpSocketPtr = TAtomicSharedPtr<TTcpSocket>;

namespace NSolomon {
namespace NAgent {

using THandlerPtr = THolder<TTcpServer::TConnectionHandler>;

namespace {

class TConnection : public IConnection {
public:
    TConnection(THandlerPtr handler, TTcpSocketPtr socket, size_t bufferSize, ITcpServerStatusListener* listener)
        : Handler_{std::move(handler)}
        , Socket_{socket}
        , RemoteAddress_{Socket_->RemoteEndpoint().IpToString()}
        , Stream_{bufferSize}
        , Listener_{listener}
    {
        SA_LOG(INFO) << "Connection from " << RemoteAddress_;
    }

    ~TConnection() {
        auto time = TInstant::Now() - Start_;
        SA_LOG(INFO) << "Received " << ReadTotal_ << " from " << RemoteAddress_ << " in " << time.MilliSeconds() << "ms";

        if (Listener_ != nullptr) {
            Listener_->OnRequestCompleted(ReadTotal_, TInstant::Now() - Start_);
        }
    }

    void Write(const TBlob& data) override {
        Socket_->WriteSome(data.Data(), data.Length());
    }

    void Close() override {
        Closed_ = 1;
        Socket_->AsyncCancel();

        auto blob = Stream_.AsBlob();
        Handler_->OnConnectionClose(blob);
    }

    const TString& RemoteAddress() const override {
        return RemoteAddress_;
    }

    auto MakeReadHandler() {
        return [self = TIntrusivePtr<TConnection>(this)] (const TErrorCode& err, size_t amount, IHandlingContext&) {
            self->HandleRead(err, amount);
        };
    }

    static void Serve(TTcpServer::TConnectionHandler* handler, TTcpSocketPtr socket, size_t bufferSize, ITcpServerStatusListener* listener) {
        SetNonBlock(socket->Native());

        auto self = ::MakeIntrusive<TConnection>(THolder<TTcpServer::TConnectionHandler>(handler), socket, bufferSize, listener);
        self->Handler_->OnConnectionOpen(self.Get());

        self->AsyncRead(self->MakeReadHandler());
    }

    bool IsClosed() const {
        return Closed_.load(std::memory_order_relaxed) || !Socket_->IsOpen();
    }

    void AsyncRead(NAsio::TTcpSocket::TReadHandler readHandler) {
        Socket_->AsyncReadSome(Stream_.ReadBuf(), Stream_.Avail(), readHandler, TDuration::Seconds(1));
    }

    void HandleRead(const TErrorCode& err, size_t amount) {
        if (IsClosed() || amount == 0 || err) {
            Close();
            return;
        }

        ReadTotal_ += amount;
        Stream_.MarkWritten(amount);

        auto data = Stream_.AsBlob();
        const size_t bytesRead = Handler_->OnData(data);
        Stream_.MarkRead(bytesRead);

        if (Stream_.IsFull()) {
            SA_LOG(ERROR) << "Buffer overflow";
            Close();
            return;
        }

        AsyncRead(MakeReadHandler());
    }


private:
    THandlerPtr Handler_;
    TTcpSocketPtr Socket_;
    TString RemoteAddress_;

    TDataStream Stream_;

    std::atomic_bool Closed_ {0};

    ITcpServerStatusListener* Listener_ {nullptr};
    ui64 ReadTotal_ {0};
    TInstant Start_ {TInstant::Now()};
};

} // namespace

class TTcpServer::TImpl {
public:
    explicit TImpl(ICallbackPtr cb, const TTcpServer::TOptions& opts, ITcpServerStatusListener* listener)
        : Address_{opts.BindAddress, opts.BindPort}
        , Options_{opts}
        , Callback_{cb}
        , WorkerThreads_{Options_.ThreadCount}
        , Listener_{listener}
    {
        Y_ENSURE(Options_.ThreadCount > 0);
    }

    void Start() {
        for (auto it = Address_.Begin(); it != Address_.End(); ++it) {
            TEndpoint endpoint{new TAddrInfo(&*it)};
            Acceptors_.emplace_back(new TTcpAcceptor{AcceptorThread_.GetIOService()});

            auto* acceptor = Acceptors_.back().Get();
            acceptor->Bind(endpoint);
            acceptor->Listen(Options_.BacklogSize);
            SA_LOG(DEBUG) << "start listening on " << endpoint.IpToString();

            Accept(acceptor);
        }
    }

    void Accept(TTcpAcceptor* acceptor) {
        auto socket = ::MakeAtomicShared<TTcpSocket>(GetWorkerThread());

        acceptor->AsyncAccept(*socket, [=] (const TErrorCode& err, IHandlingContext&) {
            HandleAccept(err, socket, acceptor);
        });
    }

    TIOService& GetWorkerThread() {
        if (WorkerThreads_.Size() == 0) {
            return AcceptorThread_.GetIOService();
        }

        return WorkerThreads_.GetExecutor().GetIOService();
    }

    void HandleAccept(const TErrorCode& err, TTcpSocketPtr socket, TTcpAcceptor* acceptor) {
        if (!err) {
            TConnection::Serve(Callback_->CreateHandler(), socket, Options_.BufferSize, Listener_);
        } else if (err && err.Value() == ECANCELED) {
            return;
        } else {
            SA_LOG(WARN) << "Error in accept: " << err.Text();
        }

        Accept(acceptor);
    }

    void Stop() {
        WorkerThreads_.SyncShutdown();
        AcceptorThread_.SyncShutdown();
        Acceptors_.clear();
    }

private:
    TNetworkAddress Address_;
    TOptions Options_;
    ICallbackPtr Callback_;

    TVector<THolder<TTcpAcceptor>> Acceptors_;
    TIOServiceExecutor AcceptorThread_;
    TExecutorsPool WorkerThreads_;

    ITcpServerStatusListener* Listener_;
};

TTcpServer::TTcpServer(ICallbackPtr cb, const TOptions& opts, ITcpServerStatusListener* listener)
    : Impl_{::MakeHolder<TImpl>(cb, opts, listener)}
{
}

TTcpServer::~TTcpServer() {
}

void TTcpServer::Start() {
    Impl_->Start();
}

void TTcpServer::Stop() {
    Impl_->Stop();
}

} // namespace NAgent
} // namespace NSolomon
