package main

import (
	"context"
	"flag"
	"fmt"
	"log"
	"os"
	"os/signal"
	"syscall"

	"github.com/kolide/osquery-go"
	"github.com/kolide/osquery-go/plugin/table"

	"a.yandex-team.ru/security/osquery/extensions/gosecure/dbus/logind"
	"a.yandex-team.ru/security/osquery/extensions/gosecure/dbus/systemd"
)

const (
	SystemdExtensionName           = "systemd"
	SystemdUnitsTableName          = "systemd_units"
	SystemdUnitFilesTableName      = "systemd_unit_files"
	SystemdUnitPaths               = "systemd_unit_paths"
	SystemdJobsInfoTable           = "systemd_jobs"
	SystemdLogindSessionsTableName = "systemd_logind_sessions"
)

func SystemdUnitsTable() []table.ColumnDefinition {
	return []table.ColumnDefinition{
		table.TextColumn("name"),
		table.TextColumn("description"),
		table.TextColumn("load_state"),
		table.TextColumn("active_state"),
		table.TextColumn("sub_state"),
		table.TextColumn("followed_by"),
		table.TextColumn("object_path"),
		table.IntegerColumn("job_id"),
		table.TextColumn("job_type"),
		table.TextColumn("job_object_path"),
	}
}

func GenerateUnitsTableData(ctx context.Context, queryCtx table.QueryContext) ([]map[string]string, error) {
	unitsData, err := systemd.ListUnits()
	if err != nil {
		return nil, err
	}

	result := make([]map[string]string, len(unitsData))
	for index, unit := range unitsData {
		result[index] = map[string]string{
			"name":            unit.Name,
			"description":     unit.Description,
			"load_state":      unit.LoadState,
			"active_state":    unit.ActiveState,
			"sub_state":       unit.SubState,
			"followed_by":     unit.FollowedBy,
			"object_path":     string(unit.ObjectPath),
			"job_id":          fmt.Sprint(unit.JobID),
			"job_type":        unit.JobType,
			"job_object_path": string(unit.JobObjectPath),
		}
	}

	return result, nil
}

func SystemdUnitFilesTable() []table.ColumnDefinition {
	return []table.ColumnDefinition{
		table.TextColumn("path"),
		table.TextColumn("status"),
	}
}

func GenerateUnitFilesTableData(ctx context.Context, queryCtx table.QueryContext) ([]map[string]string, error) {
	unitFilesData, err := systemd.ListUnitFiles()
	if err != nil {
		return nil, err
	}

	result := make([]map[string]string, len(unitFilesData))
	for index, fileInfo := range unitFilesData {
		result[index] = map[string]string{
			"path":   fileInfo.Path,
			"status": fileInfo.Status,
		}
	}

	return result, nil
}

func SystemdUnitPathsTable() []table.ColumnDefinition {
	return []table.ColumnDefinition{
		table.TextColumn("path"),
	}
}

func GenerateUnitPathsTableData(ctx context.Context, queryCtx table.QueryContext) ([]map[string]string, error) {
	paths, err := systemd.UnitPath()
	if err != nil {
		return nil, err
	}

	result := make([]map[string]string, len(paths))
	for index, p := range paths {
		result[index] = map[string]string{
			"path": p,
		}
	}

	return result, nil
}

func SystemdJobsTable() []table.ColumnDefinition {
	return []table.ColumnDefinition{
		table.IntegerColumn("id"),
		table.TextColumn("unit_name"),
		table.TextColumn("type"),
		table.TextColumn("state"),
		table.TextColumn("object"),
		table.TextColumn("unit_object"),
	}
}

func GenerateJobsTableData(ctx context.Context, queryCtx table.QueryContext) ([]map[string]string, error) {
	jobs, err := systemd.ListJobs()
	if err != nil {
		return nil, err
	}

	result := make([]map[string]string, len(jobs))
	for index, job := range jobs {
		result[index] = map[string]string{
			"id":          fmt.Sprint(job.ID),
			"unit_name":   job.UnitName,
			"type":        job.Type,
			"state":       job.State,
			"object":      string(job.Object),
			"unit_object": string(job.Unit),
		}
	}

	return result, nil
}

func SystedmLogindSessionsTable() []table.ColumnDefinition {
	return []table.ColumnDefinition{
		table.TextColumn("id"),
		table.IntegerColumn("uid"),
		table.TextColumn("name"),
		table.BigIntColumn("timestamp"),
		table.BigIntColumn("timestamp_monotonic"),
		table.IntegerColumn("vtnr"),
		table.TextColumn("seat_id"),
		table.TextColumn("tty"),
		table.TextColumn("display"),
		table.TextColumn("remote"),
		table.TextColumn("remote_host"),
		table.TextColumn("remote_user"),
		table.TextColumn("service"),
		table.TextColumn("scope"),
		table.IntegerColumn("leader"),
		table.IntegerColumn("audit"),
		table.TextColumn("type"),
		table.TextColumn("class"),
		table.TextColumn("active"),
		table.TextColumn("idle_hint"),
		table.BigIntColumn("idle_since_hint"),
		table.BigIntColumn("idle_since_hint_monotonic"),
	}
}

func GenerateSystemdLogindTableData(ctx context.Context, queryCtx table.QueryContext) ([]map[string]string, error) {
	sessions, err := logind.GetAllSessionObjects()
	if err != nil {
		return nil, err
	}

	result := make([]map[string]string, len(sessions))
	for index, session := range sessions {
		result[index] = map[string]string{
			"id":                        session.ID,
			"uid":                       fmt.Sprint(session.User.UserID),
			"name":                      session.Name,
			"timestamp":                 fmt.Sprint(session.Timestamp),
			"timestamp_monotonic":       fmt.Sprint(session.TimestampMonotonic),
			"vtnr":                      fmt.Sprint(session.VTNr),
			"seat_id":                   session.Seat.SeatID,
			"tty":                       session.TTY,
			"display":                   session.Display,
			"remote":                    fmt.Sprintf("%t", session.Remote),
			"remote_host":               session.RemoteHost,
			"remote_user":               session.RemoteUser,
			"service":                   session.Service,
			"scope":                     session.Scope,
			"leader":                    fmt.Sprint(session.Leader),
			"audit":                     fmt.Sprint(session.Audit),
			"type":                      session.Type,
			"class":                     session.Class,
			"active":                    fmt.Sprintf("%t", session.Active),
			"idle_since_hint":           fmt.Sprint(session.IdleSinceHint),
			"idle_hint":                 fmt.Sprintf("%t", session.IdleHint),
			"idle_since_hint_monotonic": fmt.Sprint(session.IdleSinceHintMonotonic),
		}
	}

	return result, nil
}

func main() {
	flag.Parse()
	var socketPath string
	if flag.NArg() != 1 {
		fmt.Println("invalid path to osquery socket")
		os.Exit(1)
	} else {
		socketPath = flag.Arg(0)
	}

	if _, err := os.Stat(socketPath); os.IsNotExist(err) {
		fmt.Println("can't find socket at", socketPath)
		os.Exit(1)
	}

	server, err := osquery.NewExtensionManagerServer(SystemdExtensionName, socketPath)
	if err != nil {
		fmt.Println("can't initialize extension manager:", err.Error())
		os.Exit(1)
	}

	server.RegisterPlugin(table.NewPlugin(SystemdUnitsTableName, SystemdUnitsTable(), GenerateUnitsTableData))
	server.RegisterPlugin(table.NewPlugin(SystemdUnitFilesTableName, SystemdUnitFilesTable(), GenerateUnitFilesTableData))
	server.RegisterPlugin(table.NewPlugin(SystemdUnitPaths, SystemdUnitPathsTable(), GenerateUnitPathsTableData))
	server.RegisterPlugin(table.NewPlugin(SystemdJobsInfoTable, SystemdJobsTable(), GenerateJobsTableData))
	server.RegisterPlugin(table.NewPlugin(SystemdLogindSessionsTableName, SystedmLogindSessionsTable(), GenerateSystemdLogindTableData))

	idleConnsClosed := make(chan struct{})
	go func() {
		gracefulStop := make(chan os.Signal, 1)

		signal.Notify(gracefulStop, syscall.SIGTERM)
		signal.Notify(gracefulStop, syscall.SIGINT)
		<-gracefulStop

		if err := server.Shutdown(context.Background()); err != nil {
			log.Printf("Server shutdown error: %v", err)
		}
		close(idleConnsClosed)

	}()

	if err := server.Run(); err != nil {
		log.Fatal(err)
	}
	<-idleConnsClosed
}
