diff --git a/pkg/workflows/wasm/host/execution.go b/pkg/workflows/wasm/host/execution.go index ec9fd1bbfd..a12a6ceb33 100644 --- a/pkg/workflows/wasm/host/execution.go +++ b/pkg/workflows/wasm/host/execution.go @@ -11,6 +11,7 @@ import ( "google.golang.org/protobuf/proto" "github.com/smartcontractkit/chainlink-common/pkg/config" + "github.com/smartcontractkit/chainlink-common/pkg/settings/limits" sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" wfpb "github.com/smartcontractkit/chainlink-protos/workflows/go/v2" ) @@ -21,6 +22,7 @@ type execution[T any] struct { ctx context.Context capabilityResponses map[int32]<-chan *sdkpb.CapabilityResponse secretsResponses map[int32]<-chan *secretsResponse + pendingCallsLimiter limits.ResourcePoolLimiter[int] lock sync.RWMutex module *module executor ExecutionHelper @@ -38,12 +40,20 @@ type execution[T any] struct { // channel and storing each channel with a unique identifier for future // retrieval on await. func (e *execution[T]) callCapAsync(ctx context.Context, req *sdkpb.CapabilityRequest) error { + // Acquire a slot from the pool limiter to bound concurrency. + free, err := e.pendingCallsLimiter.Wait(ctx, 1) + if err != nil { + return err + } + ch := make(chan *sdkpb.CapabilityResponse, 1) e.lock.Lock() defer e.lock.Unlock() e.capabilityResponses[req.CallbackId] = ch go func() { + defer free() + resp, err := e.executor.CallCapability(ctx, req) if err != nil { @@ -95,12 +105,20 @@ type secretsResponse struct { } func (e *execution[T]) getSecretsAsync(ctx context.Context, req *sdkpb.GetSecretsRequest) error { + // Acquire a slot from the pool limiter to bound concurrency. + free, err := e.pendingCallsLimiter.Wait(ctx, 1) + if err != nil { + return err + } + ch := make(chan *secretsResponse, 1) e.lock.Lock() defer e.lock.Unlock() e.secretsResponses[req.CallbackId] = ch go func() { + defer free() + resp, err := e.executor.GetSecrets(ctx, req) sr := &secretsResponse{responses: resp, err: err} diff --git a/pkg/workflows/wasm/host/execution_await_order_test.go b/pkg/workflows/wasm/host/execution_await_order_test.go index 29084302a2..f250dd69fc 100644 --- a/pkg/workflows/wasm/host/execution_await_order_test.go +++ b/pkg/workflows/wasm/host/execution_await_order_test.go @@ -10,6 +10,8 @@ import ( "google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/emptypb" + "github.com/smartcontractkit/chainlink-common/pkg/settings/cresettings" + "github.com/smartcontractkit/chainlink-common/pkg/settings/limits" sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" wfpb "github.com/smartcontractkit/chainlink-protos/workflows/go/v2" ) @@ -69,6 +71,7 @@ func TestAwaitCapabilities_headOfLineBlocksOnEarlierID(t *testing.T) { exec := &execution[*sdkpb.ExecutionResult]{ ctx: t.Context(), capabilityResponses: make(map[int32]<-chan *sdkpb.CapabilityResponse), + pendingCallsLimiter: limits.GlobalResourcePoolLimiter(cresettings.Default.PerWorkflow.CapabilityConcurrencyLimit.DefaultValue), executor: stub, } diff --git a/pkg/workflows/wasm/host/execution_semaphore_test.go b/pkg/workflows/wasm/host/execution_semaphore_test.go new file mode 100644 index 0000000000..ed7c8720df --- /dev/null +++ b/pkg/workflows/wasm/host/execution_semaphore_test.go @@ -0,0 +1,289 @@ +package host + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/emptypb" + + "github.com/smartcontractkit/chainlink-common/pkg/settings/limits" + sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" + wfpb "github.com/smartcontractkit/chainlink-protos/workflows/go/v2" +) + +// slowCapStub delays CallCapability by a configurable duration and counts in-flight calls. +type slowCapStub struct { + delay time.Duration + inflight atomic.Int32 + peakLoad atomic.Int32 + callCount atomic.Int32 +} + +func (s *slowCapStub) CallCapability(_ context.Context, _ *sdkpb.CapabilityRequest) (*sdkpb.CapabilityResponse, error) { + s.callCount.Add(1) + cur := s.inflight.Add(1) + for { + peak := s.peakLoad.Load() + if cur <= peak || s.peakLoad.CompareAndSwap(peak, cur) { + break + } + } + time.Sleep(s.delay) + s.inflight.Add(-1) + + payload, _ := anypb.New(&emptypb.Empty{}) + return &sdkpb.CapabilityResponse{ + Response: &sdkpb.CapabilityResponse_Payload{Payload: payload}, + }, nil +} + +func (s *slowCapStub) GetSecrets(context.Context, *sdkpb.GetSecretsRequest) ([]*sdkpb.SecretResponse, error) { + return nil, nil +} +func (s *slowCapStub) GetWorkflowExecutionID() string { return "test-exec" } +func (s *slowCapStub) GetNodeTime() time.Time { return time.Now() } +func (s *slowCapStub) GetDONTime() (time.Time, error) { return time.Now(), nil } +func (s *slowCapStub) EmitUserLog(string) error { return nil } +func (s *slowCapStub) EmitUserMetric(context.Context, *wfpb.WorkflowUserMetric) error { + return nil +} + +var _ ExecutionHelper = (*slowCapStub)(nil) + +func newTestExec(maxPending int, stub ExecutionHelper) *execution[*sdkpb.ExecutionResult] { + return &execution[*sdkpb.ExecutionResult]{ + ctx: context.Background(), + capabilityResponses: make(map[int32]<-chan *sdkpb.CapabilityResponse), + secretsResponses: make(map[int32]<-chan *secretsResponse), + pendingCallsLimiter: limits.GlobalResourcePoolLimiter[int](maxPending), + executor: stub, + } +} + +// TestSemaphore_BackpressureBlocksCallN proves that call N+1 blocks when +// N == MaxPendingCalls and nothing has been awaited yet. +func TestSemaphore_BackpressureBlocksCallN(t *testing.T) { + t.Parallel() + const max = 5 + + // Use a delay longer than the check window so goroutines hold their slots. + stub := &slowCapStub{delay: 5 * time.Second} + exec := newTestExec(max, stub) + + ctx := t.Context() + + // Fill semaphore. + for i := int32(0); i < max; i++ { + require.NoError(t, exec.callCapAsync(ctx, &sdkpb.CapabilityRequest{CallbackId: i})) + } + + // Next call should block. + blocked := make(chan struct{}) + go func() { + _ = exec.callCapAsync(ctx, &sdkpb.CapabilityRequest{CallbackId: max}) + close(blocked) + }() + + select { + case <-blocked: + t.Fatal("call max+1 did not block; semaphore backpressure broken") + case <-time.After(200 * time.Millisecond): + // expected — still blocked + } + + // Await the first call to free a slot. + resp, err := exec.awaitCapabilities(ctx, &sdkpb.AwaitCapabilitiesRequest{Ids: []int32{0}}) + require.NoError(t, err) + require.Len(t, resp.Responses, 1) + + // Now the blocked call should proceed. + select { + case <-blocked: + // success + case <-time.After(2 * time.Second): + t.Fatal("call max+1 did not unblock after await freed a slot") + } +} + +// TestSemaphore_HighThroughputBounded issues many calls in batches, +// awaiting each batch before the next. Peak in-flight goroutines must never +// exceed MaxPendingCalls. +func TestSemaphore_HighThroughputBounded(t *testing.T) { + t.Parallel() + const max = 10 + const batches = 50 + const callsPerBatch = max + + stub := &slowCapStub{delay: 1 * time.Millisecond} + exec := newTestExec(max, stub) + + ctx := t.Context() + var callId int32 + + for b := 0; b < batches; b++ { + ids := make([]int32, callsPerBatch) + for i := 0; i < callsPerBatch; i++ { + ids[i] = callId + require.NoError(t, exec.callCapAsync(ctx, &sdkpb.CapabilityRequest{CallbackId: callId})) + callId++ + } + resp, err := exec.awaitCapabilities(ctx, &sdkpb.AwaitCapabilitiesRequest{Ids: ids}) + require.NoError(t, err) + require.Len(t, resp.Responses, callsPerBatch) + } + + assert.LessOrEqual(t, int(stub.peakLoad.Load()), max, + "peak in-flight goroutines exceeded MaxPendingCalls") + assert.Equal(t, int32(batches*callsPerBatch), stub.callCount.Load()) +} + +// TestSemaphore_ContextCancelUnblocksCall proves that a blocked callCapAsync +// returns ctx.Err() when the context is cancelled. +func TestSemaphore_ContextCancelUnblocksCall(t *testing.T) { + t.Parallel() + const max = 2 + + stub := &slowCapStub{delay: 5 * time.Second} // very slow, won't finish + exec := newTestExec(max, stub) + + ctx, cancel := context.WithCancel(t.Context()) + + // Fill semaphore. + for i := int32(0); i < max; i++ { + require.NoError(t, exec.callCapAsync(ctx, &sdkpb.CapabilityRequest{CallbackId: i})) + } + + // Next call will block on semaphore. + var callErr error + done := make(chan struct{}) + go func() { + callErr = exec.callCapAsync(ctx, &sdkpb.CapabilityRequest{CallbackId: max}) + close(done) + }() + + // Cancel context. + cancel() + + select { + case <-done: + require.ErrorIs(t, callErr, context.Canceled) + case <-time.After(2 * time.Second): + t.Fatal("callCapAsync did not unblock after context cancel") + } +} + +// TestSemaphore_SlotsRecycledCorrectly ensures that after many await cycles, +// the semaphore is back to its full capacity and new calls can proceed. +func TestSemaphore_SlotsRecycledCorrectly(t *testing.T) { + t.Parallel() + const max = 5 + const rounds = 100 + + stub := &slowCapStub{delay: 0} + exec := newTestExec(max, stub) + + ctx := t.Context() + + for r := 0; r < rounds; r++ { + ids := make([]int32, max) + for i := int32(0); i < max; i++ { + id := int32(r*max) + i + ids[i] = id + require.NoError(t, exec.callCapAsync(ctx, &sdkpb.CapabilityRequest{CallbackId: id})) + } + _, err := exec.awaitCapabilities(ctx, &sdkpb.AwaitCapabilitiesRequest{Ids: ids}) + require.NoError(t, err) + } + + // After all rounds, all slots should be available again. + // Goroutines release slots via defer after the channel send, so allow a + // brief window for the last batch of defers to execute. + assert.Eventually(t, func() bool { + avail, err := exec.pendingCallsLimiter.Available(ctx) + return err == nil && avail == max + }, time.Second, 5*time.Millisecond, + "limiter still has occupied slots after all awaits completed") +} + +// TestSemaphore_MapCleanedOnAwait verifies the capabilityResponses map +// doesn't leak entries. +func TestSemaphore_MapCleanedOnAwait(t *testing.T) { + t.Parallel() + const max = 10 + const total = 200 + + stub := &slowCapStub{delay: 0} + exec := newTestExec(max, stub) + + ctx := t.Context() + + for i := int32(0); i < total; i += max { + ids := make([]int32, max) + for j := int32(0); j < max; j++ { + id := i + j + ids[j] = id + require.NoError(t, exec.callCapAsync(ctx, &sdkpb.CapabilityRequest{CallbackId: id})) + } + _, err := exec.awaitCapabilities(ctx, &sdkpb.AwaitCapabilitiesRequest{Ids: ids}) + require.NoError(t, err) + } + + exec.lock.RLock() + mapLen := len(exec.capabilityResponses) + exec.lock.RUnlock() + + assert.Equal(t, 0, mapLen, "capabilityResponses map leaked %d entries", mapLen) +} + +// TestSemaphore_ConcurrentCallAndAwait exercises concurrent callers issuing +// callCapAsync from multiple goroutines while others await, simulating the +// real engine dispatching multiple workflow executions. +func TestSemaphore_ConcurrentCallAndAwait(t *testing.T) { + t.Parallel() + const max = 10 + const workers = 20 + const callsPerWorker = 50 + + stub := &slowCapStub{delay: 10 * time.Microsecond} + // Each worker gets its own execution (like real CRE — one per WASM invocation). + // We want to prove that WITHIN a single execution, concurrent isn't needed because + // WASM is single-threaded. But let's stress the shared semaphore anyway. + exec := newTestExec(max, stub) + + ctx := t.Context() + var wg sync.WaitGroup + + // Simulate sequential call-then-await pattern from a single WASM thread + // (the real case). We run it in parallel workers to stress-test the lock. + for w := 0; w < workers; w++ { + wg.Go(func() { + for i := 0; i < callsPerWorker; i++ { + id := int32(w*callsPerWorker + i) + err := exec.callCapAsync(ctx, &sdkpb.CapabilityRequest{CallbackId: id}) + if err != nil { + return + } + _, err = exec.awaitCapabilities(ctx, &sdkpb.AwaitCapabilitiesRequest{Ids: []int32{id}}) + if err != nil { + return + } + } + }) + } + + wg.Wait() + + assert.LessOrEqual(t, int(stub.peakLoad.Load()), max) + assert.Equal(t, int32(workers*callsPerWorker), stub.callCount.Load()) + assert.Eventually(t, func() bool { + avail, err := exec.pendingCallsLimiter.Available(context.Background()) + return err == nil && avail == max + }, time.Second, 5*time.Millisecond, + "limiter still has occupied slots after all awaits completed") +} diff --git a/pkg/workflows/wasm/host/module.go b/pkg/workflows/wasm/host/module.go index 07b42a00eb..cce932515a 100644 --- a/pkg/workflows/wasm/host/module.go +++ b/pkg/workflows/wasm/host/module.go @@ -25,6 +25,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/custmsg" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/settings" + "github.com/smartcontractkit/chainlink-common/pkg/settings/cresettings" "github.com/smartcontractkit/chainlink-common/pkg/settings/limits" dagsdk "github.com/smartcontractkit/chainlink-common/pkg/workflows/sdk" "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm" @@ -61,16 +62,20 @@ type DeterminismConfig struct { Seed int64 } type ModuleConfig struct { - TickInterval time.Duration - Timeout *time.Duration - MaxMemoryMBs uint64 - MinMemoryMBs uint64 - MemoryLimiter limits.BoundLimiter[config.Size] // supersedes Max/MinMemoryMBs if set - InitialFuel uint64 - Logger logger.Logger - IsUncompressed bool - Fetch func(ctx context.Context, req *FetchRequest) (*FetchResponse, error) - MaxFetchRequests int + TickInterval time.Duration + Timeout *time.Duration + MaxMemoryMBs uint64 + MinMemoryMBs uint64 + MemoryLimiter limits.BoundLimiter[config.Size] // supersedes Max/MinMemoryMBs if set + InitialFuel uint64 + Logger logger.Logger + IsUncompressed bool + Fetch func(ctx context.Context, req *FetchRequest) (*FetchResponse, error) + MaxFetchRequests int + // PendingCallsLimiter bounds concurrent in-flight capability and secrets + // calls. When scoped (e.g. ScopeWorkflow), each workflow ID gets its own + // pool; when global/unscoped, the limit is shared across all callers. + PendingCallsLimiter limits.ResourcePoolLimiter[int] MaxCompressedBinarySize uint64 MaxCompressedBinaryLimiter limits.BoundLimiter[config.Size] // supersedes MaxCompressedBinarySize if set MaxDecompressedBinarySize uint64 @@ -192,6 +197,15 @@ func NewModule(ctx context.Context, modCfg *ModuleConfig, binary []byte, opts .. modCfg.MaxFetchRequests = defaultMaxFetchRequests } + if modCfg.PendingCallsLimiter == nil { + lf := limits.Factory{Logger: modCfg.Logger} + var err error + modCfg.PendingCallsLimiter, err = limits.MakeResourcePoolLimiter(lf, cresettings.Default.PerWorkflow.CapabilityConcurrencyLimit) + if err != nil { + return nil, fmt.Errorf("failed to make pending calls limiter: %w", err) + } + } + if modCfg.Labeler == nil { modCfg.Labeler = &unimplementedMessageEmitter{} } @@ -693,6 +707,7 @@ func runWasm[I, O proto.Message]( ctx: ctxWithTimeout, capabilityResponses: map[int32]<-chan *sdkpb.CapabilityResponse{}, secretsResponses: map[int32]<-chan *secretsResponse{}, + pendingCallsLimiter: m.cfg.PendingCallsLimiter, module: m, executor: helper, donSeed: donSeed, diff --git a/pkg/workflows/wasm/host/module_test.go b/pkg/workflows/wasm/host/module_test.go index ff7307ca79..7aa4343621 100644 --- a/pkg/workflows/wasm/host/module_test.go +++ b/pkg/workflows/wasm/host/module_test.go @@ -15,6 +15,8 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/custmsg" "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/settings/cresettings" + "github.com/smartcontractkit/chainlink-common/pkg/settings/limits" "github.com/smartcontractkit/chainlink-common/pkg/utils/matches" wasmpb "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/pb" sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" @@ -633,6 +635,7 @@ func Test_CallAwaitRace(t *testing.T) { exec := &execution[*wasmpb.ExecutionResult]{ module: m, capabilityResponses: map[int32]<-chan *sdkpb.CapabilityResponse{}, + pendingCallsLimiter: limits.GlobalResourcePoolLimiter(cresettings.Default.PerWorkflow.CapabilityConcurrencyLimit.DefaultValue), ctx: t.Context(), executor: mockExecHelper, } diff --git a/pkg/workflows/wasm/host/wasm_nodag_test.go b/pkg/workflows/wasm/host/wasm_nodag_test.go index 917692b523..d7461b0e7a 100644 --- a/pkg/workflows/wasm/host/wasm_nodag_test.go +++ b/pkg/workflows/wasm/host/wasm_nodag_test.go @@ -11,6 +11,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/protoc/pkg/test_capabilities/basictrigger" "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/settings/cresettings" "github.com/smartcontractkit/chainlink-common/pkg/settings/limits" "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" wfpb "github.com/smartcontractkit/chainlink-protos/workflows/go/v2" @@ -313,8 +314,9 @@ func Test_NoDAG_EmitMetricDisabled(t *testing.T) { func defaultNoDAGModCfg(t testing.TB) *ModuleConfig { return &ModuleConfig{ - Logger: logger.Test(t), - IsUncompressed: true, + Logger: logger.Test(t), + IsUncompressed: true, + PendingCallsLimiter: limits.GlobalResourcePoolLimiter(cresettings.Default.PerWorkflow.CapabilityConcurrencyLimit.DefaultValue), } }