Code adjustments

This commit is contained in:
Daan Selen
2025-05-27 16:57:22 +02:00
parent 1d85d9bcf4
commit 7549437c99
11 changed files with 145 additions and 107 deletions

View File

@@ -17,7 +17,7 @@ class connect:
return session
@staticmethod
async def run(session: meshctrl.Session, command: str, nodeids: list[str]) -> None:
async def run(session: meshctrl.Session, command: str, nodeids: str) -> None:
try:
response = await session.run_command(nodeids=nodeids,
command=command,

View File

@@ -12,20 +12,20 @@ def cmd_flags() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Process command-line arguments")
parser.add_argument("-lo", "--list-online", action='store_true', help="Specify if the program needs to list online devices.")
parser.add_argument("-rc", "--run", action='store_true', help="Make the program run a command.")
parser.add_argument("--run", action='store_true', help="Make the program run a command.")
parser.add_argument("--command", type=str, help="Specify the actual command that is going to run.")
parser.add_argument('--nodeids', nargs='+', help='List of node IDs')
parser.add_argument('--nodeid', nargs='+', help='List of node IDs')
parser.add_argument("-i", "--indent", action='store_true', help="Specify whether the output needs to be indented.")
return parser.parse_args()
async def prepare_command(command: str, nodeids: list[str]) -> list[str]: # Have some checks so it happens correctly.
if len(nodeids) < 1 or len(command) < 1:
print("No nodeids or command passed... quiting.")
return []
async def prepare_command(command: str, nodeid: str) -> str: # Have some checks so it happens correctly.
if len(nodeid) < 1 or len(command) < 1:
print("No nodeid or command passed... quiting.")
return ""
return nodeids
return nodeid
async def main() -> None:
args = cmd_flags()
@@ -43,15 +43,15 @@ async def main() -> None:
return await connect.quit(session) # Exit gracefully. Because python.
if args.run:
if not args.command or not args.nodeids:
print("When using run, also use --comand and --nodeids")
if not args.command or not args.nodeid:
print("When using run, also use --command and --nodeid")
return await connect.quit(session) # Exit gracefully. Because python.
command = args.command
nodeids = args.nodeids
nodeids = await prepare_command(command, nodeids)
nodeid = args.nodeid
nodeid = await prepare_command(command, nodeid)
await connect.run(session, command, nodeids)
await connect.run(session, command, nodeid)
await session.close()

View File

@@ -30,5 +30,5 @@ func main() {
log.Println(utilities.InfoTag, "Letting TimeKeeper take over...")
log.Println(utilities.InfoTag, fmt.Sprintf("Interval set at: %d seconds.", cfg.Interval))
timekeeper.KeepTime(cfg.Interval, cfg)
timekeeper.KeepTime(cfg.Interval, cfg.PyVenvName)
}

View File

@@ -2,7 +2,6 @@ package database
import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"ghostrunner-server/modules/encrypt"
@@ -91,9 +90,9 @@ func RetrieveTokenNames() []string {
return tokenNames
}
func InsertTask(name, command string, nodeids []string, date, status string) error {
func InsertTask(name, command string, nodeids []string, date string) error {
for _, singleNodeid := range nodeids {
_, err := db.Exec(declStat.CreateTask, name, command, string(singleNodeid), date, status)
_, err := db.Exec(declStat.CreateTask, name, command, string(singleNodeid), date)
if err != nil {
return fmt.Errorf("failed to create task: %w", err)
}
@@ -101,9 +100,18 @@ func InsertTask(name, command string, nodeids []string, date, status string) err
return nil
}
func RemoveTask(name string) error {
_, err := db.Exec(declStat.DeleteTask, name)
func RemoveTask(name, nodeid string) error {
var count int
err := db.QueryRow(declStat.CountTasks, name).Scan(&count)
if err != nil {
return fmt.Errorf("failed to count the task occurence: %w", err)
}
if count == 0 {
return fmt.Errorf("task '%s' not found", name)
}
if _, err = db.Exec(declStat.DeleteTask, name, nodeid); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return fmt.Errorf("token not found")
}
@@ -113,7 +121,7 @@ func RemoveTask(name string) error {
return nil
}
func RetrieveTasks() []utilities.TaskData {
func RetrieveTasks() []utilities.InternalQTaskData {
rows, err := db.Query(declStat.ListAllTasks)
if err != nil {
log.Println("Query error:", err)
@@ -121,24 +129,17 @@ func RetrieveTasks() []utilities.TaskData {
}
defer rows.Close()
var tasks []utilities.TaskData
var tasks []utilities.InternalQTaskData
for rows.Next() {
var task utilities.TaskData
var nodeidsStr string
var task utilities.InternalQTaskData
err := rows.Scan(&task.Name, &task.Command, &nodeidsStr, &task.Creation, &task.Status)
err := rows.Scan(&task.Name, &task.Command, &task.Nodeid, &task.Creation)
if err != nil {
log.Println("Row scan error:", err)
continue
}
err = json.Unmarshal([]byte(nodeidsStr), &task.Nodeids)
if err != nil {
log.Println("Unmarshal error:", err)
continue
}
tasks = append(tasks, task)
}

View File

@@ -13,6 +13,7 @@ type Statements struct {
CreateTask string
DeleteTask string
ListAllTasks string
CountTasks string
}
var declStat = Statements{
@@ -27,9 +28,15 @@ var declStat = Statements{
name TEXT NOT NULL,
command TEXT NOT NULL,
nodeid TEXT NOT NULL,
creation TEXT NOT NULL,
status TEXT NOT NULL,
result TEXT DEFAULT NULL
creation TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS completed (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
command TEXT NOT NULL,
nodeid TEXT NOT NULL,
completion TEXT NOT NULL,
result TEXT NOT NULL
);`,
AdminTokenCreate: `
@@ -50,10 +57,14 @@ var declStat = Statements{
SELECT name FROM tokens`,
CreateTask: `
INSERT INTO tasks (name, command, nodeid, creation, status)
VALUES (?, ?, ?, ?, ?);`,
INSERT INTO tasks (name, command, nodeid, creation)
VALUES (?, ?, ?, ?);`,
DeleteTask: `
DELETE FROM tasks WHERE name = ?;`,
DELETE FROM tasks WHERE name = ? AND nodeid = ?;`,
ListAllTasks: `
Select name, command, nodeid, creation, status from tasks;`,
Select name, command, nodeid, creation from tasks;`,
CountTasks: `
SELECT COUNT(*)
FROM tasks
WHERE name = ?;`,
}

View File

@@ -13,9 +13,10 @@ import (
"slices"
)
const (
constCreationStatus string = "Created"
)
type authPayload interface {
GetAuthToken() string
GetName() string
}
func generalAuth(w http.ResponseWriter, securedCandidate string) bool {
tokens := database.RetrieveTokens()
@@ -27,55 +28,36 @@ func generalAuth(w http.ResponseWriter, securedCandidate string) bool {
return true
}
func parseTokenAndAuth(w http.ResponseWriter, r *http.Request, hmacKey []byte) (utilities.TokenCreateBody, bool) {
var data utilities.TokenCreateBody
func parseAndAuth[T authPayload](w http.ResponseWriter, r *http.Request, hmacKey []byte) (T, bool) {
var data T
if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
log.Println(utilities.ErrTag, "Decode error:", err)
http.Error(w, "Invalid request body", http.StatusBadRequest)
return data, false
}
if data.AuthToken == "" || data.Details.Name == "" {
if data.GetAuthToken() == "" || data.GetName() == "" {
log.Println("[ERROR] Missing required fields")
http.Error(w, "Missing required fields", http.StatusBadRequest)
return data, false
}
givenToken := data.AuthToken
securedCandidate := encrypt.CreateHMAC(givenToken, hmacKey)
return data, generalAuth(w, securedCandidate)
}
func parseTaskAndAuth(w http.ResponseWriter, r *http.Request, hmacKey []byte) (utilities.TaskBody, bool) {
var data utilities.TaskBody
if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
log.Println(utilities.ErrTag, "Decode error:", err)
http.Error(w, "Invalid request body", http.StatusBadRequest)
return data, false
}
if data.AuthToken == "" || data.Details.Name == "" {
log.Println("[ERROR] Missing required fields")
http.Error(w, "Missing required fields", http.StatusBadRequest)
return data, false
}
givenToken := data.AuthToken
securedCandidate := encrypt.CreateHMAC(givenToken, hmacKey)
securedCandidate := encrypt.CreateHMAC(data.GetAuthToken(), hmacKey)
return data, generalAuth(w, securedCandidate)
}
/*
The following section portrains to Token creation and deletion.
The following section pertrains to Token creation and deletion.
*/
func createTokenHandler(hmacKey []byte) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
data, ok := parseTokenAndAuth(w, r, hmacKey)
data, ok := parseAndAuth[utilities.TokenCreateBody](w, r, hmacKey)
if !ok {
return
}
data.Details.Name = strings.ToLower(data.Details.Name) //Transform to lower
token, err := createToken(data.Details.Name, hmacKey)
if err != nil {
log.Println(utilities.ErrTag, "createToken failed:", err)
@@ -87,7 +69,7 @@ func createTokenHandler(hmacKey []byte) http.HandlerFunc {
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(utilities.InfoResponse{
Status: http.StatusCreated,
Message: "Token Succesfully Created.",
Message: "Token Successfully Created.",
Data: token,
})
}
@@ -95,7 +77,7 @@ func createTokenHandler(hmacKey []byte) http.HandlerFunc {
func deleteTokenHandler(hmacKey []byte) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
data, ok := parseTokenAndAuth(w, r, hmacKey)
data, ok := parseAndAuth[utilities.TokenCreateBody](w, r, hmacKey)
if !ok {
return
}
@@ -141,7 +123,7 @@ func listTokenHandler(hmacKey []byte) http.HandlerFunc {
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(utilities.InfoResponse{
Status: http.StatusOK,
Message: "Succesfully Retrieved Tokens",
Message: "Successfully Retrieved Tokens",
Data: data,
})
}
@@ -168,11 +150,12 @@ The following section portrains to Task creation and deletion.
func createTaskHandler(hmacKey []byte) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
data, ok := parseTaskAndAuth(w, r, hmacKey)
data, ok := parseAndAuth[utilities.TaskCreateBody](w, r, hmacKey)
if !ok {
return
}
data.Details.Name = strings.ToLower(data.Details.Name) //Transform to lower
if err := createTask(data.Details.Name, data.Details.Command, data.Details.Nodeids); err != nil {
log.Println(utilities.ErrTag, "createTask failed:", err)
http.Error(w, "Task creation failed", http.StatusInternalServerError)
@@ -183,19 +166,20 @@ func createTaskHandler(hmacKey []byte) http.HandlerFunc {
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(utilities.InfoResponse{
Status: http.StatusOK,
Message: "Task '" + data.Details.Name + "' Created Succesfully.",
Message: "Task '" + data.Details.Name + "' Created Successfully.",
})
}
}
func deleteTaskHandler(hmacKey []byte) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
data, ok := parseTaskAndAuth(w, r, hmacKey)
data, ok := parseAndAuth[utilities.TaskCreateBody](w, r, hmacKey)
if !ok {
return
}
nodeid := data.Details.Nodeids[0]
if err := deleteTask(data.Details.Name); err != nil {
if err := deleteTask(data.Details.Name, nodeid); err != nil {
log.Println(utilities.ErrTag, "createTask failed:", err)
http.Error(w, "Task deletion failed", http.StatusInternalServerError)
return
@@ -205,7 +189,7 @@ func deleteTaskHandler(hmacKey []byte) http.HandlerFunc {
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(utilities.InfoResponse{
Status: http.StatusOK,
Message: "Task '" + data.Details.Name + "' Deleted Succesfully.",
Message: "Task '" + data.Details.Name + "' Deleted Successfully.",
})
}
}
@@ -236,7 +220,7 @@ func listTasksHandler(hmacKey []byte) http.HandlerFunc {
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(utilities.InfoResponse{
Status: http.StatusOK,
Message: "Succesfully Retrieved Tasks",
Message: "Successfully Retrieved Tasks",
Data: data,
})
}
@@ -250,12 +234,11 @@ func flushTaskListHandler(hmacKey []byte) http.HandlerFunc {
func createTask(taskName, command string, nodeids []string) error {
creationDate := time.Now().Format("02-01-2006 15:04:05")
creationStatus := constCreationStatus
taskName = strings.ToLower(taskName)
return database.InsertTask(taskName, command, nodeids, creationDate, creationStatus)
return database.InsertTask(taskName, command, nodeids, creationDate)
}
func deleteTask(taskName string) error {
return database.RemoveTask(taskName)
func deleteTask(taskName, nodeid string) error {
return database.RemoveTask(taskName, nodeid)
}

View File

@@ -1,29 +1,42 @@
package timekeeper
import (
"fmt"
"ghostrunner-server/modules/database"
"ghostrunner-server/modules/utilities"
"ghostrunner-server/modules/wrapper"
"log"
"strings"
)
func routine(cfg utilities.ConfigStruct, pyListArgs []string) {
d := listDevices(cfg, pyListArgs) // Retrieve the Online devices.
func routine(venvName string, pyListArgs []string) {
d := listDevices(venvName, pyListArgs) // Retrieve the Online devices.
curTasks := database.RetrieveTasks()
for index, task := range curTasks {
relevantNodeids := task.Nodeids
relevantNodeid := task.Nodeid
log.Printf("Processing Task %d", index)
for _, nodeid := range relevantNodeids {
if isNodeOnline(nodeid, d.OnlineDevices) {
//result := wrapper.ExecCommand(nodeid, task.Command)
log.Printf("Node online: %s", nodeid)
}
if isNodeOnline(relevantNodeid, d.OnlineDevices) {
log.Printf("Node online: %s", relevantNodeid)
forgeAndExec(venvName, relevantNodeid, task.Command)
database.RemoveTask(task.Name, task.Nodeid)
} else {
log.Printf("Node offline %s", relevantNodeid)
}
}
}
func listDevices(venvName string, pyArgs []string) utilities.PyOnlineDevices {
onDevices, err := wrapper.PyListOnline(venvName, pyArgs)
if err != nil {
log.Println(utilities.ErrTag, err)
}
return onDevices
}
func isNodeOnline(nodeid string, onlineDevices []utilities.Device) bool {
for _, device := range onlineDevices {
if device.NodeID == nodeid {
@@ -32,3 +45,12 @@ func isNodeOnline(nodeid string, onlineDevices []utilities.Device) bool {
}
return false
}
func forgeAndExec(venvName string, nodeid, command string) {
log.Printf("Triggered %s, on %s", command, nodeid)
pyArgs := strings.Fields(fmt.Sprintf("--run --nodeid %s --command", nodeid))
pyArgs = append(pyArgs, command)
wrapper.ExecTask(venvName, pyArgs)
}

View File

@@ -2,7 +2,6 @@ package timekeeper
import (
"ghostrunner-server/modules/utilities"
"ghostrunner-server/modules/wrapper"
"log"
"strings"
"time"
@@ -12,7 +11,7 @@ var ( // Debugging
pyListArgs = strings.Fields("-lo")
)
func KeepTime(interval int, cfg utilities.ConfigStruct) {
func KeepTime(interval int, venvName string) {
transInterval := time.Duration(interval) * time.Second
ticker := time.NewTicker(transInterval)
@@ -21,15 +20,6 @@ func KeepTime(interval int, cfg utilities.ConfigStruct) {
for t := range ticker.C {
log.Println(utilities.InfoTag, "Tick at:", t)
log.Println(utilities.InfoTag, "Starting Routine.")
routine(cfg, pyListArgs)
routine(venvName, pyListArgs)
}
}
func listDevices(cfg utilities.ConfigStruct, pyArgs []string) utilities.PyOnlineDevices {
onDevices, err := wrapper.PyListOnline(cfg.PyVenvName, pyArgs)
if err != nil {
log.Println(utilities.ErrTag, err)
}
return onDevices
}

View File

@@ -34,7 +34,7 @@ type TokenListBody struct {
AuthToken string `json:"authtoken"`
}
type TaskData struct {
type RequestTaskData struct {
Name string `json:"name"`
Command string `json:"command"`
Nodeids []string `json:"nodeids"`
@@ -42,9 +42,24 @@ type TaskData struct {
Status string `json:"status"`
}
type TaskBody struct {
AuthToken string `json:"authtoken"`
Details TaskData `json:"details"`
type InternalQTaskData struct {
Name string `json:"name"`
Command string `json:"command"`
Nodeid string `json:"nodeid"`
Creation string `json:"creation"`
}
type InternalCTaskData struct {
Name string `json:"name"`
Command string `json:"command"`
Nodeid string `json:"nodeid"`
Completion string `json:"completion"`
Result string `json:"result"`
}
type TaskCreateBody struct {
AuthToken string `json:"authtoken"`
Details RequestTaskData `json:"details"`
}
// Python wrapper objects.

View File

@@ -14,6 +14,12 @@ const (
ErrTag = "[ERROR]"
)
func (t TokenCreateBody) GetAuthToken() string { return t.AuthToken }
func (t TokenCreateBody) GetName() string { return t.Details.Name }
func (t TaskCreateBody) GetAuthToken() string { return t.AuthToken }
func (t TaskCreateBody) GetName() string { return t.Details.Name }
func CheckDatabaseRemnants(databaseDir, fullDatabasePath string) {
remnantDir := StatPath(databaseDir)
if !remnantDir {

View File

@@ -13,13 +13,17 @@ const (
pyFile = "./../runner/runner.py"
)
func PyListOnline(venvName string, pyArgs []string) (utilities.PyOnlineDevices, error) {
func pyExec(venvName string, pyArgs []string) ([]byte, error) {
pyBin := fmt.Sprintf("./../runner/%s/bin/python", venvName)
runtimeArgs := append([]string{pyFile}, pyArgs...)
cmd := exec.Command(pyBin, runtimeArgs...)
rawData, err := cmd.CombinedOutput()
return cmd.CombinedOutput()
}
func PyListOnline(venvName string, pyArgs []string) (utilities.PyOnlineDevices, error) {
rawData, err := pyExec(venvName, pyArgs)
if err != nil {
cwd, _ := os.Getwd()
return utilities.PyOnlineDevices{}, fmt.Errorf("python execution failed, working directory: %s", cwd)
@@ -33,6 +37,12 @@ func PyListOnline(venvName string, pyArgs []string) (utilities.PyOnlineDevices,
return data, nil
}
func ExecCommand(nodeid, command string) {
log.Printf("Triggered %s, on %s", command, nodeid)
func ExecTask(venvName string, pyArgs []string) {
rawData, err := pyExec(venvName, pyArgs)
if err != nil {
cwd, _ := os.Getwd()
log.Println("FAILED,", err, "CWD:", cwd)
}
log.Println(string(rawData))
}