#include <linux/bpf.h>
#include <linux/version.h>
#include <sys/socket.h>

#include <bpf_helpers.h>

#include <map_defs/ya_nets.h>

#include <bpf_utils.h>
#include <defs.h>

struct bpf_map_def SEC("maps") project_id_map = {
    .type = BPF_MAP_TYPE_HASH,
    .key_size = sizeof(__u32),
    .value_size = sizeof(__u8),
    .max_entries = 1000,
};

static __always_inline __u8 get_yt_tc(const __u32 _addr[4])
{
    const volatile __u32 *addr = (const volatile __u32 *)_addr;
    __u32 a = addr[0];

    if (unlikely(a != YA_NETS))
        return 0;

    a = addr[2];

    __u8 *tc = bpf_map_lookup_elem(&project_id_map, &a);
    return tc ? *tc : 0;
}

static __always_inline __u8 choose_tc(struct bpf_sock_ops *skops)
{
    __u32 local_dc_net = get_dc_net(skops->local_ip6);
    __u32 remote_dc_net = get_dc_net(skops->remote_ip6);

    enum ENetwork local_net_type = (__u16)(local_dc_net >> 16);
    enum ENetwork remote_net_type = (__u16)(local_dc_net >> 16);

    if (local_net_type != remote_net_type)
        return 0;

    __u8 local_yt_tc = get_yt_tc(skops->local_ip6);
    __u8 remote_yt_tc = get_yt_tc(skops->remote_ip6);

    if (local_net_type == NET_FASTBONE) {
        if (!remote_yt_tc) {
            return TC_CS1;
        } else {
            return TC_CS2;
        }
    } else if (local_net_type == NET_BACKBONE) {
        if (local_yt_tc && remote_yt_tc && local_dc_net == remote_dc_net) {
            return local_yt_tc;
        } else {
            return TC_CS3;
        }
    }

    return 0;
}

SEC("sockops")
int tcp_tos(struct bpf_sock_ops *skops)
{
    int rv = -1;

    if (skops->family == AF_INET6) {
        __u8 tc;

        switch (skops->op) {
        case BPF_SOCK_OPS_NEEDS_ECN:
            tc = choose_tc(skops);
            if (tc & ECN_BIT)
                rv = 1;
            break;
        case BPF_SOCK_OPS_TCP_CONNECT_CB:
        case BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB:
        case BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB:
            tc = choose_tc(skops);
            if (tc) {
                int prev_tc = get_sock_tc(skops);
                rv = set_sock_tc(skops, prev_tc | (tc & ~ECN_MASK));
            }
            break;
        default:
            break;
        }
    }

    skops->reply = rv;

    return 1;
}

char _license[] SEC("license") = "GPL";
__u32 _version SEC("version") = LINUX_VERSION_CODE;
