#include <fcntl.h>

#include "classip.h"
#include "log.h"
#include "net.h"

TNet::TNet(TIP IPAddr) {
    IP = IPAddr;
}
TNet::~TNet() {
    Down();
}

void TNet::Down() {
    if (!SocketOwner)
        return;

    if (SocketTCP > 0) {
        if (DEBUG)
            Log("Closing TCP socket [%s]:%u", IP.GetIPChar(), IP.GetPort());
        shutdown(SocketTCP, SHUT_RDWR);
        close(SocketTCP);
    }
    if (SocketUDP > 0) {
        if (DEBUG)
            Log("Closing UDP socket [%s]:%u", IP.GetIPChar(), IP.GetPort());
        shutdown(SocketUDP, SHUT_RDWR);
        close(SocketUDP);
    }
}

void TNet::SetSocketOwner() {
    SocketOwner = true;
}

void TNet::ResetSocketOwner() {
    SocketOwner = false;
}

int TNet::GetTCPSocket() {
    return SocketTCP;
}

int TNet::GetUDPSocket() {
    return SocketUDP;
}

const struct timeval *TNet::GetTimeout() {
    return &Timeout;
}

void TNet::SetTimeout(struct timeval *TO) {
    memcpy(&Timeout, TO, sizeof(struct timeval));
}

TMyBuffer *TNet::GetInBuffer() {
    return &IBuf;
}

TMyBuffer *TNet::GetOutBuffer() {
    return &OBuf;
}

TIP *TNet::GetIP() {
    return &IP;
}

const Time *TNet::GetLastByteSendOKTime() {
    return &LastByteSendOKTime;
}

const Time *TNet::GetLastByteRecvOKTime() {
    return &LastByteRecvOKTime;
}

int TNet::SetNonBlockTCP() {
    return SetNonBlock(SocketTCP);
}

int TNet::SetNonBlockUDP() {
    return SetNonBlock(SocketUDP);
}

int TNet::SetNonBlock(int Socket) {
    int flags;
    if ((flags = fcntl(Socket, F_GETFL, 0)) < 0 || fcntl(Socket, F_SETFL, flags | O_NONBLOCK) < 0) {
        Log("Can't make NONBLOCK socket for [%s]:%u", IP.GetIPChar(), IP.GetPort());
        Down();
        return -1;
    }
    return Socket;
}

int TNet::SetBufSizeTCP(uint32_t sockbufsize) {
    return SetBufSize(sockbufsize, SocketTCP);
}

int TNet::SetBufSizeUDP(uint32_t sockbufsize) {
    return SetBufSize(sockbufsize, SocketUDP);
}

int TNet::SetBufSize(uint32_t sockbufsize, int Socket) {
    if (    setsockopt(Socket, SOL_SOCKET, SO_SNDBUF, &sockbufsize, sizeof(sockbufsize)) < 0 ||
            setsockopt(Socket, SOL_SOCKET, SO_RCVBUF, &sockbufsize, sizeof(sockbufsize)) < 0) {
        Log("Can't set socket bufsize [%s]:%u - %s (%i)", IP.GetIPChar(), IP.GetPort(), strerror(errno), errno);
        Down();
        return -1;
    }
    return Socket;
}

int TNet::CreateUDP() {
    if ((SocketUDP = socket(IP.GetFamily(), SOCK_DGRAM, IPPROTO_UDP)) < 0) {
        Log("Can't open socket dgram [%s]:%u in CreateUDP() - %s (%i)", IP.GetIPChar(), IP.GetPort(), strerror(errno), errno);
        return -1;
    }
    return SocketUDP;
}

int TNet::ListenUDP() {
    if (CreateUDP() < 0)
        return -1;
    if (((IP.GetFamily() == AF_INET) ?
            bind(SocketUDP, (struct sockaddr *)IP.GetSockAddr(), sizeof(struct sockaddr_in)) :
            bind(SocketUDP, (struct sockaddr *)IP.GetSockAddr6(), sizeof(struct sockaddr_in6))) < 0) {
        close(SocketUDP);
        Log("Can't bind to [%s]:%u (%u) in ListenUDP() - %s (%i)", IP.GetIPChar(), IP.GetPort(), SocketUDP, strerror(errno), errno);
        return -1;
    }
    Log("Listening udp on [%s]:%u", IP.GetIPChar(), IP.GetPort());
    return SocketUDP;
}

ssize_t TNet::SendUDP() {
    ssize_t Length = 0;
    if (OBuf.size() == 0)
        return 0;
    if ((Length = (IP.GetFamily() == AF_INET) ?
            sendto(SocketUDP, OBuf.data(), OBuf.size(), 0, (struct sockaddr *)IP.GetSockAddr(), sizeof(struct sockaddr_in)) :
            sendto(SocketUDP, OBuf.data(), OBuf.size(), 0, (struct sockaddr *)IP.GetSockAddr6(), sizeof(struct sockaddr_in6))) < 0) {
        if (errno == EAGAIN || errno == EWOULDBLOCK)
            return 0;
        else {
            Log("Can't send UDP to [%s]:%u (%u) in SendUDP() - %s (%i)", IP.GetIPChar(), IP.GetPort(), SocketUDP, strerror(errno), errno);
            return -1;
        }
    }
    if (static_cast<size_t>(Length) == OBuf.size())
        OBuf.clear();
    else
        OBuf.erase(0, Length);
    LastByteSendOKTime.Update();
    return Length;
}

std::unique_ptr<TNet>TNet::RecvUDP() {
    int Length = 0;
    struct sockaddr_in6 SockAddr = {0, 0, 0, IN6ADDR_ANY_INIT, 0};
    socklen_t SockAddrLength = sizeof(struct sockaddr_in6);

    TNet *Peer = new TNet();
    Peer->IBuf.resize(DEFAULTBUFFERLENGTH);
    if ((Length = recvfrom(SocketUDP, (void *)Peer->IBuf.data(), DEFAULTBUFFERLENGTH, 0, (struct sockaddr *)&SockAddr, &SockAddrLength)) < 0) {
        delete Peer;
        Log("Can't recv UDP for [%s]:%u (%u) in RecvUDP() - %s (%i)", IP.GetIPChar(), IP.GetPort(), SocketUDP, strerror(errno), errno);
        return nullptr;
    }
    Peer->IBuf.resize(Length);
    Peer->SocketUDP = SocketUDP; // XXX - Peer reply automatically will be from the input socket
    Peer->SocketOwner = false;
    Peer->IP.ConstructFromSockaddr((struct sockaddr *)&SockAddr);
    LastByteRecvOKTime.Update();
    return std::unique_ptr<TNet>(Peer);
}

void TNet::RecvUDP(TNet *Peer) {
    int Length = 0;
    struct sockaddr_in6 SockAddr = {0, 0, 0, IN6ADDR_ANY_INIT, 0};
    socklen_t SockAddrLength = sizeof(struct sockaddr_in6);

    Peer->IBuf.resize(DEFAULTBUFFERLENGTH);
    if ((Length = recvfrom(SocketUDP, (void *)Peer->IBuf.data(), DEFAULTBUFFERLENGTH, 0, (struct sockaddr *)&SockAddr, &SockAddrLength)) < 0) {
        Log("Can't recv UDP for [%s]:%u (%u) in RecvUDP() - %s (%i)", IP.GetIPChar(), IP.GetPort(), SocketUDP, strerror(errno), errno);
        return;
    }
    Peer->IBuf.resize(Length);
    Peer->SocketUDP = SocketUDP; // XXX - Peer reply automatically will be from the input socket
    Peer->SocketOwner = false;
    Peer->IP.ConstructFromSockaddr((struct sockaddr *)&SockAddr);
    LastByteRecvOKTime.Update();
}

int TNet::ListenTCP() {
    if ((SocketTCP = socket(IP.GetFamily(), SOCK_STREAM, 0)) < 0) {
        Log("Can't open socket stream [%s]:%u in ListenTCP() - %s (%i)", IP.GetIPChar(), IP.GetPort(), strerror(errno), errno);
        return -1;
    }
    int yes = 1;
    if (setsockopt(SocketTCP, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof(int)) < 0) {
        close(SocketTCP);
        Log("Can't set socket options [%s]:%u in ListenTCP() - %s (%i)", IP.GetIPChar(), IP.GetPort(), strerror(errno), errno);
        return -1;
    }
    if (((IP.GetFamily() == AF_INET) ?
            bind(SocketTCP, (struct sockaddr *)IP.GetSockAddr(), sizeof(struct sockaddr_in)) :
            bind(SocketTCP, (struct sockaddr *)IP.GetSockAddr6(), sizeof(struct sockaddr_in6))) < 0) {
        close(SocketTCP);
        Log("Can't bind to [%s]:%u in ListenTCP() - %s (%i)", IP.GetIPChar(), IP.GetPort(), strerror(errno), errno);
        return -1;
    }
    if (listen(SocketTCP, 10) < 0) {
        close(SocketTCP);
        Log("Can't listen on [%s]:%u in ListenTCP() - %s (%i)", IP.GetIPChar(), IP.GetPort(), strerror(errno), errno);
        return -1;
    }
    Log("Listening tcp on [%s]:%u", IP.GetIPChar(), IP.GetPort());
    return SocketTCP;
}

int TNet::ConnectTCP(bool Block) {
    int err;
    socklen_t err_l = sizeof(err);
    fd_set rset;

    if ((SocketTCP = socket(IP.GetFamily(), SOCK_STREAM, 0)) < 0) {
        Log("Can't open socket stream for [%s]:%u in ConnectTCP() - %s (%i)", IP.GetIPChar(), IP.GetPort(), strerror(errno), errno);
        return -1;
    }
    if (!Block)
        if (SetNonBlockTCP() < 0)
            return -1;
    if (((IP.GetFamily() == AF_INET) ?
            connect(SocketTCP, (struct sockaddr *)IP.GetSockAddr(), sizeof(struct sockaddr_in)) :
            connect(SocketTCP, (struct sockaddr *)IP.GetSockAddr6(), sizeof(struct sockaddr_in6))) < 0 && errno != EINPROGRESS) {
        Down();
        Log("Can't connect to [%s]:%u (%u) in ConnectTCP() - %s (%i)", IP.GetIPChar(), IP.GetPort(), SocketTCP, strerror(errno), errno);
        return -1;
    }
    if (Block) {
        FD_ZERO(&rset);
        FD_SET(SocketTCP, &rset);
        if (select(SocketTCP + 1, NULL, &rset, NULL, &Timeout) <= 0 || getsockopt(SocketTCP, SOL_SOCKET, SO_ERROR, &err, &err_l) < 0 || err) {
            Down();
            Log("Failed to connect to [%s]:%u (%u) in ConnectTCP() - %s (%i)", IP.GetIPChar(), IP.GetPort(), SocketTCP, strerror(errno), errno);
            return -1;
        }
        Log("Connected tcp to [%s]:%u", IP.GetIPChar(), IP.GetPort());
    }
    return SocketTCP;
}

int TNet::PollTCPSocket() {
    struct timeval ZeroTimeout = {0, 0};
    fd_set rset;

    FD_ZERO(&rset);
    FD_SET(SocketTCP, &rset);
    return select(SocketTCP + 1, NULL, &rset, NULL, &ZeroTimeout);
}

int TNet::TCPSocketError() {
    int err;
    socklen_t err_l = sizeof(err);

    if (getsockopt(SocketTCP, SOL_SOCKET, SO_ERROR, &err, &err_l) < 0 || err)
        return errno;
    return 0;
}

std::unique_ptr<TNet>TNet::AcceptTCP() {
    int Sock;
    struct sockaddr_in6 SockAddr = {0, 0, 0, IN6ADDR_ANY_INIT, 0};
    socklen_t SockAddrLength = sizeof(struct sockaddr_in6);
    if ((Sock = accept(SocketTCP, (struct sockaddr *)&SockAddr, &SockAddrLength)) < 0) {
        Log("Can't accept for [%s]:%u (%u) in AcceptTCP() - %s (%i)", IP.GetIPChar(), IP.GetPort(), SocketTCP, strerror(errno), errno);
        return NULL;
    }
    TNet *Peer = (IP.GetFamily() == AF_INET) ?
                new TNet((struct sockaddr_in *)&SockAddr) :
                new TNet((struct sockaddr_in6 *)&SockAddr);
    Peer->SocketTCP = Sock;
    return std::unique_ptr<TNet>(Peer);
}

ssize_t TNet::SendTCP() {
    ssize_t Length = 0;
    if (OBuf.size() == 0)
        return 0;
    if ((Length = send(SocketTCP, OBuf.data(), OBuf.size(), 0)) < 0) {
        if (errno == EAGAIN || errno == EWOULDBLOCK)
            return 0;
        else {
            Log("Can't send TCP to [%s]:%u (%u) in SendTCP() - %s (%i)", IP.GetIPChar(), IP.GetPort(), SocketTCP, strerror(errno), errno);
            return -1;
        }
    }
    if (static_cast<size_t>(Length) == OBuf.size())
        OBuf.clear();
    else
        OBuf.erase(0, Length);
    LastByteSendOKTime.Update();
    return Length;
}

int TNet::RecvTCP() {
    int Length = 0;
    size_t Size = IBuf.size();
    IBuf.resize(Size + DEFAULTBUFFERLENGTH);
    if ((Length = recv(SocketTCP, &IBuf[Size], DEFAULTBUFFERLENGTH, 0)) < 0)
    {
        IBuf.resize(Size);
        if (errno == EAGAIN || errno == EWOULDBLOCK)
            return 0;
        else {
            Log("Can't recieve TCP for [%s]:%u (%u) in RecvTCP() - %s (%i)", IP.GetIPChar(), IP.GetPort(), SocketTCP, strerror(errno), errno);
            return -1;
        }
    }
    IBuf.resize(Size + Length);
    LastByteRecvOKTime.Update();
    return Length;
}
