#include <iostream>
#include <string>
#include <map>
#include <unordered_map>
#include <vector>
#include <memory>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <string.h>

#include <net/ethernet.h>
#include <netinet/ether.h>
#include <netinet/ip.h>
#include <netinet/ip6.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <arpa/inet.h>

#include <sys/wait.h>
#include <signal.h>

#include "buffer.h"
#include "log.h"
#include "classtime.h"
#include "classip.h"
#include "net.h"
#include "epoll.h"
#include "tcpstream.h"
#include "arguments.h"
#include "pcap.h"

//============================================================================================


int exit_flag = 0;
int verbosity = 0;
std::unique_ptr<TPcap> Pcap;
uint32_t StreamFraction = 1;


//============================================================================================


#define TCP_TIMEOUT 10
class TTCPStreamStore {
public:
    std::unique_ptr<TIP>SinkIP = nullptr;
private:
    struct STNet {
        std::unique_ptr<TNet> Net;
        bool                  Close;
        uint32_t              StreamSeqNumber;
    };
    std::unordered_map<TTCPStream, struct STNet> TCPStreamMap;
    std::unordered_map<int, TTCPStream> SocketMap;
public:
    TTCPStreamStore() {
    }
    ~TTCPStreamStore() {
    }

    uint64_t Size() {
        return TCPStreamMap.size();
    }
    void ScheduleClose(TTCPStream *TCPStream) {
        auto it = TCPStreamMap.find(*TCPStream);
        if (it != TCPStreamMap.end())
            it->second.Close = true;
        if (verbosity > 0)
            Log("Schedule closure for tcp stream %s.%u -> %s.%u",
                TCPStream->GetSrc()->GetIPChar(), TCPStream->GetSrc()->GetPort(),
                TCPStream->GetDst()->GetIPChar(), TCPStream->GetDst()->GetPort());
    }
    TNet *FindStream(TTCPStream *TCPStream) {
        auto it = TCPStreamMap.find(*TCPStream);
        if (it != TCPStreamMap.end())
            return it->second.Net.get();
        return nullptr;
    }
    TNet *FindStreamBySocket(int Socket) {
        auto it = SocketMap.find(Socket);
        if (it != SocketMap.end()) {
            return TCPStreamMap[it->second].Net.get();
        }
        return nullptr;
    }
    bool SequenceIsOk(TTCPStream *TCPStream, uint32_t TCPLoadLength, uint32_t SeqNumber) {
        auto it = TCPStreamMap.find(*TCPStream);
        if (it != TCPStreamMap.end())
            if (it->second.StreamSeqNumber == SeqNumber) {
                it->second.StreamSeqNumber += TCPLoadLength;
                return true;
            }
        return false;
    }
    int AddStream(TTCPStream *TCPStream, uint32_t SeqNumber) {
        int SocketTCP;
        std::unique_ptr<TNet>Sink = std::unique_ptr<TNet>(new TNet(*(SinkIP.get())));
        if ((SocketTCP = Sink->ConnectTCP()) < 0) {
            Log("Sink connection failed!");
            return -1;
        }
        SocketMap[Sink->GetTCPSocket()] = *TCPStream;
        TCPStreamMap[*TCPStream] = {std::move(Sink), false, SeqNumber + 1};
        if (verbosity > 0)
            Log("Add tcp stream %s.%u -> %s.%u (%u)",
                TCPStream->GetSrc()->GetIPChar(), TCPStream->GetSrc()->GetPort(),
                TCPStream->GetDst()->GetIPChar(), TCPStream->GetDst()->GetPort(), SocketTCP);
        return SocketTCP;
    }
    void DeleteStream(TTCPStream *TCPStream) {
        auto it = TCPStreamMap.find(*TCPStream);
        if (it != TCPStreamMap.end()) {
            if (verbosity > 0)
                Log("Close tcp stream %s.%u -> %s.%u",
                    TCPStream->GetSrc()->GetIPChar(), TCPStream->GetSrc()->GetPort(),
                    TCPStream->GetDst()->GetIPChar(), TCPStream->GetDst()->GetPort());
            SocketMap.erase(SocketMap.find(it->second.Net->GetTCPSocket()));
            TCPStreamMap.erase(it);
        }
    }
    void TryRemoveStreamBySocket(int Socket) {
        auto it = SocketMap.find(Socket);
        if (it != SocketMap.end()) {
            auto StreamIt = TCPStreamMap.find(it->second);
            if (StreamIt->second.Net->GetOutBuffer()->size() == 0 && StreamIt->second.Close) {
                if (verbosity > 0)
                    Log("Removing by socket empty stream %s.%u -> %s.%u",
                        StreamIt->first.GetSrc()->GetIPChar(), StreamIt->first.GetSrc()->GetPort(),
                        StreamIt->first.GetDst()->GetIPChar(), StreamIt->first.GetDst()->GetPort());
                TCPStreamMap.erase(StreamIt);
                SocketMap.erase(it);
            }
        }
    }
    void RemoveStreamBySocket(int Socket) {
        auto it = SocketMap.find(Socket);
        if (it != SocketMap.end()) {
            auto StreamIt = TCPStreamMap.find(it->second);
            if (verbosity > 0)
                Log("Removing by socket stream %s.%u -> %s.%u",
                    StreamIt->first.GetSrc()->GetIPChar(), StreamIt->first.GetSrc()->GetPort(),
                    StreamIt->first.GetDst()->GetIPChar(), StreamIt->first.GetDst()->GetPort());
            TCPStreamMap.erase(StreamIt);
            SocketMap.erase(it);
        }
    }
    void RemoveStaleStreams(Time CurrentTime) {
        for (auto it = TCPStreamMap.begin(); it != TCPStreamMap.end(); ) {
            if (CurrentTime - *(it->second.Net->GetLastByteSendOKTime()) > TCP_TIMEOUT) {
                if (verbosity > 0)
                    Log("Timeout for stream %s.%u -> %s.%u",
                        it->first.GetSrc()->GetIPChar(), it->first.GetSrc()->GetPort(),
                        it->first.GetDst()->GetIPChar(), it->first.GetDst()->GetPort());
                SocketMap.erase(SocketMap.find(it->second.Net->GetTCPSocket()));
                TCPStreamMap.erase(it++);
            }
            else if (it->second.Net->GetOutBuffer()->size() == 0 && it->second.Close) {
                if (verbosity > 0)
                    Log("Removing by cleanup empty stream %s.%u -> %s.%u",
                        it->first.GetSrc()->GetIPChar(), it->first.GetSrc()->GetPort(),
                        it->first.GetDst()->GetIPChar(), it->first.GetDst()->GetPort());
                SocketMap.erase(SocketMap.find(it->second.Net->GetTCPSocket()));
                TCPStreamMap.erase(it++);
            }
            else
                ++it;
        }
    }
};


//============================================================================================


char *utos(uint32_t n) {
    static char String[10] = {0};
    sprintf(String, "%x", n);
    return String;
}

uint16_t EthernetHandle(const u_char *Packet) {
    struct ether_header *eptr = (struct ether_header *)Packet;
    uint16_t Type = ntohs(eptr->ether_type);

    if (verbosity > 2)
        Log("ethernet (%s): %s -> %s",
            (Type == ETHERTYPE_IP) ? "IP" : (Type == ETHERTYPE_IPV6) ? "IPv6" :
            (Type == ETHERTYPE_ARP) ? "ARP" : (Type == ETHERTYPE_REVARP) ? "RARP" :
            (Type == ETHERTYPE_VLAN) ? "VLAN" : (Type == ETHERTYPE_LOOPBACK) ? "LO" : utos(Type),
            ether_ntoa((const struct ether_addr *)eptr->ether_shost),
            ether_ntoa((const struct ether_addr *)eptr->ether_dhost));
    return Type;
}

uint16_t IPHandle(const struct pcap_pkthdr *pkthdr, const u_char *Packet) {
    const struct iphdr *ip;
    uint16_t PacketLength, FragmentOffset, TotalLength, HeaderLength;

    ip = (struct iphdr *)Packet;
    PacketLength = pkthdr->len - sizeof(struct ether_header);
    if (PacketLength < sizeof(struct iphdr))
    {
        Log("truncated ip %d", PacketLength);
        return 0;
    }
    if (ip->version != 4)
    {
        Log("unknown version %d", ip->version);
        return 0;
    }
    HeaderLength = ip->ihl << 2;
    if (HeaderLength < 20) {
        Log("bad header length %d", HeaderLength);
        return 0;
    }
    TotalLength = ntohs(ip->tot_len);
    if (PacketLength < TotalLength) {
        Log("truncated IP - %d bytes missing", TotalLength - PacketLength);
        return 0;
    }
    FragmentOffset = ntohs(ip->frag_off);
    if ((FragmentOffset & 0x1fff) != 0)
        return 0;

    return HeaderLength;
}

#define IPV6HEADERLENGTH 40
uint16_t IP6Handle(const struct pcap_pkthdr *pkthdr, const u_char *Packet) {
    const struct ip6_hdr *ip6;
    uint16_t PacketLength, IPVersion;

    ip6 = (struct ip6_hdr *)Packet;
    PacketLength = pkthdr->len - sizeof(struct ether_header);
    if (PacketLength < sizeof(struct ip6_hdr))
    {
        Log("truncated ip6 %d", PacketLength);
        return 0;
    }
    IPVersion = ip6->ip6_ctlun.ip6_un2_vfc >> 4;
    if (IPVersion != 6)
    {
        Log("unknown version %d", IPVersion);
        return 0;
    }
    if (PacketLength < IPV6HEADERLENGTH) {
        Log("truncated IP - %d bytes missing", IPV6HEADERLENGTH - PacketLength);
        return 0;
    }
    return IPV6HEADERLENGTH;
}


//============================================================================================


struct LinkType_SLL {
    uint16_t Type;
    uint16_t ARPHRD_Type;
    uint16_t LinkLayerAddressLength;
    uint8_t LinkLayerAddress[8];
    uint16_t ProtocolType;
};

#define SLL_HEADER_LENGTH 16
int PacketHandler(TTCPStreamStore *pTCPStreamStore, const struct pcap_pkthdr *pkthdr, const u_char *Packet) {
    static uint32_t StreamCounter = 0;
    TTCPStream TCPStreamTemplate;
    uint16_t PacketType;
    uint16_t IPPacketOffset;
    uint16_t IPHeaderLen;
    uint32_t TCPOffset;
    uint32_t TCPLoadLength;
    const u_char *IPPacket;
    struct tcphdr *TCPHeader;

    if (pkthdr->caplen != pkthdr->len) {
        Log("Captured a part of a packet %d vs real %d bytes", pkthdr->caplen, pkthdr->len);
        return -1;
    }
    if (pkthdr->len > IP_MAXPACKET) {
        Log("The packet is too big for our buffer %d > %d", pkthdr->len, IP_MAXPACKET);
        return -1;
    }
    if (Pcap->GetIfaceLinkType() == LINKTYPE_LINUX_SLL) {
        if (pkthdr->len < sizeof(struct LinkType_SLL)) {
            Log("The packet length is too small %u bytes", pkthdr->len);
            return -1;
        }
        struct LinkType_SLL *SLL = (struct LinkType_SLL *)Packet;
        PacketType = ntohs(SLL->ProtocolType);
        IPPacketOffset = SLL_HEADER_LENGTH;
    }
    else {
        PacketType = EthernetHandle(Packet);
        IPPacketOffset = sizeof(struct ether_header);
    }

    IPPacket = Packet + IPPacketOffset;
    if (PacketType == ETHERTYPE_IP) {
        IPHeaderLen = IPHandle(pkthdr, IPPacket);
        if (IPHeaderLen == 0 || ((struct iphdr *)IPPacket)->protocol != IPPROTO_TCP)
            return -1;
    }
    else if (PacketType == ETHERTYPE_IPV6) {
        IPHeaderLen = IP6Handle(pkthdr, IPPacket);
        if (IPHeaderLen == 0 || ((struct ip6_hdr *)IPPacket)->ip6_ctlun.ip6_un1.ip6_un1_nxt != IPPROTO_TCP)
            return -1;
    }
    else
        return -1;

    TCPHeader = (struct tcphdr *)(IPPacket + IPHeaderLen);
    TCPOffset = TCPHeader->doff << 2;

    if (PacketType == ETHERTYPE_IP) {
        TCPLoadLength = ntohs(((struct iphdr *)IPPacket)->tot_len) - (((struct iphdr *)IPPacket)->ihl << 2) - TCPOffset;
        TCPStreamTemplate.Construct(
                    &((struct ip *)IPPacket)->ip_src, ntohs(TCPHeader->source),
                    &((struct ip *)IPPacket)->ip_dst, ntohs(TCPHeader->dest));
        if (verbosity > 1)
            Log("IPv4 TCP: %15s:%-5u -> %15s:%-5u size:%-5u TCPMapSize:%lu",
                inet_ntoa(((struct ip *)IPPacket)->ip_src), ntohs(TCPHeader->source),
                inet_ntoa(((struct ip *)IPPacket)->ip_dst), ntohs(TCPHeader->dest),
                TCPLoadLength, pTCPStreamStore->Size());
    }
    else {
        TCPLoadLength = ntohs(((struct ip6_hdr *)IPPacket)->ip6_ctlun.ip6_un1.ip6_un1_plen) - TCPOffset;
        TCPStreamTemplate.Construct6(
                    &((struct ip6_hdr *)IPPacket)->ip6_src, ntohs(TCPHeader->source),
                    &((struct ip6_hdr *)IPPacket)->ip6_dst, ntohs(TCPHeader->dest));
        if (verbosity > 1) {
            char IPCharSrc[INET6_ADDRSTRLEN + 1];
            char IPCharDst[INET6_ADDRSTRLEN + 1];
            inet_ntop(AF_INET6, &(((struct ip6_hdr *)IPPacket)->ip6_src), IPCharSrc, INET6_ADDRSTRLEN);
            inet_ntop(AF_INET6, &(((struct ip6_hdr *)IPPacket)->ip6_dst), IPCharDst, INET6_ADDRSTRLEN);
            Log("IPv6 TCP: %39s.%-5u -> %39s.%-5u size:%-5u TCPMapSize:%lu",
                IPCharSrc, ntohs(TCPHeader->source),
                IPCharDst, ntohs(TCPHeader->dest),
                TCPLoadLength, pTCPStreamStore->Size());
        }
    }

    TNet *pTCPStreamSink = pTCPStreamStore->FindStream(&TCPStreamTemplate);

    if (TCPHeader->syn == 1 && TCPHeader->ack == 0 && TCPHeader->fin == 0 && TCPHeader->rst == 0) {
        if ((StreamCounter = (StreamCounter + 1) % StreamFraction) != 0)
            return -1;
        if (pTCPStreamSink != nullptr)
            pTCPStreamStore->DeleteStream(&TCPStreamTemplate);
        return pTCPStreamStore->AddStream(&TCPStreamTemplate, ntohl(TCPHeader->seq));
    }
    if (pTCPStreamSink != nullptr && TCPLoadLength > 0) {
        if (pTCPStreamStore->SequenceIsOk(&TCPStreamTemplate, TCPLoadLength, ntohl(TCPHeader->seq))) {
            if (verbosity > 2)
                Log("Adding packet with len:%u and seq:%u to output buffer", TCPLoadLength, ntohl(TCPHeader->seq));
            pTCPStreamSink->GetOutBuffer()->append((char *)(IPPacket + IPHeaderLen + TCPOffset), TCPLoadLength);
        }
    }
    if (TCPHeader->fin == 1 || TCPHeader->rst == 1) {
        if (pTCPStreamSink != nullptr)
            pTCPStreamStore->ScheduleClose(&TCPStreamTemplate);
    }

    return -1;
}


//============================================================================================


class TSender {
#define POLLING_INTERVAL 10
#define PERIODIC_TIME 1
#define SOCKET_BUFFER_SIZE 1024*1024
public:
    TTCPStreamStore *pTCPStreamStore;
    TPcap *pPcap;
    int PcapFD = 0;

    TSender(TTCPStreamStore *pStore, TPcap *pCap) {
        pTCPStreamStore = pStore;
        pPcap = pCap;
        PcapFD = pPcap->GetFileDescriptor();
        pPcap->SetNonBlock();
    }
    ~TSender() {
    }

    void Start() {
        TNet *Sink;
        Time CurrentTime;
        Time NextCheckTime;
        std::unique_ptr<TEpoll>Epoll(new TEpoll());

        Epoll->Add(PcapFD, EPOLLIN);
        while (!exit_flag)
        {
            Epoll->Wait(POLLING_INTERVAL);
            for (struct epoll_event *pEpollEvent = Epoll->begin(); pEpollEvent < Epoll->end(); ++pEpollEvent) {
                int EventSocket = pEpollEvent->data.fd;

                // New packet arrived
                if (EventSocket == PcapFD) {
                    struct pcap_pkthdr *Header;
                    const u_char *Data;
                    int TCPSocket;

                    int GetpacketStatus = pPcap->GetPacket(&Header, &Data);
                    TCPSocket = PacketHandler(pTCPStreamStore, Header, Data);

                    if (GetpacketStatus < 0) {
                        exit_flag = 1;
                        Log("Getting no more packets");
                    }
                    if (TCPSocket > 0)
                        Epoll->Add(TCPSocket, EPOLLOUT | EPOLLIN);
                }
                // TCP Output could be written
                else if (pEpollEvent->events & EPOLLOUT) {
                    Sink = pTCPStreamStore->FindStreamBySocket(EventSocket);
                    size_t BufSize = Sink->GetOutBuffer()->size();

                    if (BufSize > 0) {
                        if (verbosity > 1)
                            Log("Sending data to %s!", Sink->GetIP()->GetIPChar());
                        ssize_t BytesSent = Sink->SendTCP();
                        if (BytesSent < 0) {
                            Epoll->Del(EventSocket);
                            pTCPStreamStore->RemoveStreamBySocket(EventSocket);
                        }
                        else if (static_cast<size_t>(BytesSent) == BufSize)
                            pTCPStreamStore->TryRemoveStreamBySocket(EventSocket);
                    }
                    else
                        pTCPStreamStore->TryRemoveStreamBySocket(EventSocket);
                }
                // TCP Input arrived
                else if (pEpollEvent->events & EPOLLIN) {
                    Sink = pTCPStreamStore->FindStreamBySocket(EventSocket);
                    if (Sink->RecvTCP() > 0)
                        Sink->GetInBuffer()->clear();
                }
                else if (pEpollEvent->events & EPOLLERR) {
                    int error = 0;
                    socklen_t errlen = sizeof(error);
                    if (verbosity > 0 && getsockopt(EventSocket, SOL_SOCKET, SO_ERROR, (void *)&error, &errlen) == 0)
                        Log("Socket (%i) error in epoll, %i: %s", EventSocket, error, strerror(error));
                    Epoll->Del(EventSocket);
                    pTCPStreamStore->RemoveStreamBySocket(EventSocket);
                }
                else if (pEpollEvent->events & EPOLLHUP) {
                    Fatal("Socket (%i) hangup in epoll", EventSocket);
                }
                else {
                    Fatal("Unhandled event: %d", pEpollEvent->events);
                }
            } // Socket loop

            // Maintenance
            CurrentTime.Update();
            if (CurrentTime.uSec > NextCheckTime.uSec) {
                NextCheckTime = CurrentTime + PERIODIC_TIME;
                pTCPStreamStore->RemoveStaleStreams(CurrentTime);
            }
        } // Main loop
    }
};


//============================================================================================


void signal_handler(int sig_num)
{
    Log("Signal %i", sig_num);

    switch (sig_num) {
    case SIGHUP:
        break;
    default:
        exit_flag = 1;
        while (waitpid(-1, &sig_num, WNOHANG) > 0);
        break;
    }
}


//============================================================================================


int main(int argc, char* argv[])
{
    int PacketsMaxCount;
    std::string SniffIface = "any";
    TTCPStreamStore TCPStreamStore;

    (void)signal(SIGCHLD, signal_handler);
    (void)signal(SIGPIPE, signal_handler);
    (void)signal(SIGTERM, signal_handler);
    (void)signal(SIGINT, signal_handler);
    (void)signal(SIGHUP, signal_handler);
    (void)signal(SIGQUIT, signal_handler);
    (void)signal(SIGALRM, signal_handler);

    std::unique_ptr<TArguments>Args(new TArguments("[<options>] "));
    Args->AddKey("i", TArguments::strarg, "Interface to sniff network data");
    Args->AddKey("F", TArguments::strarg, "Filter rule for network data");
    Args->AddKey("d", TArguments::strarg, "Address[:Port] to send TCP data to");
    Args->AddKey("c", TArguments::intarg, "number of packets to capture");
    Args->AddKey("n", TArguments::intarg, "send every nth stream to destination");
    Args->AddKey("v", TArguments::noarg,  "verbose output");
    Args->Parse(argc, argv);

    verbosity = Args->GetKey("v").Count;

    if (Args->GetKey("d").isPresent)
        TCPStreamStore.SinkIP = std::unique_ptr<TIP>(new TIP(Args->GetKey("d").StrValue));
    else
        Args->Usage("destination must be set");

    if (Args->GetKey("n").isPresent)
        StreamFraction = Args->GetKey("n").IntValue;

    PacketsMaxCount = (Args->GetKey("c").isPresent) ? Args->GetKey("c").IntValue : -1;

    Pcap = std::unique_ptr<TPcap>(new TPcap());
    if (Args->GetKey("i").isPresent) {
        SniffIface = Args->GetKey("i").StrValue;
        if (Pcap->CheckIface(SniffIface) < 0)
            Fatal("Cannot find interface '%s'", SniffIface.c_str());
    }

    Pcap->Init(SniffIface, Args->GetKey("F").StrValue, PacketsMaxCount);

    TSender SenderLoop(&TCPStreamStore, Pcap.get());
    SenderLoop.Start();

    return 0;
}
