diff --git a/internal/controller/nodedisruptionbudget_controller.go b/internal/controller/nodedisruptionbudget_controller.go index 070d54f..9f930bd 100644 --- a/internal/controller/nodedisruptionbudget_controller.go +++ b/internal/controller/nodedisruptionbudget_controller.go @@ -176,6 +176,12 @@ type NodeDisruptionBudgetResolver struct { Resolver resolver.Resolver } +func computeNodeDisruptionBudgetDisruptionsAllowed(maxDisruptedNodes, minUndisruptedNodes, watchedNodes, currentDisruptions int) int { + disruptionsForMax := maxDisruptedNodes - currentDisruptions + disruptionsForMin := (watchedNodes - currentDisruptions) - minUndisruptedNodes + return int(math.Min(float64(disruptionsForMax), float64(disruptionsForMin))) +} + // Sync ensure the budget's status is up to date func (r *NodeDisruptionBudgetResolver) Sync(ctx context.Context) error { nodeNames, err := r.GetSelectedNodes(ctx) @@ -192,9 +198,12 @@ func (r *NodeDisruptionBudgetResolver) Sync(ctx context.Context) error { r.NodeDisruptionBudget.Status.WatchedNodes = nodes r.NodeDisruptionBudget.Status.CurrentDisruptions = disruptionCount - disruptionsForMax := r.NodeDisruptionBudget.Spec.MaxDisruptedNodes - disruptionCount - disruptionsForMin := (len(nodes) - disruptionCount) - r.NodeDisruptionBudget.Spec.MinUndisruptedNodes - r.NodeDisruptionBudget.Status.DisruptionsAllowed = int(math.Min(float64(disruptionsForMax), float64(disruptionsForMin))) - disruptionCount + r.NodeDisruptionBudget.Status.DisruptionsAllowed = computeNodeDisruptionBudgetDisruptionsAllowed( + r.NodeDisruptionBudget.Spec.MaxDisruptedNodes, + r.NodeDisruptionBudget.Spec.MinUndisruptedNodes, + len(nodes), + disruptionCount, + ) r.NodeDisruptionBudget.Status.Disruptions = disruptions return nil } diff --git a/internal/controller/nodedisruptionbudget_math_test.go b/internal/controller/nodedisruptionbudget_math_test.go new file mode 100644 index 0000000..67e3249 --- /dev/null +++ b/internal/controller/nodedisruptionbudget_math_test.go @@ -0,0 +1,61 @@ +package controller + +import "testing" + +func TestComputeNodeDisruptionBudgetDisruptionsAllowed(t *testing.T) { + tests := []struct { + name string + maxDisruptedNodes int + minUndisruptedNodes int + watchedNodes int + currentDisruptions int + expected int + }{ + { + name: "at max disruption returns zero", + maxDisruptedNodes: 1, + minUndisruptedNodes: 0, + watchedNodes: 34, + currentDisruptions: 1, + expected: 0, + }, + { + name: "max side is limiting factor", + maxDisruptedNodes: 3, + minUndisruptedNodes: 0, + watchedNodes: 10, + currentDisruptions: 2, + expected: 1, + }, + { + name: "min undisrupted side is limiting factor", + maxDisruptedNodes: 10, + minUndisruptedNodes: 10, + watchedNodes: 12, + currentDisruptions: 1, + expected: 1, + }, + { + name: "returns negative when min undisrupted cannot be respected", + maxDisruptedNodes: 10, + minUndisruptedNodes: 10, + watchedNodes: 5, + currentDisruptions: 1, + expected: -6, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := computeNodeDisruptionBudgetDisruptionsAllowed( + tt.maxDisruptedNodes, + tt.minUndisruptedNodes, + tt.watchedNodes, + tt.currentDisruptions, + ) + if got != tt.expected { + t.Fatalf("unexpected disruptions allowed: got=%d expected=%d", got, tt.expected) + } + }) + } +}