Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import { UnauthorizedException } from '@nestjs/common';
import { Test, TestingModule } from '@nestjs/testing';

import { beforeEach, describe, expect, it, vi } from 'vitest';
Expand All @@ -7,14 +6,13 @@ import { OidcRedirectUriService } from '@app/unraid-api/graph/resolvers/sso/clie
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
import { validateRedirectUri } from '@app/unraid-api/utils/redirect-uri-validator.js';

// Mock the redirect URI validator
vi.mock('@app/unraid-api/utils/redirect-uri-validator.js', () => ({
validateRedirectUri: vi.fn(),
}));

describe('OidcRedirectUriService', () => {
let service: OidcRedirectUriService;
let oidcConfig: any;
let oidcConfig: { getConfig: ReturnType<typeof vi.fn> };

beforeEach(async () => {
vi.clearAllMocks();
Expand All @@ -39,19 +37,16 @@ describe('OidcRedirectUriService', () => {
});

describe('getRedirectUri', () => {
it('should return valid redirect URI when validation passes', async () => {
const requestOrigin = 'https://example.com';
const requestHeaders = {
'x-forwarded-proto': 'https',
'x-forwarded-host': 'example.com',
};

it('returns a callback URI when validation passes', async () => {
(validateRedirectUri as any).mockReturnValue({
isValid: true,
validatedUri: 'https://example.com',
});

const result = await service.getRedirectUri(requestOrigin, requestHeaders);
const result = await service.getRedirectUri('https://example.com', {
protocol: 'https',
host: 'example.com',
});

expect(result).toBe('https://example.com/graphql/api/auth/oidc/callback');
expect(validateRedirectUri).toHaveBeenCalledWith(
Expand All @@ -63,41 +58,35 @@ describe('OidcRedirectUriService', () => {
);
});

it('should throw UnauthorizedException when validation fails', async () => {
const requestOrigin = 'https://evil.com';
const requestHeaders = {
'x-forwarded-proto': 'https',
'x-forwarded-host': 'example.com',
};

it('throws when validation fails', async () => {
(validateRedirectUri as any).mockReturnValue({
isValid: false,
reason: 'Origin not allowed',
});

await expect(service.getRedirectUri(requestOrigin, requestHeaders)).rejects.toThrow(
UnauthorizedException
);
await expect(
service.getRedirectUri('https://evil.com', {
protocol: 'https',
host: 'example.com',
})
).rejects.toThrow();
});

it('should handle missing allowed origins', async () => {
it('passes through missing allowed origins', async () => {
oidcConfig.getConfig.mockResolvedValue({
providers: [],
defaultAllowedOrigins: undefined,
});

const requestOrigin = 'https://example.com';
const requestHeaders = {
'x-forwarded-proto': 'https',
'x-forwarded-host': 'example.com',
};

(validateRedirectUri as any).mockReturnValue({
isValid: true,
validatedUri: 'https://example.com',
});

const result = await service.getRedirectUri(requestOrigin, requestHeaders);
const result = await service.getRedirectUri('https://example.com', {
protocol: 'https',
host: 'example.com',
});

expect(result).toBe('https://example.com/graphql/api/auth/oidc/callback');
expect(validateRedirectUri).toHaveBeenCalledWith(
Expand All @@ -109,114 +98,63 @@ describe('OidcRedirectUriService', () => {
);
});

it('should extract protocol from headers correctly', async () => {
const requestOrigin = 'https://example.com';
const requestHeaders = {
'x-forwarded-proto': ['https', 'http'],
host: 'example.com',
};

it('uses the trusted request origin info provided by Fastify', async () => {
(validateRedirectUri as any).mockReturnValue({
isValid: true,
validatedUri: 'https://example.com',
validatedUri: 'https://nas.domain.com/graphql/api/auth/oidc/callback',
});

const result = await service.getRedirectUri(requestOrigin, requestHeaders);

expect(result).toBe('https://example.com/graphql/api/auth/oidc/callback');
expect(validateRedirectUri).toHaveBeenCalledWith(
'https://example.com',
'https', // Should use first value from array
'example.com',
expect.anything(),
expect.anything()
const result = await service.getRedirectUri(
'https://nas.domain.com/graphql/api/auth/oidc/callback',
{
protocol: 'https',
host: 'nas.domain.com',
}
);
});

it('should use host header as fallback', async () => {
const requestOrigin = 'https://example.com';
const requestHeaders = {
host: 'example.com',
};

(validateRedirectUri as any).mockReturnValue({
isValid: true,
validatedUri: 'https://example.com',
});

const result = await service.getRedirectUri(requestOrigin, requestHeaders);

expect(result).toBe('https://example.com/graphql/api/auth/oidc/callback');
expect(result).toBe('https://nas.domain.com/graphql/api/auth/oidc/callback');
expect(validateRedirectUri).toHaveBeenCalledWith(
'https://example.com',
'https', // Inferred from requestOrigin when x-forwarded-proto not present
'example.com',
'https://nas.domain.com/graphql/api/auth/oidc/callback',
'https',
'nas.domain.com',
expect.anything(),
expect.anything()
);
});

it('should prefer x-forwarded-host over host header', async () => {
const requestOrigin = 'https://example.com';
const requestHeaders = {
'x-forwarded-host': 'forwarded.example.com',
host: 'original.example.com',
};

it('allows host values with ports', async () => {
(validateRedirectUri as any).mockReturnValue({
isValid: true,
validatedUri: 'https://example.com',
});

const result = await service.getRedirectUri(requestOrigin, requestHeaders);
const result = await service.getRedirectUri('https://example.com', {
protocol: 'https',
host: 'forwarded.example.com:8443',
});

expect(result).toBe('https://example.com/graphql/api/auth/oidc/callback');
expect(validateRedirectUri).toHaveBeenCalledWith(
'https://example.com',
'https', // Inferred from requestOrigin when x-forwarded-proto not present
'forwarded.example.com', // Should use x-forwarded-host
'https',
'forwarded.example.com:8443',
expect.anything(),
expect.anything()
);
});

it('should throw when URL construction fails', async () => {
const requestOrigin = 'https://example.com';
const requestHeaders = {};

(validateRedirectUri as any).mockReturnValue({
isValid: true,
validatedUri: 'invalid-url', // Invalid URL
});

await expect(service.getRedirectUri(requestOrigin, requestHeaders)).rejects.toThrow(
UnauthorizedException
);
});

it('should handle array values in headers correctly', async () => {
const requestOrigin = 'https://example.com';
const requestHeaders = {
'x-forwarded-proto': ['https'],
'x-forwarded-host': ['forwarded.example.com', 'another.example.com'],
host: ['original.example.com'],
};

it('throws when URL construction fails after validation', async () => {
(validateRedirectUri as any).mockReturnValue({
isValid: true,
validatedUri: 'https://example.com',
validatedUri: 'invalid-url',
});

const result = await service.getRedirectUri(requestOrigin, requestHeaders);

expect(result).toBe('https://example.com/graphql/api/auth/oidc/callback');
expect(validateRedirectUri).toHaveBeenCalledWith(
'https://example.com',
'https',
'forwarded.example.com', // Should use first value from array
expect.anything(),
expect.anything()
);
await expect(
service.getRedirectUri('https://example.com', {
protocol: 'https',
host: 'example.com',
})
).rejects.toThrow();
});
});
});
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { Injectable, Logger, UnauthorizedException } from '@nestjs/common';

import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
import { RequestOriginInfo } from '@app/unraid-api/graph/resolvers/sso/utils/oidc-request-origin.util.js';
import { validateRedirectUri } from '@app/unraid-api/utils/redirect-uri-validator.js';

@Injectable()
Expand All @@ -10,12 +11,9 @@ export class OidcRedirectUriService {

constructor(private readonly oidcConfig: OidcConfigPersistence) {}

async getRedirectUri(
requestOrigin: string,
requestHeaders: Record<string, string | string[] | undefined>
): Promise<string> {
async getRedirectUri(requestOrigin: string, requestOriginInfo: RequestOriginInfo): Promise<string> {
// Extract protocol and host from headers for validation
const { protocol, host } = this.getRequestOriginInfo(requestHeaders, requestOrigin);
const { protocol, host } = requestOriginInfo;

// Get the global allowed origins from OIDC config
const config = await this.oidcConfig.getConfig();
Expand Down Expand Up @@ -61,37 +59,4 @@ export class OidcRedirectUriService {
throw new UnauthorizedException('Invalid redirect_uri');
}
}

private getRequestOriginInfo(
requestHeaders: Record<string, string | string[] | undefined>,
requestOrigin?: string
): {
protocol: string;
host: string | undefined;
} {
// Extract protocol from x-forwarded-proto or infer from requestOrigin, default to http
const forwardedProto = requestHeaders['x-forwarded-proto'];
const protocol = forwardedProto
? Array.isArray(forwardedProto)
? forwardedProto[0]
: forwardedProto
: requestOrigin?.startsWith('https')
? 'https'
: 'http';

// Extract host from x-forwarded-host or host header
const forwardedHost = requestHeaders['x-forwarded-host'];
const hostHeader = requestHeaders['host'];
const host = forwardedHost
? Array.isArray(forwardedHost)
? forwardedHost[0]
: forwardedHost
: hostHeader
? Array.isArray(hostHeader)
? hostHeader[0]
: hostHeader
: undefined;

return { protocol, host };
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,10 @@ describe('OidcService Integration Tests - Enhanced Logging', () => {
providerId: 'auth-url-test',
state: 'test-state',
requestOrigin: 'http://test.local',
requestHeaders: { host: 'test.local' },
requestOriginInfo: {
protocol: 'http',
host: 'test.local',
},
});

// Verify URL building logs
Expand Down Expand Up @@ -278,9 +281,9 @@ describe('OidcService Integration Tests - Enhanced Logging', () => {
providerId: 'manual-endpoints',
state: 'test-state',
requestOrigin: 'http://test.local',
requestHeaders: {
'x-forwarded-host': 'test.local',
'x-forwarded-proto': 'http',
requestOriginInfo: {
protocol: 'http',
host: 'test.local',
},
});

Expand Down
16 changes: 13 additions & 3 deletions api/src/unraid-api/graph/resolvers/sso/core/oidc.service.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,16 @@ describe('OidcService Integration', () => {
providerId: 'custom-provider',
state: 'client-state-123',
requestOrigin: 'https://example.com',
requestHeaders: { host: 'example.com' },
requestOriginInfo: {
protocol: 'https',
host: 'example.com',
},
};

const url = await service.getAuthorizationUrl(params);

expect(redirectUriService.getRedirectUri).toHaveBeenCalledWith('https://example.com', {
protocol: 'https',
host: 'example.com',
});

Expand Down Expand Up @@ -177,7 +181,10 @@ describe('OidcService Integration', () => {
providerId: 'discovery-provider',
state: 'client-state-123',
requestOrigin: 'https://example.com',
requestHeaders: {},
requestOriginInfo: {
protocol: 'https',
host: 'example.com',
},
};

const url = await service.getAuthorizationUrl(params);
Expand All @@ -193,7 +200,10 @@ describe('OidcService Integration', () => {
providerId: 'non-existent',
state: 'state',
requestOrigin: 'https://example.com',
requestHeaders: {},
requestOriginInfo: {
protocol: 'https',
host: 'example.com',
},
};

await expect(service.getAuthorizationUrl(params)).rejects.toThrow(UnauthorizedException);
Expand Down
Loading
Loading