diff --git a/errors_test.go b/errors_test.go index 2ef75f1..82fb8df 100644 --- a/errors_test.go +++ b/errors_test.go @@ -12,6 +12,8 @@ import ( ) func TestErrorFuncs(t *testing.T) { + userUnit := requireUserTestUnit(t) + systemUnit := requireSystemTestUnit(t) errFuncs := []func(ctx context.Context, unit string, opts Options) error{ func(ctx context.Context, unit string, opts Options) error { return Enable(ctx, unit, opts) }, func(ctx context.Context, unit string, opts Options) error { return Disable(ctx, unit, opts) }, @@ -29,11 +31,11 @@ func TestErrorFuncs(t *testing.T) { // try nonexistant unit in user mode as user {"nonexistant", ErrDoesNotExist, Options{UserMode: true}, true}, // try existing unit in user mode as user - {"syncthing", nil, Options{UserMode: true}, true}, + {userUnit, nil, Options{UserMode: true}, true}, // try nonexisting unit in system mode as user {"nonexistant", ErrInsufficientPermissions, Options{UserMode: false}, true}, // try existing unit in system mode as user - {"nginx", ErrInsufficientPermissions, Options{UserMode: false}, true}, + {systemUnit, ErrInsufficientPermissions, Options{UserMode: false}, true}, /* End user tests*/ @@ -42,9 +44,9 @@ func TestErrorFuncs(t *testing.T) { // try nonexistant unit in system mode as system {"nonexistant", ErrDoesNotExist, Options{UserMode: false}, false}, // try existing unit in system mode as system - {"nginx", ErrBusFailure, Options{UserMode: true}, false}, + {systemUnit, ErrBusFailure, Options{UserMode: true}, false}, // try existing unit in system mode as system - {"nginx", nil, Options{UserMode: false}, false}, + {systemUnit, nil, Options{UserMode: false}, false}, /* End superuser tests*/ diff --git a/helpers_test.go b/helpers_test.go index a468d8c..56e8c6e 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -21,6 +21,8 @@ func TestGetStartTime(t *testing.T) { if testing.Short() { t.Skip("skipping in short mode") } + userUnit := requireUserTestUnit(t) + systemUnit := requireSystemTestUnit(t) testCases := []struct { unit string err error @@ -31,23 +33,23 @@ func TestGetStartTime(t *testing.T) { // try nonexistant unit in user mode as user {"nonexistant", ErrUnitNotActive, Options{UserMode: false}, true}, // try existing unit in user mode as user - {"syncthing", ErrUnitNotActive, Options{UserMode: true}, true}, + {userUnit, ErrUnitNotActive, Options{UserMode: true}, true}, // try existing unit in system mode as user - {"nginx", nil, Options{UserMode: false}, true}, + {systemUnit, nil, Options{UserMode: false}, true}, // Run these tests only as a superuser // try nonexistant unit in system mode as system {"nonexistant", ErrUnitNotActive, Options{UserMode: false}, false}, // try existing unit in system mode as system - {"nginx", ErrBusFailure, Options{UserMode: true}, false}, + {systemUnit, ErrBusFailure, Options{UserMode: true}, false}, // try existing unit in system mode as system - {"nginx", nil, Options{UserMode: false}, false}, + {systemUnit, nil, Options{UserMode: false}, false}, } ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() - Restart(ctx, "syncthing", Options{UserMode: true}) - Stop(ctx, "syncthing", Options{UserMode: true}) + Restart(ctx, userUnit, Options{UserMode: true}) + Stop(ctx, userUnit, Options{UserMode: true}) time.Sleep(1 * time.Second) for _, tc := range testCases { t.Run(fmt.Sprintf("%s as %s, UserMode=%v", tc.unit, userString, tc.opts.UserMode), func(t *testing.T) { @@ -64,6 +66,7 @@ func TestGetStartTime(t *testing.T) { } }) } + Start(ctx, userUnit, Options{UserMode: true}) // Prove start time changes after a restart t.Run("prove start time changes", func(t *testing.T) { if userString != "root" && userString != "system" { @@ -72,17 +75,17 @@ func TestGetStartTime(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - startTime, err := GetStartTime(ctx, "nginx", Options{UserMode: false}) + startTime, err := GetStartTime(ctx, systemUnit, Options{UserMode: false}) if err != nil { t.Errorf("issue getting start time of nginx: %v", err) } time.Sleep(1 * time.Second) - err = Restart(ctx, "nginx", Options{UserMode: false}) + err = Restart(ctx, systemUnit, Options{UserMode: false}) if err != nil { t.Errorf("issue restarting nginx as %s: %v", userString, err) } time.Sleep(100 * time.Millisecond) - newStartTime, err := GetStartTime(ctx, "nginx", Options{UserMode: false}) + newStartTime, err := GetStartTime(ctx, systemUnit, Options{UserMode: false}) if err != nil { t.Errorf("issue getting second start time of nginx: %v", err) } @@ -94,6 +97,11 @@ func TestGetStartTime(t *testing.T) { } func TestGetNumRestarts(t *testing.T) { + userUnit := requireUserTestUnit(t) + systemUnit := requireSystemTestUnit(t) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + Start(ctx, userUnit, Options{UserMode: true}) type testCase struct { unit string err error @@ -106,18 +114,18 @@ func TestGetNumRestarts(t *testing.T) { // try nonexistant unit in user mode as user {"nonexistant", ErrValueNotSet, Options{UserMode: false}, true}, // try existing unit in user mode as user (loaded, so NRestarts=0 is valid) - {"syncthing", nil, Options{UserMode: true}, true}, + {userUnit, nil, Options{UserMode: true}, true}, // try existing unit in system mode as user - {"nginx", nil, Options{UserMode: false}, true}, + {systemUnit, nil, Options{UserMode: false}, true}, // Run these tests only as a superuser // try nonexistant unit in system mode as system {"nonexistant", ErrValueNotSet, Options{UserMode: false}, false}, // try existing unit in system mode as system - {"nginx", ErrBusFailure, Options{UserMode: true}, false}, + {systemUnit, ErrBusFailure, Options{UserMode: true}, false}, // try existing unit in system mode as system - {"nginx", nil, Options{UserMode: false}, false}, + {systemUnit, nil, Options{UserMode: false}, false}, } for _, tc := range testCases { func(tc testCase) { @@ -148,17 +156,17 @@ func TestGetNumRestarts(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - restarts, err := GetNumRestarts(ctx, "nginx", Options{UserMode: false}) + restarts, err := GetNumRestarts(ctx, systemUnit, Options{UserMode: false}) if err != nil { t.Errorf("issue getting number of restarts for nginx: %v", err) } - pid, err := GetPID(ctx, "nginx", Options{UserMode: false}) + pid, err := GetPID(ctx, systemUnit, Options{UserMode: false}) if err != nil { t.Errorf("issue getting MainPID for nginx as %s: %v", userString, err) } syscall.Kill(pid, syscall.SIGKILL) for { - running, errIsActive := IsActive(ctx, "nginx", Options{UserMode: false}) + running, errIsActive := IsActive(ctx, systemUnit, Options{UserMode: false}) if errIsActive != nil { t.Errorf("error asserting nginx is up: %v", errIsActive) break @@ -166,7 +174,7 @@ func TestGetNumRestarts(t *testing.T) { break } } - secondRestarts, err := GetNumRestarts(ctx, "nginx", Options{UserMode: false}) + secondRestarts, err := GetNumRestarts(ctx, systemUnit, Options{UserMode: false}) if err != nil { t.Errorf("issue getting second reading on number of restarts for nginx: %v", err) } @@ -177,6 +185,11 @@ func TestGetNumRestarts(t *testing.T) { } func TestGetMemoryUsage(t *testing.T) { + userUnit := requireUserTestUnit(t) + systemUnit := requireSystemTestUnit(t) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + Start(ctx, userUnit, Options{UserMode: true}) type testCase struct { unit string err error @@ -189,18 +202,18 @@ func TestGetMemoryUsage(t *testing.T) { // try nonexistant unit in user mode as user {"nonexistant", ErrValueNotSet, Options{UserMode: false}, true}, // try existing unit in user mode as user - {"syncthing", ErrValueNotSet, Options{UserMode: true}, true}, + {userUnit, nil, Options{UserMode: true}, true}, // try existing unit in system mode as user - {"nginx", nil, Options{UserMode: false}, true}, + {systemUnit, nil, Options{UserMode: false}, true}, // Run these tests only as a superuser // try nonexistant unit in system mode as system {"nonexistant", ErrValueNotSet, Options{UserMode: false}, false}, // try existing unit in system mode as system - {"nginx", ErrBusFailure, Options{UserMode: true}, false}, + {systemUnit, ErrBusFailure, Options{UserMode: true}, false}, // try existing unit in system mode as system - {"nginx", nil, Options{UserMode: false}, false}, + {systemUnit, nil, Options{UserMode: false}, false}, } for _, tc := range testCases { func(tc testCase) { @@ -224,7 +237,7 @@ func TestGetMemoryUsage(t *testing.T) { t.Run("prove memory usage values change across services", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() - bytes, err := GetMemoryUsage(ctx, "nginx", Options{UserMode: false}) + bytes, err := GetMemoryUsage(ctx, systemUnit, Options{UserMode: false}) if err != nil { t.Errorf("issue getting memory usage of nginx: %v", err) } @@ -287,6 +300,11 @@ func TestGetUnits(t *testing.T) { } func TestGetPID(t *testing.T) { + userUnit := requireUserTestUnit(t) + systemUnit := requireSystemTestUnit(t) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + Start(ctx, userUnit, Options{UserMode: true}) type testCase struct { unit string err error @@ -300,18 +318,18 @@ func TestGetPID(t *testing.T) { // try nonexistant unit in user mode as user {"nonexistant", nil, Options{UserMode: false}, true}, // try existing unit in user mode as user - {"syncthing", nil, Options{UserMode: true}, true}, + {userUnit, nil, Options{UserMode: true}, true}, // try existing unit in system mode as user - {"nginx", nil, Options{UserMode: false}, true}, + {systemUnit, nil, Options{UserMode: false}, true}, // Run these tests only as a superuser // try nonexistant unit in system mode as system {"nonexistant", nil, Options{UserMode: false}, false}, // try existing unit in system mode as system - {"nginx", ErrBusFailure, Options{UserMode: true}, false}, + {systemUnit, ErrBusFailure, Options{UserMode: true}, false}, // try existing unit in system mode as system - {"nginx", nil, Options{UserMode: false}, false}, + {systemUnit, nil, Options{UserMode: false}, false}, } for _, tc := range testCases { func(tc testCase) { @@ -338,7 +356,7 @@ func TestGetPID(t *testing.T) { if userString != "root" && userString != "system" { t.Skip("skipping superuser test while running as user") } - unit := "nginx" + unit := systemUnit ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() Restart(ctx, unit, Options{UserMode: true}) diff --git a/systemctl_test.go b/systemctl_test.go index af8817a..e630347 100644 --- a/systemctl_test.go +++ b/systemctl_test.go @@ -461,7 +461,7 @@ func TestStart(t *testing.T) { } func TestStatus(t *testing.T) { - unit := "nginx" + unit := requireSystemTestUnit(t) userMode := false opts := Options{UserMode: userMode} ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) diff --git a/testenv_test.go b/testenv_test.go new file mode 100644 index 0000000..e5799bf --- /dev/null +++ b/testenv_test.go @@ -0,0 +1,138 @@ +package systemctl + +import ( + "context" + "fmt" + "strings" + "sync" + "testing" + "time" +) + +var ( + testUnitOnce sync.Once + testUserUnit string + testSystemUnit string +) + +func initTestUnits(t *testing.T) { + t.Helper() + testUnitOnce.Do(func() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + testUserUnit = findUserTestUnit(ctx) + testSystemUnit = findSystemTestUnit(ctx) + }) +} + +func requireUserTestUnit(t *testing.T) string { + t.Helper() + initTestUnits(t) + if testUserUnit == "" { + t.Skip("skipping: no manageable active user service found") + } + return testUserUnit +} + +func requireSystemTestUnit(t *testing.T) string { + t.Helper() + initTestUnits(t) + if testSystemUnit == "" { + t.Skip("skipping: no readable active system service found") + } + return testSystemUnit +} + +func findUserTestUnit(ctx context.Context) string { + units, err := GetUnits(ctx, Options{UserMode: true}) + if err != nil { + return "" + } + preferred := []string{"ha-to-openclaw.service", "mail-to-openclaw.service", "buxfer-sync.service", "openclaw-gateway.service"} + for _, unit := range preferred { + if userUnitUsable(ctx, unit) { + return trimServiceSuffix(unit) + } + } + for _, unit := range units { + if unit.Load != "loaded" || unit.Active != "active" { + continue + } + if userUnitUsable(ctx, unit.Name) { + return trimServiceSuffix(unit.Name) + } + } + return "" +} + +func userUnitUsable(ctx context.Context, unit string) bool { + trimmed := trimServiceSuffix(unit) + if _, err := GetPID(ctx, trimmed, Options{UserMode: true}); err != nil { + return false + } + if _, err := GetStartTime(ctx, trimmed, Options{UserMode: true}); err != nil { + return false + } + if _, err := GetMemoryUsage(ctx, trimmed, Options{UserMode: true}); err != nil { + return false + } + return true +} + +func findSystemTestUnit(ctx context.Context) string { + units, err := GetUnits(ctx, Options{UserMode: false}) + if err != nil { + return "" + } + preferred := []string{"ssh.service", "cron.service", "dbus.service", "chrony.service", "containerd.service"} + for _, unit := range preferred { + if systemUnitUsable(ctx, unit) { + return trimServiceSuffix(unit) + } + } + for _, unit := range units { + if unit.Load != "loaded" || unit.Active != "active" || !strings.HasSuffix(unit.Name, ".service") { + continue + } + if systemUnitUsable(ctx, unit.Name) { + return trimServiceSuffix(unit.Name) + } + } + return "" +} + +func systemUnitUsable(ctx context.Context, unit string) bool { + trimmed := trimServiceSuffix(unit) + if _, err := GetPID(ctx, trimmed, Options{UserMode: false}); err != nil { + return false + } + if _, err := GetStartTime(ctx, trimmed, Options{UserMode: false}); err != nil { + return false + } + if _, err := GetMemoryUsage(ctx, trimmed, Options{UserMode: false}); err != nil { + return false + } + return true +} + +func trimServiceSuffix(unit string) string { + return strings.TrimSuffix(unit, ".service") +} + +func waitForActiveState(ctx context.Context, unit string, opts Options, want bool) error { + for { + active, err := IsActive(ctx, unit, opts) + if err != nil { + return err + } + if active == want { + return nil + } + select { + case <-ctx.Done(): + return fmt.Errorf("timed out waiting for %s active=%v: %w", unit, want, ctx.Err()) + case <-time.After(100 * time.Millisecond): + } + } +}