Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions packages/appkit/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,11 @@
"@opentelemetry/sdk-trace-base": "2.6.0",
"@opentelemetry/semantic-conventions": "1.38.0",
"@types/semver": "7.7.1",
"cors": "^2.8.6",
"dotenv": "16.6.1",
"express": "4.22.0",
"get-port": "7.2.0",
"helmet": "^8.1.0",
"js-yaml": "4.1.1",
"obug": "2.1.1",
"pg": "8.18.0",
Expand All @@ -88,6 +90,7 @@
},
"devDependencies": {
"@opentelemetry/context-async-hooks": "2.6.1",
"@types/cors": "^2.8.19",
"@types/express": "4.17.25",
"@types/js-yaml": "4.0.9",
"@types/json-schema": "7.0.15",
Expand Down
8 changes: 8 additions & 0 deletions packages/appkit/src/plugins/server/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import { instrumentations } from "../../telemetry";
import { sanitizeClientConfig } from "./client-config-sanitizer";
import manifest from "./manifest.json";
import { RemoteTunnelController } from "./remote-tunnel/remote-tunnel-controller";
import { registerErrorHandler, registerSecurityMiddleware } from "./security";
import { StaticServer } from "./static-server";
import type { ServerConfig } from "./types";
import { getRoutes, type PluginEndpoints, printRoutes } from "./utils";
Expand Down Expand Up @@ -111,6 +112,10 @@ export class ServerPlugin extends Plugin {
*/
async start(): Promise<express.Application> {
this.serverApplication.use(requestMetricsMiddleware);

// Security middleware first — inspects headers only, no body needed
registerSecurityMiddleware(this.serverApplication, this.config.security);

this.serverApplication.use(
express.json({
// Express's stock 100kb default is too tight for modern apps —
Expand Down Expand Up @@ -147,6 +152,9 @@ export class ServerPlugin extends Plugin {

await this.setupFrontend(endpoints, pluginConfigs);

// Error handler last — catches unhandled errors from API routes
registerErrorHandler(this.serverApplication, this.config.security);

const listenPort = await this.resolveListenPort();

const server = this.serverApplication.listen(
Expand Down
69 changes: 69 additions & 0 deletions packages/appkit/src/plugins/server/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,75 @@
"staticPath": {
"type": "string",
"description": "Path to static files directory (auto-detected if not provided)"
},
"bodyLimit": {
"type": "string",
"description": "JSON body size limit (e.g. '100kb', '1mb'). Default: '100kb'"
},
"security": {
"type": "object",
"description": "Security configuration. Secure defaults applied when omitted.",
"properties": {
"csrf": {
"oneOf": [
{
"type": "object",
"properties": {
"allowedOrigins": {
"type": "array",
"items": { "type": "string" },
"description": "Additional trusted origins for CSRF validation"
}
}
},
{ "const": false }
]
},
"helmet": {
"oneOf": [
{
"type": "object",
"description": "HelmetOptions — fully replaces defaults"
},
{ "const": false }
]
},
"cors": {
"oneOf": [
{
"type": "object",
"properties": {
"allowedOrigins": {
"type": "array",
"items": { "type": "string" }
},
"credentials": { "type": "boolean" },
"maxAge": { "type": "number" },
"allowedMethods": {
"type": "array",
"items": { "type": "string" }
},
"allowedHeaders": {
"type": "array",
"items": { "type": "string" }
}
}
},
{ "const": false }
]
},
"errorHandler": {
"oneOf": [
{
"type": "object",
"properties": {
"includeErrorCode": { "type": "boolean" }
}
},
{ "const": false }
]
}
}
}
}
}
Expand Down
184 changes: 184 additions & 0 deletions packages/appkit/src/plugins/server/security/csrf.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
import type { NextFunction, Request, Response } from "express";
import { createLogger } from "../../../logging/logger";
import type { CsrfConfig } from "./types";

const logger = createLogger("server");

const STATE_CHANGING_METHODS = new Set(["POST", "PUT", "DELETE", "PATCH"]);

/**
* Parse a comma-separated env var into trimmed, non-empty strings.
*/
function parseEnvOrigins(envVar: string | undefined): string[] {
if (!envVar) return [];
return envVar
.split(",")
.map((s) => s.trim())
.filter(Boolean);
}

/**
* Build the set of trusted origins from all sources:
* 1. DATABRICKS_APP_URL env var
* 2. Config allowedOrigins
* 3. APPKIT_CSRF_ALLOWED_ORIGINS env var
*/
function buildTrustedOrigins(config?: CsrfConfig): Set<string> {
const origins = new Set<string>();

const appUrl = process.env.DATABRICKS_APP_URL;
if (appUrl) {
try {
origins.add(new URL(appUrl).origin.toLowerCase());
} catch {
logger.warn(
"DATABRICKS_APP_URL is not a valid URL: %s — skipping for CSRF",
appUrl,
);
}
}

for (const o of config?.allowedOrigins ?? []) {
origins.add(o.toLowerCase().replace(/\/$/, ""));
}

for (const o of parseEnvOrigins(process.env.APPKIT_CSRF_ALLOWED_ORIGINS)) {
origins.add(o.toLowerCase().replace(/\/$/, ""));
}

return origins;
}

/**
* Check if an origin matches localhost (any port).
*/
function isLocalhostOrigin(origin: string): boolean {
try {
const url = new URL(origin);
return url.hostname === "localhost" || url.hostname === "127.0.0.1";
} catch {
return false;
}
}

/**
* Same-origin heuristic: compare Origin against Host header.
* Used as fallback when no trusted origins are configured.
*/
function isSameOrigin(origin: string, req: Request): boolean {
const host = req.headers.host;
if (!host) return false;

try {
const originUrl = new URL(origin);
const originHost = originUrl.host.toLowerCase();
return originHost === host.toLowerCase();
} catch {
return false;
}
}

/**
* Create CSRF protection middleware using Origin header validation.
*
* - Applies to state-changing methods (POST, PUT, DELETE, PATCH) only
* - Allows absent/empty Origin (same-origin browser or non-browser client)
* - Rejects `Origin: null` (sandboxed iframe attack vector)
* - In dev mode, auto-allows localhost origins
* - Falls back to Host header comparison when no trusted origins are configured
*/
export function createCsrfMiddleware(
config?: CsrfConfig | false,
): (req: Request, res: Response, next: NextFunction) => void {
if (config === false) {
return (_req, _res, next) => next();
}

const isDev = process.env.NODE_ENV === "development";
const trustedOrigins = buildTrustedOrigins(
config === undefined ? undefined : config,
);

if (!isDev && trustedOrigins.size === 0) {
logger.warn(
"DATABRICKS_APP_URL not set and no CSRF origins configured — CSRF will use Host header fallback. Set DATABRICKS_APP_URL for full protection.",
);
}

return (req: Request, res: Response, next: NextFunction) => {
if (!STATE_CHANGING_METHODS.has(req.method)) {
return next();
}

const origin = req.headers.origin;

// No Origin header — allow (same-origin or non-browser client)
if (!origin || origin === "") {
return next();
}

// Reject Origin: null (sandboxed iframe, data: URI)
if (origin === "null") {
logger.debug("CSRF rejected: null Origin on %s %s", req.method, req.path);
return res.status(403).json(
isDev
? {
error: "CSRF validation failed",
detail:
"Origin: null rejected — possible sandboxed iframe or data: URI",
}
: { error: "CSRF validation failed" },
);
}

const normalizedOrigin = origin.toLowerCase().replace(/\/$/, "");

// In dev mode, allow localhost origins
if (isDev && isLocalhostOrigin(normalizedOrigin)) {
return next();
}

// In production, reject non-HTTPS origins
if (!isDev && !normalizedOrigin.startsWith("https://")) {
logger.debug(
"CSRF rejected: non-HTTPS Origin %s on %s %s",
origin,
req.method,
req.path,
);
return res.status(403).json(
isDev
? {
error: "CSRF validation failed",
detail: `Origin must use HTTPS in production: ${origin}`,
}
: { error: "CSRF validation failed" },
);
}

// Check against trusted origins
if (trustedOrigins.has(normalizedOrigin)) {
return next();
}

// Fallback: same-origin heuristic (compare Origin vs Host)
if (trustedOrigins.size === 0 && isSameOrigin(origin, req)) {
return next();
}

logger.debug(
"CSRF rejected: Origin %s not trusted on %s %s",
origin,
req.method,
req.path,
);
return res.status(403).json(
isDev
? {
error: "CSRF validation failed",
detail: `Origin ${origin} not in trusted set`,
}
: { error: "CSRF validation failed" },
);
};
}
85 changes: 85 additions & 0 deletions packages/appkit/src/plugins/server/security/error-handler.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import type { NextFunction, Request, Response } from "express";
import { AppKitError } from "../../../errors/base";
import { createLogger } from "../../../logging/logger";
import type { ErrorHandlerConfig } from "./types";

const logger = createLogger("server");

/**
* Create a global error handler middleware that prevents information disclosure.
*
* - Logs full error details server-side (using AppKitError.toJSON() for safe sanitization)
* - Returns generic error messages in production
* - Includes message/stack in dev mode for debugging
* - Handles SyntaxError from JSON body parsing (returns 400)
* - Respects headersSent to avoid double-send
*/
export function createErrorHandler(
config?: ErrorHandlerConfig | false,
): (err: Error, req: Request, res: Response, next: NextFunction) => void {
if (config === false) {
return (_err, _req, _res, next) => next(_err);
}

const isDev = process.env.NODE_ENV === "development";
const includeErrorCode = config?.includeErrorCode ?? true;

return (err: Error, _req: Request, res: Response, next: NextFunction) => {
// If headers already sent, delegate to Express default handler
if (res.headersSent) {
return next(err);
}

// Log the error server-side
if (err instanceof AppKitError) {
logger.error("Unhandled error: %O", err.toJSON());
} else {
logger.error("Unhandled error: %s", err.message);
if (err.stack) {
logger.debug("Stack trace: %s", err.stack);
}
}

// Handle JSON parsing errors from express.json()
if (
err instanceof SyntaxError &&
"status" in err &&
(err as { status?: number }).status === 400
) {
return res
.status(400)
.json(
isDev
? { error: "Bad Request", message: err.message }
: { error: "Bad Request" },
);
}

// Handle AppKitError with proper status code
if (err instanceof AppKitError) {
const body: Record<string, unknown> = {
error: isDev ? err.message : "Internal Server Error",
};

if (includeErrorCode) {
body.code = err.code;
}

if (isDev && err.stack) {
body.stack = err.stack;
}

return res.status(err.statusCode).json(body);
}

// Generic error
return res.status(500).json(
isDev
? {
error: err.message || "Internal Server Error",
stack: err.stack,
}
: { error: "Internal Server Error" },
);
};
}
Loading
Loading