diff --git a/internal/extension/extension.go b/internal/extension/extension.go index 68e2813..7b32714 100644 --- a/internal/extension/extension.go +++ b/internal/extension/extension.go @@ -24,6 +24,7 @@ import ( "time" "github.com/DataDog/datadog-lambda-go/internal/logger" + "github.com/aws/aws-lambda-go/lambdacontext" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace" ) @@ -43,6 +44,8 @@ const ( DdSeverlessSpan ddTraceContext = "dd-tracer-serverless-span" DdLambdaResponse ddTraceContext = "dd-response" + + lambdaRuntimeAwsRequestIdHeader = "lambda-runtime-aws-request-id" ) const ( @@ -119,6 +122,13 @@ func (em *ExtensionManager) checkAgentRunning() { func (em *ExtensionManager) SendStartInvocationRequest(ctx context.Context, eventPayload json.RawMessage) context.Context { body := bytes.NewBuffer(eventPayload) req, _ := http.NewRequest(http.MethodPost, em.startInvocationUrl, body) + + if lc, ok := lambdacontext.FromContext(ctx); ok { + req.Header.Set(lambdaRuntimeAwsRequestIdHeader, lc.AwsRequestID) + } else { + logger.Error(fmt.Errorf("missing AWS Lambda context. Unable to set lambda-runtime-aws-request-id header")) + } + response, err := em.httpClient.Do(req) if response != nil && response.Body != nil { defer func() { @@ -157,6 +167,11 @@ func (em *ExtensionManager) SendEndInvocationRequest(ctx context.Context, functi } body := bytes.NewBuffer(content) req, _ := http.NewRequest(http.MethodPost, em.endInvocationUrl, body) + if lc, ok := lambdacontext.FromContext(ctx); ok { + req.Header.Set(lambdaRuntimeAwsRequestIdHeader, lc.AwsRequestID) + } else { + logger.Error(fmt.Errorf("missing AWS Lambda context. Unable to set lambda-runtime-aws-request-id header")) + } // Mark the invocation as an error if any if cfg.Error != nil { diff --git a/internal/extension/extension_test.go b/internal/extension/extension_test.go index 5b391f4..07ae0e9 100644 --- a/internal/extension/extension_test.go +++ b/internal/extension/extension_test.go @@ -19,6 +19,7 @@ import ( "testing" "github.com/DataDog/datadog-lambda-go/internal/logger" + "github.com/aws/aws-lambda-go/lambdacontext" "github.com/stretchr/testify/assert" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" @@ -153,6 +154,41 @@ func TestExtensionStartInvoke(t *testing.T) { assert.Nil(t, samplingPriority) } +func TestExtensionStartInvokeLambdaRequestId(t *testing.T) { + headers := http.Header{} + capturingClient := capturingClient{hdr: headers} + + em := &ExtensionManager{ + startInvocationUrl: startInvocationUrl, + httpClient: capturingClient, + } + + lc := &lambdacontext.LambdaContext{ + AwsRequestID: "test-request-id-12345", + } + ctx := lambdacontext.NewContext(context.TODO(), lc) + em.SendStartInvocationRequest(ctx, []byte{}) + + err := em.Flush() + + assert.Nil(t, err) + assert.Equal(t, "test-request-id-12345", headers.Get("lambda-runtime-aws-request-id")) +} + +func TestExtensionStartInvokeLambdaRequestIdError(t *testing.T) { + em := &ExtensionManager{ + startInvocationUrl: startInvocationUrl, + httpClient: &ClientSuccessStartInvoke{}, + } + + logOutput := captureLog(func() { em.SendStartInvocationRequest(context.TODO(), []byte{}) }) + err := em.Flush() + assert.Nil(t, err) + assert.Contains(t, logOutput, "missing AWS Lambda context. Unable to set lambda-runtime-aws-request-id header") + lines := strings.Split(strings.TrimSpace(logOutput), "\n") + assert.Equal(t, 1, len(lines)) +} + func TestExtensionStartInvokeWithTraceContext(t *testing.T) { headers := http.Header{} headers.Set(string(DdTraceId), mockTraceId) @@ -205,8 +241,9 @@ func TestExtensionEndInvocation(t *testing.T) { endInvocationUrl: endInvocationUrl, httpClient: &ClientSuccessEndInvoke{}, } + ctx := lambdacontext.NewContext(context.TODO(), &lambdacontext.LambdaContext{}) span := tracer.StartSpan("aws.lambda") - logOutput := captureLog(func() { em.SendEndInvocationRequest(context.TODO(), span, ddtrace.FinishConfig{}) }) + logOutput := captureLog(func() { em.SendEndInvocationRequest(ctx, span, ddtrace.FinishConfig{}) }) span.Finish() // Expected because the noopSpanContext doesn't have the SamplingPriority() and we cannot use the mock for the agent assert.Contains(t, logOutput, "could not get sampling priority from getSamplingPriority()") @@ -215,6 +252,50 @@ func TestExtensionEndInvocation(t *testing.T) { assert.Equal(t, 1, len(lines)) } +func TestExtensionEndInvokeLambdaRequestId(t *testing.T) { + headers := http.Header{} + capturingClient := capturingClient{hdr: headers} + + em := &ExtensionManager{ + endInvocationUrl: endInvocationUrl, + httpClient: capturingClient, + } + + lc := &lambdacontext.LambdaContext{ + AwsRequestID: "test-request-id-12345", + } + + ctx := lambdacontext.NewContext(context.TODO(), lc) + span := tracer.StartSpan("aws.lambda") + span.Finish() + cfg := ddtrace.FinishConfig{} + em.SendEndInvocationRequest(ctx, span, cfg) + err := em.Flush() + assert.Nil(t, err) + assert.Equal(t, "test-request-id-12345", headers.Get("lambda-runtime-aws-request-id")) +} + +func TestExtensionEndInvokeLambdaRequestIdError(t *testing.T) { + headers := http.Header{} + capturingClient := capturingClient{hdr: headers} + ctx := context.WithValue(context.TODO(), DdSamplingPriority, mockSamplingPriority) + ctx = context.WithValue(ctx, DdTraceId, mockTraceId) + em := &ExtensionManager{ + endInvocationUrl: endInvocationUrl, + httpClient: capturingClient, + } + + span := tracer.StartSpan("aws.lambda") + logOutput := captureLog(func() { em.SendEndInvocationRequest(ctx, span, ddtrace.FinishConfig{}) }) + span.Finish() + + err := em.Flush() + assert.Nil(t, err) + assert.Contains(t, logOutput, "missing AWS Lambda context. Unable to set lambda-runtime-aws-request-id header") + lines := strings.Split(strings.TrimSpace(logOutput), "\n") + assert.Equal(t, 1, len(lines)) +} + func TestExtensionEndInvocationError(t *testing.T) { em := &ExtensionManager{ endInvocationUrl: endInvocationUrl,