diff --git a/internal/extension/extension.go b/internal/extension/extension.go index 8e54afc..68e2813 100644 --- a/internal/extension/extension.go +++ b/internal/extension/extension.go @@ -35,6 +35,7 @@ const ( DdParentId ddTraceContext = "x-datadog-parent-id" DdSpanId ddTraceContext = "x-datadog-span-id" DdSamplingPriority ddTraceContext = "x-datadog-sampling-priority" + DdOrigin ddTraceContext = "x-datadog-origin" DdInvocationError ddTraceContext = "x-datadog-invocation-error" DdInvocationErrorMsg ddTraceContext = "x-datadog-invocation-error-msg" DdInvocationErrorType ddTraceContext = "x-datadog-invocation-error-type" diff --git a/internal/trace/constants.go b/internal/trace/constants.go index 21a66e6..36cf649 100644 --- a/internal/trace/constants.go +++ b/internal/trace/constants.go @@ -12,6 +12,7 @@ const ( traceIDHeader = "x-datadog-trace-id" parentIDHeader = "x-datadog-parent-id" samplingPriorityHeader = "x-datadog-sampling-priority" + originHeader = "x-datadog-origin" ) const ( diff --git a/internal/trace/context.go b/internal/trace/context.go index ac8320e..c245bec 100644 --- a/internal/trace/context.go +++ b/internal/trace/context.go @@ -160,6 +160,13 @@ func getTraceContext(ctx context.Context, headers map[string]string) (TraceConte samplingPriority = "1" //sampler-keep } + // try to pull datadog origin from either headers or context + if origin, ok := headers[originHeader]; ok { + tc[originHeader] = origin + } else if origin, ok := ctx.Value(extension.DdOrigin).(string); ok { + tc[originHeader] = origin + } + tc[samplingPriorityHeader] = samplingPriority tc[traceIDHeader] = traceID tc[parentIDHeader] = parentID diff --git a/internal/trace/context_test.go b/internal/trace/context_test.go index 7c98839..9b12217 100644 --- a/internal/trace/context_test.go +++ b/internal/trace/context_test.go @@ -44,7 +44,7 @@ func mockLambdaXRayTraceContext(ctx context.Context, traceID, parentID string, s return context.WithValue(ctx, xray.LambdaTraceHeaderKey, headerString) } -func mockTraceContext(traceID, parentID, samplingPriority string) context.Context { +func mockTraceContext(traceID, parentID, samplingPriority, origin string) context.Context { ctx := context.Background() if traceID != "" { ctx = context.WithValue(ctx, extension.DdTraceId, traceID) @@ -55,6 +55,9 @@ func mockTraceContext(traceID, parentID, samplingPriority string) context.Contex if samplingPriority != "" { ctx = context.WithValue(ctx, extension.DdSamplingPriority, samplingPriority) } + if origin != "" { + ctx = context.WithValue(ctx, extension.DdOrigin, origin) + } return ctx } @@ -135,6 +138,7 @@ func TestGetDatadogTraceContextFromContextObject(t *testing.T) { traceID string parentID string samplingPriority string + origin string expectTC TraceContext expectOk bool }{ @@ -142,6 +146,20 @@ func TestGetDatadogTraceContextFromContextObject(t *testing.T) { "trace", "parent", "sampling", + "origin", + TraceContext{ + "x-datadog-trace-id": "trace", + "x-datadog-parent-id": "parent", + "x-datadog-sampling-priority": "sampling", + "x-datadog-origin": "origin", + }, + true, + }, + { + "trace", + "parent", + "sampling", + "", TraceContext{ "x-datadog-trace-id": "trace", "x-datadog-parent-id": "parent", @@ -153,6 +171,7 @@ func TestGetDatadogTraceContextFromContextObject(t *testing.T) { "", "parent", "sampling", + "origin", TraceContext{}, false, }, @@ -160,6 +179,7 @@ func TestGetDatadogTraceContextFromContextObject(t *testing.T) { "trace", "", "sampling", + "", TraceContext{}, false, }, @@ -167,6 +187,7 @@ func TestGetDatadogTraceContextFromContextObject(t *testing.T) { "trace", "parent", "", + "", TraceContext{ "x-datadog-trace-id": "trace", "x-datadog-parent-id": "parent", @@ -178,8 +199,8 @@ func TestGetDatadogTraceContextFromContextObject(t *testing.T) { ev := loadRawJSON(t, "../testdata/non-proxy-no-headers.json") for _, test := range testcases { - t.Run(test.traceID+test.parentID+test.samplingPriority, func(t *testing.T) { - ctx := mockTraceContext(test.traceID, test.parentID, test.samplingPriority) + t.Run(test.traceID+test.parentID+test.samplingPriority+test.origin, func(t *testing.T) { + ctx := mockTraceContext(test.traceID, test.parentID, test.samplingPriority, test.origin) tc, ok := getTraceContext(ctx, getHeadersFromEventHeaders(ctx, *ev)) assert.Equal(t, test.expectTC, tc) assert.Equal(t, test.expectOk, ok)