#include <linux/bpf.h>
#include <linux/version.h>
#include <sys/socket.h>

#include <bpf_helpers.h>

#include <map_defs/ya_nets.h>

#include <bpf_tcp_sock.h>
#include <bpf_utils.h>
#include <defs.h>

#ifndef TCP_BPF_RTO_MIN
#define TCP_BPF_RTO_MIN 1004
#endif

#ifndef SO_TXREHASH
#define SO_TXREHASH 74
#endif

struct bpf_map_def SEC("maps") config_map = {
    .type = BPF_MAP_TYPE_ARRAY,
    .key_size = sizeof(__u32),
    .value_size = sizeof(__u32),
    .max_entries = CONF__COUNT,
};

struct bpf_map_def SEC("maps") rto_map = {
    .type = BPF_MAP_TYPE_ARRAY,
    .key_size = sizeof(__u32),
    .value_size = sizeof(__u32),
    .max_entries = DC__COUNT - 1,
};

struct bpf_map_def SEC("maps") tcp_bytes_counters_map = {
    .type = BPF_MAP_TYPE_PERCPU_ARRAY,
    .key_size = sizeof(__u32),
    .value_size = sizeof(__u64),
    .max_entries = TCP_BYTES__COUNT,
};

struct bpf_map_def SEC("maps") yttl_blacklist_nets_map = {
    .type = BPF_MAP_TYPE_LPM_TRIE,
    .key_size = 8,
    .value_size = sizeof(__u32),
    .max_entries = 10000,
    .map_flags = BPF_F_NO_PREALLOC,
};

struct bpf_map_def SEC("maps") bb_cong_map = {
    .type = BPF_MAP_TYPE_ARRAY,
    .key_size = sizeof(__u32),
    .value_size = sizeof(char) * MAX_CONG_LEN,
    .max_entries = (DC__COUNT - 1),
};

struct bpf_map_def SEC("maps") fb_cong_map = {
    .type = BPF_MAP_TYPE_ARRAY,
    .key_size = sizeof(__u32),
    .value_size = sizeof(char) * MAX_CONG_LEN,
    .max_entries = (DC__COUNT - 1),
};

static __always_inline bool is_in_yttl_blacklist(const __u32 _addr[4])
{
    const volatile __u32 *addr = (const volatile __u32 *)_addr;
    __u32 a = addr[0];

    if (unlikely(a != YA_NETS))
        return false;

    a = addr[1];

    __u32 key[2] = { 32, a };

    __u32 *value = bpf_map_lookup_elem(&yttl_blacklist_nets_map, &key);
    return (bool)value;
}

static __always_inline enum EDatacenter get_my_dc()
{
    __u32 key = CONF_MY_DC;
    __u32 *dc = bpf_map_lookup_elem(&config_map, &key);
    return dc ? (enum EDatacenter)*dc : DC_UNKNOWN;
}

static __always_inline bool is_inside_dc(const __u32 _addr[4])
{
    enum EDatacenter my_dc = get_my_dc();
    if (unlikely(!my_dc))
        return false;

    enum EDatacenter target_dc = get_dc(_addr);
    return my_dc == target_dc;
}

#ifdef TCP_EXTENDED_COUNTERS
static __always_inline bool is_external_connection(struct bpf_sock_ops *skops) {
    if (skops->family == AF_INET)
        return true;

    const volatile __u32 *addr = (const volatile __u32 *)skops->remote_ip6;
    __u32 a = addr[0];

    return a != YA_NETS;
}
#endif

static __always_inline int get_rto(const __u32 _addr[4])
{
    __u32 target_dc_net = get_dc_net(_addr);
    enum EDatacenter target_dc = (__u16)target_dc_net;
    enum ENetwork target_net = (__u16)(target_dc_net >> 16);
    if (target_dc) {
        __u32 key = target_dc - 1;
        __u32 *rto = bpf_map_lookup_elem(&rto_map, &key);
        if (unlikely(!rto))
            return -1;

        key = CONF_CROSS_DC_RTO;
        __u32 *cross_dc = bpf_map_lookup_elem(&config_map, &key);
        enum EDatacenter my_dc = get_my_dc();
        if (my_dc != target_dc) {
            if (target_net == NET_FASTBONE)
                return -1;
            if (cross_dc && *cross_dc)
                return *rto;
            return -1;
        }

        return *rto;
    }
    return -1;
}

static __always_inline bool is_config_enabled(__u32 key)
{
    __u32 *enabled = bpf_map_lookup_elem(&config_map, &key);
    return enabled && *enabled;
}

static __always_inline char *get_cong(struct bpf_sock_ops *skops) {
    char *cong = 0;

    if (is_config_enabled(CONF_CONG_CONTROL)) {
        __u32 my_dc_net = get_dc_net(skops->local_ip6);
        enum ENetwork my_net = (__u16)(my_dc_net >> 16);

        __u32 target_dc_net = get_dc_net(skops->remote_ip6);
        __u32 target_dc = (__u16)target_dc_net;
        enum ENetwork target_net = (__u16)(target_dc_net >> 16);

        if (target_dc && my_net && my_net == target_net) {
            __u32 key = target_dc - 1;

            if (my_net == NET_BACKBONE)
                cong = bpf_map_lookup_elem(&bb_cong_map, &key);
            else if (my_net == NET_FASTBONE)
                cong = bpf_map_lookup_elem(&fb_cong_map, &key);
        }
    }

    return cong;
}

#ifndef TXREHASH
static __always_inline void random_flowlabel(struct bpf_sock_ops *skops)
{
    skops->sk_txhash = bpf_get_prandom_u32();
}

static __always_inline void force_hash_flowlabel(struct bpf_sock_ops *skops)
{
    skops->sk_txhash = 0;
}
#endif

static __always_inline void txrehash_enable(struct bpf_sock_ops *skops) {
#ifdef TXREHASH
    int optval = 1;
    (void)bpf_setsockopt(skops, SOL_SOCKET, SO_TXREHASH, &optval, sizeof(optval));
#else
    random_flowlabel(skops);
#endif
}

#define YATTL_BIT (1 << 2)

SEC("sockops")
int tcp_rto(struct bpf_sock_ops *skops)
{
    int tclass, rv = -1;

    if (skops->family == AF_INET6) {
        bool yttl_fl_allowed = is_config_enabled(CONF_YTTL) && !is_in_yttl_blacklist(skops->local_ip6) && !is_in_yttl_blacklist(skops->remote_ip6);
        char *cong = get_cong(skops);

        switch (skops->op) {
        case BPF_SOCK_OPS_TIMEOUT_INIT:
            {
                rv = get_rto(skops->remote_ip6);
                if (rv > 0 && is_config_enabled(CONF_SOCK_MIN_RTO)) {
                    __u32 key = CONF_JIFFIE_MS;
                    __u32 *jiffieMs = bpf_map_lookup_elem(&config_map, &key);
                    if (jiffieMs && *jiffieMs) {
                        int rtoMinUs = rv * (*jiffieMs) * 1000;
                        (void)bpf_setsockopt(skops, SOL_TCP, TCP_BPF_RTO_MIN, &rtoMinUs, sizeof(rtoMinUs));
                    }
                }
            }
            break;
        case BPF_SOCK_OPS_TCP_CONNECT_CB:
            rv = bpf_sock_ops_cb_flags_set(skops, BPF_SOCK_OPS_RTO_CB_FLAG | BPF_SOCK_OPS_STATE_CB_FLAG);
#ifndef TXREHASH
            force_hash_flowlabel(skops);
#endif
            break;
        case BPF_SOCK_OPS_TCP_LISTEN_CB:
            rv = bpf_sock_ops_cb_flags_set(skops, BPF_SOCK_OPS_STATE_CB_FLAG);
            break;
        case BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB:
            if (cong && cong[0])
                set_congestion_control(skops, cong);
            if (likely(yttl_fl_allowed)) {
                tclass = get_sock_tc(skops);
                if (tclass & YATTL_BIT)
                    rv = set_sock_tc(skops, tclass & ~YATTL_BIT);
            }
            break;
        case BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB:
            if (cong && cong[0])
                set_congestion_control(skops, cong);
            break;
        case BPF_SOCK_OPS_RTO_CB:
            if (skops->state == BPF_TCP_SYN_SENT) {
                if (skops->total_retrans == 0) {
                    if (likely(yttl_fl_allowed) && is_inside_dc(skops->remote_ip6)) {
                        tclass = get_sock_tc(skops);
                        rv = set_sock_tc(skops, tclass | YATTL_BIT);
                    } else if (!is_stateful_addr(skops->remote_ip6))
                        txrehash_enable(skops);
                } else if (skops->total_retrans == 1) {
                    if (likely(yttl_fl_allowed)) {
                        tclass = get_sock_tc(skops);
                        if (tclass & YATTL_BIT)
                            rv = set_sock_tc(skops, tclass & ~YATTL_BIT);
                    }
                    if (!is_stateful_addr(skops->remote_ip6))
                        txrehash_enable(skops);
                } else {
                    if (!is_stateful_addr(skops->remote_ip6))
                        txrehash_enable(skops);
                }
            } else if (!is_stateful_addr(skops->remote_ip6))
                txrehash_enable(skops);
            break;
#ifdef TCP_EXTENDED_COUNTERS
        case BPF_SOCK_OPS_STATE_CB:
            {
                __u32 new_state = skops->args[1];
                if (new_state == BPF_TCP_CLOSE) {
                    struct bpf_tcp_sock_ext *tp;
                    struct bpf_sock *sk;

                    sk = skops->sk;
                    if (!sk)
                        return 1;

                    tp = (struct bpf_tcp_sock_ext *)bpf_tcp_sock(sk);
                    if (!tp)
                        return 1;

                    for (__u32 i = 0; i < TCP_BYTES__COUNT; ++i) {
                        __u32 key = i;
                        __u64 *value = bpf_map_lookup_elem(&tcp_bytes_counters_map, &key);
                        if (value) {
                            switch (key) {
                            case TCP_BYTES_ACKED:
                                if (is_config_enabled(CONF_TCP_BYTES_ACKED))
                                    *value += tp->bytes_acked;
                                break;
                            case TCP_BYTES_SENT:
                                if (is_config_enabled(CONF_TCP_BYTES_SENT))
                                    *value += tp->bytes_sent;
                                break;
                            case TCP_BYTES_RETRANS:
                                if (is_config_enabled(CONF_TCP_BYTES_RETRANS))
                                    *value += tp->bytes_retrans;
                                break;
                            default:
                                break;
                            }
                        }
                    }
                }
            }
            break;
#endif
        default:
            break;
        }
    }

#ifdef TCP_EXTENDED_COUNTERS
    switch (skops->op) {
    case BPF_SOCK_OPS_STATE_CB:
        {
            __u32 new_state = skops->args[1];
            if (new_state == BPF_TCP_CLOSE) {
                struct bpf_tcp_sock_ext *tp;
                struct bpf_sock *sk;

                sk = skops->sk;
                if (!sk)
                    return 1;

                tp = (struct bpf_tcp_sock_ext *)bpf_tcp_sock(sk);
                if (!tp)
                    return 1;

                __u32 key = TCP_BYTES_EXT_RETRANS;
                __u64 *value = bpf_map_lookup_elem(&tcp_bytes_counters_map, &key);
                if (value &&
                    is_config_enabled(CONF_TCP_BYTES_EXT_RETRANS) &&
                    is_external_connection(skops))
                        *value += tp->bytes_retrans;
            }
        }
        break;
    default:
        break;
    }
#endif

    /* tcp_bpf_ca_needs_ecn requires that all sockops bpf programs reply true
     * for result to be true. As long as tcp_tos is attached before tcp_rto,
     * reply would be copied from tcp_tos. Without tcp_tos default reply is 0.
     */
    if (skops->op != BPF_SOCK_OPS_NEEDS_ECN)
        skops->reply = rv;

    return 1;
}

char _license[] SEC("license") = "GPL";
__u32 _version SEC("version") = LINUX_VERSION_CODE;
