diff --git a/Lql/Nimblesite.Lql.Core/Parsing/LqlCodeParser.cs b/Lql/Nimblesite.Lql.Core/Parsing/LqlCodeParser.cs index 678e9b7..d24bf5a 100644 --- a/Lql/Nimblesite.Lql.Core/Parsing/LqlCodeParser.cs +++ b/Lql/Nimblesite.Lql.Core/Parsing/LqlCodeParser.cs @@ -122,10 +122,9 @@ public static Result Parse(string lqlCode) // Check for circular references in let statements var letStatements = new Dictionary(); - var definedVariables = new HashSet(); // Track all defined variables var lines = lqlCode.Split('\n'); - // First pass: collect all let statements and defined variables + // First pass: collect all let statements. foreach (var line in lines) { var trimmedLine = line.Trim(); @@ -143,9 +142,6 @@ public static Result Parse(string lqlCode) var varName = parts[0][4..].Trim(); // Remove "let " prefix var expression = parts[1].Trim(); - // Add to defined variables - definedVariables.Add(varName); - // Extract the first identifier from the expression (before |>) var pipeIndex = expression.IndexOf("|>", StringComparison.Ordinal); if (pipeIndex > 0) @@ -210,31 +206,10 @@ public static Result Parse(string lqlCode) } } - // Check for undefined variables (identifiers with underscores that appear as pipeline bases) - // BUT exclude variables that are defined in let statements - if (trimmedLine.Contains("|>", StringComparison.Ordinal)) - { - var pipeIndex = trimmedLine.IndexOf("|>", StringComparison.Ordinal); - var beforePipe = trimmedLine[..pipeIndex].Trim(); - - // Check if the identifier before the pipe contains underscores (indicating it might be an undefined variable) - // BUT only flag it as undefined if it's NOT in our definedVariables set - if ( - beforePipe.Contains('_', StringComparison.Ordinal) - && !beforePipe.Contains('(', StringComparison.Ordinal) - && !beforePipe.Contains('.', StringComparison.Ordinal) - && beforePipe.All(c => char.IsLetterOrDigit(c) || c == '_') - && !definedVariables.Contains(beforePipe) - ) // Only flag if NOT defined in let statement - { - return SqlError.WithPosition( - $"Syntax error: Undefined variable '{beforePipe}'", - 1, - 0, - lqlCode - ); - } - } + // Pipeline bases can be table names such as tenant_members. The + // parser cannot distinguish those from variables without schema + // metadata, so undefined-variable validation belongs in a later + // semantic pass with table context. } return null; diff --git a/Lql/Nimblesite.Lql.Tests/LqlErrorHandlingTests.cs b/Lql/Nimblesite.Lql.Tests/LqlErrorHandlingTests.cs index e0c2836..17456c0 100644 --- a/Lql/Nimblesite.Lql.Tests/LqlErrorHandlingTests.cs +++ b/Lql/Nimblesite.Lql.Tests/LqlErrorHandlingTests.cs @@ -174,21 +174,20 @@ public void InvalidFilterFunction_ShouldReturnError() } [Fact] - public void UndefinedVariable_ShouldReturnError() + public void UnderscorePipelineBase_ShouldParseAsTableName() { // Arrange const string lqlCode = """ - undefined_variable |> select(id, name) + tenant_members |> select(id, name) """; // Act var result = LqlStatementConverter.ToStatement(lqlCode); // Assert - Assert.IsType.Error>(result); - var failure = (Result.Error)result; - Assert.Contains("Syntax error", failure.Value.Message, StringComparison.Ordinal); - Assert.NotNull(failure.Value.Position); + Assert.IsType.Ok>(result); + var success = (Result.Ok)result; + Assert.NotNull(success.Value); } [Fact] diff --git a/Migration/Nimblesite.DataProvider.Migration.Core/LqlFunctionBodyTranspiler.cs b/Migration/Nimblesite.DataProvider.Migration.Core/LqlFunctionBodyTranspiler.cs new file mode 100644 index 0000000..f5ec415 --- /dev/null +++ b/Migration/Nimblesite.DataProvider.Migration.Core/LqlFunctionBodyTranspiler.cs @@ -0,0 +1,64 @@ +using Outcome; +using StringError = Outcome.Result< + string, + Nimblesite.DataProvider.Migration.Core.MigrationError +>.Error; +using StringOk = Outcome.Result.Ok< + string, + Nimblesite.DataProvider.Migration.Core.MigrationError +>; + +namespace Nimblesite.DataProvider.Migration.Core; + +/// +/// Transpiles LQL scalar expressions into PostgreSQL SQL-language function bodies. +/// +public static class LqlFunctionBodyTranspiler +{ + /// + /// Translates a PostgreSQL function bodyLql expression to a SQL function body. + /// + /// LQL scalar expression, optionally prefixed with SELECT. + /// Function name used in diagnostic messages. + /// A SQL-language function body beginning with SELECT. + public static Result TranslatePostgresBody( + string bodyLql, + string functionName + ) + { + var expression = StripSelectPrefix(StripTrailingSemicolon(bodyLql.Trim())); + if (string.IsNullOrWhiteSpace(expression)) + { + return new StringError( + MigrationError.RlsLqlParse(functionName, "function bodyLql is empty") + ); + } + + var result = RlsPredicateTranspiler.Translate( + expression, + RlsPlatform.Postgres, + functionName + ); + return result switch + { + StringOk ok => new StringOk($"SELECT {ok.Value.Trim()}"), + StringError err => new StringError(err.Value), + }; + } + + private static string StripTrailingSemicolon(string value) => + value.EndsWith(';') ? value[..^1].TrimEnd() : value; + + private static string StripSelectPrefix(string value) + { + const string select = "select"; + if (!value.StartsWith(select, StringComparison.OrdinalIgnoreCase)) + { + return value; + } + + return value.Length == select.Length || char.IsWhiteSpace(value[select.Length]) + ? value[select.Length..].TrimStart() + : value; + } +} diff --git a/Migration/Nimblesite.DataProvider.Migration.Core/RlsCurrentSettingRewriter.cs b/Migration/Nimblesite.DataProvider.Migration.Core/RlsCurrentSettingRewriter.cs new file mode 100644 index 0000000..95ced79 --- /dev/null +++ b/Migration/Nimblesite.DataProvider.Migration.Core/RlsCurrentSettingRewriter.cs @@ -0,0 +1,362 @@ +using System.Text; +using Outcome; +using RewriteError = Outcome.Result< + Nimblesite.DataProvider.Migration.Core.RlsCurrentSettingRewrite, + Nimblesite.DataProvider.Migration.Core.MigrationError +>.Error< + Nimblesite.DataProvider.Migration.Core.RlsCurrentSettingRewrite, + Nimblesite.DataProvider.Migration.Core.MigrationError +>; +using RewriteOk = Outcome.Result< + Nimblesite.DataProvider.Migration.Core.RlsCurrentSettingRewrite, + Nimblesite.DataProvider.Migration.Core.MigrationError +>.Ok< + Nimblesite.DataProvider.Migration.Core.RlsCurrentSettingRewrite, + Nimblesite.DataProvider.Migration.Core.MigrationError +>; +using StringError = Outcome.Result< + string, + Nimblesite.DataProvider.Migration.Core.MigrationError +>.Error; +using StringOk = Outcome.Result.Ok< + string, + Nimblesite.DataProvider.Migration.Core.MigrationError +>; + +namespace Nimblesite.DataProvider.Migration.Core; + +internal sealed record RlsCurrentSettingReplacement(string Sentinel, string SqlExpression); + +internal sealed record RlsCurrentSettingRewrite( + string Text, + IReadOnlyList Replacements +); + +internal static class RlsCurrentSettingRewriter +{ + private const string FunctionName = "current_setting"; + private const string SentinelPrefix = "__RLS_CURRENT_SETTING_"; + + internal static Result ReplaceCallsForSimplePredicate( + string source, + RlsPlatform platform, + string contextName + ) + { + var rewritten = RewriteCalls(source, platform, contextName, useSentinels: false); + return rewritten switch + { + RewriteOk ok => new StringOk(ok.Value.Text), + RewriteError err => new StringError(err.Value), + }; + } + + internal static Result ReplaceCallsForPipeline( + string source, + RlsPlatform platform, + string contextName + ) => RewriteCalls(source, platform, contextName, useSentinels: true); + + internal static string RestoreSentinels( + string sql, + IReadOnlyList replacements + ) + { + var restored = sql; + foreach (var replacement in replacements) + { + restored = restored.Replace( + $"'{replacement.Sentinel}'", + replacement.SqlExpression, + StringComparison.Ordinal + ); + } + return restored; + } + + private static Result RewriteCalls( + string source, + RlsPlatform platform, + string contextName, + bool useSentinels + ) + { + var sb = new StringBuilder(source.Length); + var replacements = new List(); + var i = 0; + + while (i < source.Length) + { + if (source[i] == '\'') + { + CopyStringLiteral(source, sb, ref i); + continue; + } + + if (!TryReadIdentifier(source, i, out var word, out var afterWord)) + { + sb.Append(source[i]); + i++; + continue; + } + + if (!word.Equals(FunctionName, StringComparison.Ordinal)) + { + sb.Append(word); + i = afterWord; + continue; + } + + var afterWhitespace = afterWord; + SkipWhitespace(source, ref afterWhitespace); + if (afterWhitespace >= source.Length || source[afterWhitespace] != '(') + { + sb.Append(word); + i = afterWord; + continue; + } + + var call = TryParseCurrentSettingCall(source, i, contextName); + if ( + call + is Result.Error< + CurrentSettingCall, + MigrationError + > callErr + ) + { + return new RewriteError(callErr.Value); + } + + if ( + call + is not Result.Ok< + CurrentSettingCall, + MigrationError + > callOk + ) + { + return new RewriteError( + MigrationError.RlsLqlParse(contextName, "current_setting() could not be parsed") + ); + } + + if (platform != RlsPlatform.Postgres) + { + return new RewriteError( + MigrationError.RlsLqlTranspile( + contextName, + "current_setting() is currently supported only for PostgreSQL" + ) + ); + } + + var sqlExpression = + $"current_setting({callOk.Value.StringLiteral}, true){callOk.Value.Cast}"; + if (useSentinels) + { + var sentinel = $"{SentinelPrefix}{replacements.Count}__"; + sb.Append('\'').Append(sentinel).Append('\''); + replacements.Add(new RlsCurrentSettingReplacement(sentinel, sqlExpression)); + } + else + { + sb.Append(sqlExpression); + } + + i = callOk.Value.EndIndex; + } + + return new RewriteOk(new RlsCurrentSettingRewrite(sb.ToString(), replacements)); + } + + private static Result TryParseCurrentSettingCall( + string source, + int start, + string contextName + ) + { + var i = start + FunctionName.Length; + SkipWhitespace(source, ref i); + i++; + SkipWhitespace(source, ref i); + + if (!TryReadStringLiteral(source, i, out var literal, out var afterLiteral)) + { + return new Result.Error< + CurrentSettingCall, + MigrationError + >( + MigrationError.RlsLqlParse( + contextName, + "current_setting() requires exactly one string-literal argument" + ) + ); + } + + i = afterLiteral; + SkipWhitespace(source, ref i); + if (i >= source.Length || source[i] != ')') + { + return new Result.Error< + CurrentSettingCall, + MigrationError + >( + MigrationError.RlsLqlParse( + contextName, + "current_setting() requires exactly one string-literal argument" + ) + ); + } + + i++; + var cast = ReadOptionalCast(source, ref i, contextName); + return cast switch + { + StringOk ok => new Result.Ok< + CurrentSettingCall, + MigrationError + >(new CurrentSettingCall(literal, ok.Value, i)), + StringError err => new Result.Error< + CurrentSettingCall, + MigrationError + >(err.Value), + }; + } + + private static Result ReadOptionalCast( + string source, + ref int i, + string contextName + ) + { + if (i + 1 >= source.Length || source[i] != ':' || source[i + 1] != ':') + { + return new StringOk(string.Empty); + } + + var castStart = i; + i += 2; + var typeStart = i; + while ( + i < source.Length + && ( + char.IsLetterOrDigit(source[i]) + || source[i] == '_' + || source[i] == '.' + || source[i] == '[' + || source[i] == ']' + ) + ) + { + i++; + } + + return i == typeStart + ? new StringError( + MigrationError.RlsLqlParse( + contextName, + "current_setting() cast requires a type name after ::" + ) + ) + : new StringOk(source[castStart..i]); + } + + private static bool TryReadIdentifier( + string source, + int start, + out string word, + out int afterWord + ) + { + word = string.Empty; + afterWord = start; + if (!IsIdentifierStart(source, start)) + { + return false; + } + + var i = start + 1; + while (i < source.Length && IsIdentifierPart(source[i])) + { + i++; + } + + word = source[start..i]; + afterWord = i; + return true; + } + + private static bool IsIdentifierStart(string source, int index) => + index < source.Length + && (char.IsLetter(source[index]) || source[index] == '_') + && (index == 0 || !IsIdentifierPart(source[index - 1])); + + private static bool IsIdentifierPart(char c) => char.IsLetterOrDigit(c) || c == '_'; + + private static void SkipWhitespace(string source, ref int i) + { + while (i < source.Length && char.IsWhiteSpace(source[i])) + { + i++; + } + } + + private static void CopyStringLiteral(string source, StringBuilder sb, ref int i) + { + sb.Append(source[i]); + i++; + while (i < source.Length) + { + sb.Append(source[i]); + if (source[i] == '\'') + { + if (i + 1 < source.Length && source[i + 1] == '\'') + { + sb.Append(source[i + 1]); + i += 2; + continue; + } + i++; + break; + } + i++; + } + } + + private static bool TryReadStringLiteral( + string source, + int start, + out string literal, + out int afterLiteral + ) + { + literal = string.Empty; + afterLiteral = start; + if (start >= source.Length || source[start] != '\'') + { + return false; + } + + var i = start + 1; + while (i < source.Length) + { + if (source[i] == '\'') + { + if (i + 1 < source.Length && source[i + 1] == '\'') + { + i += 2; + continue; + } + afterLiteral = i + 1; + literal = source[start..afterLiteral]; + return true; + } + i++; + } + + return false; + } + + private sealed record CurrentSettingCall(string StringLiteral, string Cast, int EndIndex); +} diff --git a/Migration/Nimblesite.DataProvider.Migration.Core/RlsPredicateTranspiler.cs b/Migration/Nimblesite.DataProvider.Migration.Core/RlsPredicateTranspiler.cs index 536b84d..c011a5c 100644 --- a/Migration/Nimblesite.DataProvider.Migration.Core/RlsPredicateTranspiler.cs +++ b/Migration/Nimblesite.DataProvider.Migration.Core/RlsPredicateTranspiler.cs @@ -66,9 +66,7 @@ string policyName var trimmed = lql.Trim(); return TryParseExistsWrapper(trimmed, out var inner) ? TranslateExistsSubquery(inner, platform, policyName) - : new Result.Ok( - TranslateSimplePredicate(trimmed, platform) - ); + : TranslateSimplePredicate(trimmed, platform, policyName); } /// @@ -141,8 +139,38 @@ string policyName // survives LQL transpilation, then transpile the pipeline, then // replace the sentinel with the platform-specific expression. var withSentinel = SubstituteCurrentUserIdWithSentinel(innerLql); + var withSettings = RlsCurrentSettingRewriter.ReplaceCallsForPipeline( + withSentinel, + platform, + policyName + ); + if ( + withSettings + is Result.Error< + RlsCurrentSettingRewrite, + MigrationError + > settingsErr + ) + { + return new Result.Error( + settingsErr.Value + ); + } - var statementResult = LqlStatementConverter.ToStatement(withSentinel); + if ( + withSettings + is not Result.Ok< + RlsCurrentSettingRewrite, + MigrationError + > settingsOk + ) + { + return new Result.Error( + MigrationError.RlsLqlParse(policyName, "unknown current_setting rewrite failure") + ); + } + + var statementResult = LqlStatementConverter.ToStatement(settingsOk.Value.Text); if (statementResult is Result.Error sErr) { return new Result.Error( @@ -180,21 +208,53 @@ string policyName } var sql = ReplaceSentinelInSql(tOk.Value, platform); + sql = RlsCurrentSettingRewriter.RestoreSentinels(sql, settingsOk.Value.Replacements); return new Result.Ok($"EXISTS ({sql})"); } - private static string TranslateSimplePredicate(string predicate, RlsPlatform platform) + private static Result TranslateSimplePredicate( + string predicate, + RlsPlatform platform, + string policyName + ) { - // Replace current_user_id() literal with platform expression. - // Quote bare identifiers per platform (best-effort, conservative). - var withSession = ReplaceCurrentUserIdLiteral(predicate, platform); - return platform switch + var withSettings = RlsCurrentSettingRewriter.ReplaceCallsForSimplePredicate( + predicate, + platform, + policyName + ); + + if ( + withSettings is Result.Error settingsErr + ) + { + return new Result.Error( + settingsErr.Value + ); + } + + if ( + withSettings is not Result.Ok settingsOk + ) + { + return new Result.Error( + MigrationError.RlsLqlParse(policyName, "unknown current_setting rewrite failure") + ); + } + + // Replace current_user_id() literal with platform expression after + // current_setting() LQL calls have been rewritten. The Postgres + // current_user_id() expansion itself uses current_setting(..., true). + var withSession = ReplaceCurrentUserIdLiteral(settingsOk.Value, platform); + var translated = platform switch { RlsPlatform.Postgres => QuoteSimpleIdentifiers(withSession, '"', '"'), RlsPlatform.Sqlite => QuoteSimpleIdentifiers(withSession, '[', ']'), RlsPlatform.SqlServer => QuoteSimpleIdentifiers(withSession, '[', ']'), _ => withSession, }; + + return new Result.Ok(translated); } private static string SubstituteCurrentUserIdWithSentinel(string lql) => diff --git a/Migration/Nimblesite.DataProvider.Migration.Core/SchemaDefinition.cs b/Migration/Nimblesite.DataProvider.Migration.Core/SchemaDefinition.cs index 28da82a..86fa60b 100644 --- a/Migration/Nimblesite.DataProvider.Migration.Core/SchemaDefinition.cs +++ b/Migration/Nimblesite.DataProvider.Migration.Core/SchemaDefinition.cs @@ -76,6 +76,12 @@ public sealed record PostgresFunctionDefinition /// Function body placed between PostgreSQL dollar quotes. public string Body { get; init; } = string.Empty; + /// + /// LQL function body expression. Mutually exclusive with . + /// Emits a SQL-language function body for PostgreSQL. + /// + public string? BodyLql { get; init; } + /// Roles granted EXECUTE on this function. public IReadOnlyList ExecuteRoles { get; init; } = []; diff --git a/Migration/Nimblesite.DataProvider.Migration.Core/SchemaDiff.Support.cs b/Migration/Nimblesite.DataProvider.Migration.Core/SchemaDiff.Support.cs index 9cb95e3..ed649f6 100644 --- a/Migration/Nimblesite.DataProvider.Migration.Core/SchemaDiff.Support.cs +++ b/Migration/Nimblesite.DataProvider.Migration.Core/SchemaDiff.Support.cs @@ -106,10 +106,29 @@ PostgresFunctionDefinition desired || !SameSqlToken(current.Volatility, desired.Volatility) || current.SecurityDefiner != desired.SecurityDefiner || current.RevokePublicExecute != desired.RevokePublicExecute - || current.Body.Trim() != desired.Body.Trim() + || FunctionBodyForDiff(current) != FunctionBodyForDiff(desired) || !currentExecuteRoles.SetEquals(desired.ExecuteRoles); } + private static string FunctionBodyForDiff(PostgresFunctionDefinition function) + { + if (string.IsNullOrWhiteSpace(function.BodyLql)) + { + return function.Body.Trim(); + } + + var result = LqlFunctionBodyTranspiler.TranslatePostgresBody( + function.BodyLql, + $"{function.Schema}.{function.Name}" + ); + return result switch + { + Outcome.Result.Ok ok => ok.Value.Trim(), + Outcome.Result.Error => + function.BodyLql.Trim(), + }; + } + private static string FunctionKey(PostgresFunctionDefinition function) => string.Join( "|", diff --git a/Migration/Nimblesite.DataProvider.Migration.Core/SchemaYamlSerializer.cs b/Migration/Nimblesite.DataProvider.Migration.Core/SchemaYamlSerializer.cs index 87e425c..ebbf0fc 100644 --- a/Migration/Nimblesite.DataProvider.Migration.Core/SchemaYamlSerializer.cs +++ b/Migration/Nimblesite.DataProvider.Migration.Core/SchemaYamlSerializer.cs @@ -68,16 +68,25 @@ public static class SchemaYamlSerializer /// /// Schema to serialize. /// YAML representation of the schema. - public static string ToYaml(SchemaDefinition schema) => Serializer.Serialize(schema); + public static string ToYaml(SchemaDefinition schema) + { + ValidateSupportFunctionBodies(schema); + return Serializer.Serialize(schema); + } /// /// Deserialize a schema definition from YAML string. /// /// YAML string. /// Deserialized schema definition. - public static SchemaDefinition FromYaml(string yaml) => - Deserializer.Deserialize(yaml) - ?? new SchemaDefinition { Name = string.Empty, Tables = [] }; + public static SchemaDefinition FromYaml(string yaml) + { + var schema = + Deserializer.Deserialize(yaml) + ?? new SchemaDefinition { Name = string.Empty, Tables = [] }; + ValidateSupportFunctionBodies(schema); + return schema; + } /// /// Load a schema definition from a YAML file. @@ -100,6 +109,23 @@ public static void ToYamlFile(SchemaDefinition schema, string filePath) var yaml = ToYaml(schema); File.WriteAllText(filePath, yaml); } + + private static void ValidateSupportFunctionBodies(SchemaDefinition schema) + { + foreach (var function in schema.Functions) + { + if ( + !string.IsNullOrWhiteSpace(function.Body) + && !string.IsNullOrWhiteSpace(function.BodyLql) + ) + { + throw new InvalidOperationException( + "PostgreSQL function body and bodyLql are mutually exclusive: " + + $"{function.Schema}.{function.Name}" + ); + } + } + } } /// diff --git a/Migration/Nimblesite.DataProvider.Migration.Postgres/PostgresSupportDdlGenerator.cs b/Migration/Nimblesite.DataProvider.Migration.Postgres/PostgresSupportDdlGenerator.cs index c9ffaa6..e9f1ac5 100644 --- a/Migration/Nimblesite.DataProvider.Migration.Postgres/PostgresSupportDdlGenerator.cs +++ b/Migration/Nimblesite.DataProvider.Migration.Postgres/PostgresSupportDdlGenerator.cs @@ -44,6 +44,7 @@ private static string GenerateCreateOrReplaceFunction(CreateOrReplaceFunctionOpe function.Arguments.Select(ArgumentDeclaration) ); var signature = FunctionSignature(function); + var body = FunctionBody(function); var sb = new StringBuilder(); sb.AppendLine( @@ -58,7 +59,7 @@ private static string GenerateCreateOrReplaceFunction(CreateOrReplaceFunctionOpe sb.AppendLine("SECURITY DEFINER"); } sb.AppendLine("AS $function$"); - sb.AppendLine(function.Body.Trim()); + sb.AppendLine(body); sb.Append("$function$"); if (function.RevokePublicExecute) @@ -96,6 +97,36 @@ private static string ArgumentDeclaration(PostgresFunctionArgumentDefinition arg ? argument.Type : $"{QuoteIdent(argument.Name)} {argument.Type}"; + private static string FunctionBody(PostgresFunctionDefinition function) + { + if ( + !string.IsNullOrWhiteSpace(function.Body) + && !string.IsNullOrWhiteSpace(function.BodyLql) + ) + { + throw new InvalidOperationException( + "PostgreSQL function body and bodyLql are mutually exclusive: " + + $"{function.Schema}.{function.Name}" + ); + } + + if (string.IsNullOrWhiteSpace(function.BodyLql)) + { + return function.Body.Trim(); + } + + var result = LqlFunctionBodyTranspiler.TranslatePostgresBody( + function.BodyLql, + $"{function.Schema}.{function.Name}" + ); + return result switch + { + Outcome.Result.Ok ok => ok.Value, + Outcome.Result.Error err => + throw new InvalidOperationException(err.Value.Message), + }; + } + private static string FunctionSignature(PostgresFunctionDefinition function) => $"{QuoteIdent(function.Schema)}.{QuoteIdent(function.Name)}({string.Join(", ", function.Arguments.Select(a => a.Type))})"; diff --git a/Migration/Nimblesite.DataProvider.Migration.Tests/PostgresFunctionBodyLqlE2ETests.cs b/Migration/Nimblesite.DataProvider.Migration.Tests/PostgresFunctionBodyLqlE2ETests.cs new file mode 100644 index 0000000..4a07b93 --- /dev/null +++ b/Migration/Nimblesite.DataProvider.Migration.Tests/PostgresFunctionBodyLqlE2ETests.cs @@ -0,0 +1,239 @@ +using System.Globalization; + +namespace Nimblesite.DataProvider.Migration.Tests; + +[Collection(PostgresTestSuite.Name)] +[System.Diagnostics.CodeAnalysis.SuppressMessage( + "Usage", + "CA1001:Types that own disposable fields should be disposable", + Justification = "Disposed via IAsyncLifetime.DisposeAsync" +)] +public sealed class PostgresFunctionBodyLqlE2ETests(PostgresContainerFixture fixture) + : IAsyncLifetime +{ + private NpgsqlConnection _connection = null!; + private readonly ILogger _logger = NullLogger.Instance; + + public async Task InitializeAsync() + { + _connection = await fixture.CreateDatabaseAsync("function_body_lql").ConfigureAwait(false); + } + + public async Task DisposeAsync() + { + await _connection.DisposeAsync().ConfigureAwait(false); + } + + [Fact] + public void BodyLqlSupportFunctionsAndLqlPolicies_EnforceNapTenantIsolation() + { + var suffix = Guid.NewGuid().ToString("N")[..8]; + var names = Names.Create(suffix); + + Apply(Schema(names)); + + var tenantA = Guid.NewGuid(); + var tenantB = Guid.NewGuid(); + var userA = Guid.NewGuid(); + Seed(tenantA, tenantB, userA); + + using var tx = _connection.BeginTransaction(); + SetAppSession(tx, names.AppUserRole, tenantA, userA); + + Assert.Equal(1, CountVisibleDocuments(tx)); + Assert.Throws(() => InsertDocument(tx, tenantB, "blocked")); + } + + private void Apply(SchemaDefinition schema) + { + var current = ( + (SchemaResultOk)PostgresSchemaInspector.Inspect(_connection, "public", _logger) + ).Value; + var ops = ( + (OperationsResultOk)SchemaDiff.Calculate(current, schema, logger: _logger) + ).Value; + var apply = MigrationRunner.Apply( + _connection, + ops, + PostgresDdlGenerator.Generate, + MigrationOptions.Default, + _logger + ); + var failure = apply is MigrationApplyResultError error ? error.Value.Message : "unknown"; + + Assert.True(apply is MigrationApplyResultOk, $"Migration failed: {failure}"); + } + + private void Seed(Guid tenantA, Guid tenantB, Guid userA) + { + Exec( + $"INSERT INTO public.tenant_members(id, tenant_id, user_id) VALUES ('{Guid.NewGuid()}', '{tenantA}', '{userA}')" + ); + Exec( + $"INSERT INTO public.documents(id, tenant_id, title) VALUES ('{Guid.NewGuid()}', '{tenantA}', 'visible')" + ); + Exec( + $"INSERT INTO public.documents(id, tenant_id, title) VALUES ('{Guid.NewGuid()}', '{tenantB}', 'hidden')" + ); + } + + private void SetAppSession(NpgsqlTransaction tx, string role, Guid tenant, Guid user) + { + Exec(tx, $"SET LOCAL ROLE {role}"); + Exec(tx, $"SET LOCAL app.tenant_id = '{tenant}'"); + Exec(tx, $"SET LOCAL app.user_id = '{user}'"); + } + + private int CountVisibleDocuments(NpgsqlTransaction tx) + { + using var command = _connection.CreateCommand(); + command.Transaction = tx; + command.CommandText = "SELECT count(*) FROM public.documents"; + return Convert.ToInt32(command.ExecuteScalar(), CultureInfo.InvariantCulture); + } + + private void InsertDocument(NpgsqlTransaction tx, Guid tenant, string title) => + Exec( + tx, + $"INSERT INTO public.documents(id, tenant_id, title) VALUES ('{Guid.NewGuid()}', '{tenant}', '{title}')" + ); + + private void Exec(string sql) + { + using var command = _connection.CreateCommand(); + command.CommandText = sql; + command.ExecuteNonQuery(); + } + + private void Exec(NpgsqlTransaction tx, string sql) + { + using var command = _connection.CreateCommand(); + command.Transaction = tx; + command.CommandText = sql; + command.ExecuteNonQuery(); + } + + private static SchemaDefinition Schema(Names names) => + new() + { + Name = "body_lql", + Roles = [new PostgresRoleDefinition { Name = names.AppUserRole, GrantTo = ["test"] }], + Tables = [TenantMembersTable(), DocumentsTable(names)], + Functions = + [ + new PostgresFunctionDefinition + { + Name = names.AppTenantFunction, + Returns = "uuid", + BodyLql = "current_setting('app.tenant_id')::uuid", + ExecuteRoles = [names.AppUserRole], + }, + new PostgresFunctionDefinition + { + Name = names.AppUserFunction, + Returns = "uuid", + BodyLql = "current_setting('app.user_id')::uuid", + ExecuteRoles = [names.AppUserRole], + }, + new PostgresFunctionDefinition + { + Name = names.IsMemberFunction, + Returns = "boolean", + Arguments = + [ + new PostgresFunctionArgumentDefinition { Name = "u", Type = "uuid" }, + new PostgresFunctionArgumentDefinition { Name = "t", Type = "uuid" }, + ], + SecurityDefiner = true, + BodyLql = + "exists(tenant_members |> filter(fn(m) => m.user_id = u and m.tenant_id = t))", + ExecuteRoles = [names.AppUserRole], + }, + ], + Grants = + [ + new PostgresGrantDefinition + { + Schema = "public", + Target = PostgresGrantTarget.Schema, + Privileges = ["USAGE"], + Roles = [names.AppUserRole], + }, + new PostgresGrantDefinition + { + Schema = "public", + Target = PostgresGrantTarget.AllTablesInSchema, + Privileges = ["SELECT", "INSERT", "UPDATE", "DELETE"], + Roles = [names.AppUserRole], + }, + ], + }; + + private static TableDefinition TenantMembersTable() => + new() + { + Schema = "public", + Name = "tenant_members", + Columns = [RequiredUuid("id"), RequiredUuid("tenant_id"), RequiredUuid("user_id")], + PrimaryKey = new PrimaryKeyDefinition { Columns = ["id"] }, + }; + + private static TableDefinition DocumentsTable(Names names) => + new() + { + Schema = "public", + Name = "documents", + Columns = + [ + RequiredUuid("id"), + RequiredUuid("tenant_id"), + new ColumnDefinition + { + Name = "title", + Type = PortableTypes.Text, + IsNullable = false, + }, + ], + PrimaryKey = new PrimaryKeyDefinition { Columns = ["id"] }, + RowLevelSecurity = new RlsPolicySetDefinition + { + Forced = true, + Policies = + [ + new RlsPolicyDefinition + { + Name = "documents_member", + Roles = [names.AppUserRole], + UsingLql = + $"tenant_id = {names.AppTenantFunction}() and {names.IsMemberFunction}({names.AppUserFunction}(), {names.AppTenantFunction}())", + WithCheckLql = + $"tenant_id = {names.AppTenantFunction}() and {names.IsMemberFunction}({names.AppUserFunction}(), {names.AppTenantFunction}())", + }, + ], + }, + }; + + private static ColumnDefinition RequiredUuid(string name) => + new() + { + Name = name, + Type = PortableTypes.Uuid, + IsNullable = false, + }; + + private sealed record Names( + string AppUserRole, + string AppTenantFunction, + string AppUserFunction, + string IsMemberFunction + ) + { + public static Names Create(string suffix) => + new( + $"body_lql_user_{suffix}", + $"app_tenant_id_{suffix}", + $"app_user_id_{suffix}", + $"is_member_{suffix}" + ); + } +} diff --git a/Migration/Nimblesite.DataProvider.Migration.Tests/PostgresFunctionBodyLqlTests.cs b/Migration/Nimblesite.DataProvider.Migration.Tests/PostgresFunctionBodyLqlTests.cs new file mode 100644 index 0000000..f3b768d --- /dev/null +++ b/Migration/Nimblesite.DataProvider.Migration.Tests/PostgresFunctionBodyLqlTests.cs @@ -0,0 +1,265 @@ +using BodyError = Outcome.Result< + string, + Nimblesite.DataProvider.Migration.Core.MigrationError +>.Error; +using BodyOk = Outcome.Result.Ok< + string, + Nimblesite.DataProvider.Migration.Core.MigrationError +>; + +namespace Nimblesite.DataProvider.Migration.Tests; + +public sealed class PostgresFunctionBodyLqlTests +{ + [Theory] + [InlineData( + "current_setting('app.tenant_id')::uuid", + "SELECT current_setting('app.tenant_id', true)::uuid" + )] + [InlineData( + "SELECT current_setting('app.user_id')::uuid", + "SELECT current_setting('app.user_id', true)::uuid" + )] + [InlineData("true", "SELECT true")] + [InlineData( + "is_member(current_setting('app.user_id')::uuid, current_setting('app.tenant_id')::uuid)", + "SELECT is_member(current_setting('app.user_id', true)::uuid, current_setting('app.tenant_id', true)::uuid)" + )] + public void TranslatePostgresBody_NapScalarShapes_EmitsSqlBody(string bodyLql, string expected) + { + var sql = Body(bodyLql); + + Assert.Equal(expected, sql); + } + + [Fact] + public void TranslatePostgresBody_ExistsPipeline_EmitsSelectExists() + { + var sql = Body( + """ + exists( + tenant_members + |> filter(fn(m) => m.user_id = u and m.tenant_id = t) + ) + """ + ); + + Assert.StartsWith("SELECT EXISTS (", sql, StringComparison.Ordinal); + Assert.Contains("FROM tenant_members", sql, StringComparison.Ordinal); + Assert.Contains("user_id = u", sql, StringComparison.Ordinal); + Assert.Contains("tenant_id = t", sql, StringComparison.Ordinal); + } + + [Fact] + public void FromYaml_FunctionBodyLql_Deserializes() + { + var schema = SchemaYamlSerializer.FromYaml( + """ + name: nap + functions: + - schema: public + name: app_tenant_id + returns: uuid + bodyLql: current_setting('app.tenant_id')::uuid + tables: [] + """ + ); + + Assert.Single(schema.Functions); + Assert.Equal("current_setting('app.tenant_id')::uuid", schema.Functions[0].BodyLql); + Assert.Equal(string.Empty, schema.Functions[0].Body); + } + + [Fact] + public void FromYaml_FunctionBodyAndBodyLql_Throws() + { + var ex = Assert.Throws(() => + SchemaYamlSerializer.FromYaml( + """ + name: nap + functions: + - name: app_tenant_id + returns: uuid + body: SELECT NULL::uuid + bodyLql: current_setting('app.tenant_id')::uuid + tables: [] + """ + ) + ); + + Assert.Contains("mutually exclusive", ex.Message, StringComparison.Ordinal); + Assert.Contains("public.app_tenant_id", ex.Message, StringComparison.Ordinal); + } + + [Fact] + public void ToYaml_FunctionBodyLql_EmitsBodyLqlNotBody() + { + var yaml = SchemaYamlSerializer.ToYaml( + new SchemaDefinition + { + Name = "nap", + Functions = + [ + new PostgresFunctionDefinition + { + Name = "app_tenant_id", + Returns = "uuid", + BodyLql = "current_setting('app.tenant_id')::uuid", + }, + ], + } + ); + + Assert.Contains("bodyLql: current_setting('app.tenant_id')::uuid", yaml); + Assert.DoesNotContain("body:", yaml, StringComparison.Ordinal); + } + + [Fact] + public void Generate_CreateFunction_BodyLqlCurrentSetting_EmitsDollarQuotedSqlBody() + { + var ddl = PostgresDdlGenerator.Generate( + new CreateOrReplaceFunctionOperation( + new PostgresFunctionDefinition + { + Name = "app_tenant_id", + Returns = "uuid", + BodyLql = "current_setting('app.tenant_id')::uuid", + } + ) + ); + + Assert.Contains("AS $function$", ddl, StringComparison.Ordinal); + Assert.Contains( + "SELECT current_setting('app.tenant_id', true)::uuid", + ddl, + StringComparison.Ordinal + ); + Assert.DoesNotContain("bodyLql", ddl, StringComparison.Ordinal); + } + + [Fact] + public void Generate_CreateFunction_BodyLqlExists_EmitsSelectExistsBody() + { + var ddl = PostgresDdlGenerator.Generate( + new CreateOrReplaceFunctionOperation( + new PostgresFunctionDefinition + { + Name = "is_member", + Returns = "boolean", + Arguments = + [ + new PostgresFunctionArgumentDefinition { Name = "u", Type = "uuid" }, + new PostgresFunctionArgumentDefinition { Name = "t", Type = "uuid" }, + ], + SecurityDefiner = true, + BodyLql = + "exists(tenant_members |> filter(fn(m) => m.user_id = u and m.tenant_id = t))", + } + ) + ); + + Assert.Contains("SELECT EXISTS (", ddl, StringComparison.Ordinal); + Assert.Contains("FROM tenant_members", ddl, StringComparison.Ordinal); + Assert.Contains("user_id = u", ddl, StringComparison.Ordinal); + Assert.Contains("tenant_id = t", ddl, StringComparison.Ordinal); + } + + [Fact] + public void Generate_CreateFunction_BodyAndBodyLql_Throws() + { + var ex = Assert.Throws(() => + PostgresDdlGenerator.Generate( + new CreateOrReplaceFunctionOperation( + new PostgresFunctionDefinition + { + Name = "bad", + Body = "SELECT true", + BodyLql = "true", + } + ) + ) + ); + + Assert.Contains("mutually exclusive", ex.Message, StringComparison.Ordinal); + } + + [Fact] + public void SchemaDiff_BodyLqlEquivalentToInspectedBody_HasNoOperations() + { + var current = new SchemaDefinition + { + Name = "nap", + Functions = + [ + new PostgresFunctionDefinition + { + Name = "app_tenant_id", + Returns = "uuid", + Body = "SELECT current_setting('app.tenant_id', true)::uuid", + }, + ], + }; + var desired = current with + { + Functions = + [ + new PostgresFunctionDefinition + { + Name = "app_tenant_id", + Returns = "uuid", + BodyLql = "current_setting('app.tenant_id')::uuid", + }, + ], + }; + + var result = SchemaDiff.Calculate(current, desired); + + Assert.True(result is OperationsResultOk); + Assert.Empty(((OperationsResultOk)result).Value); + } + + [Fact] + public void SchemaDiff_BodyLqlChange_EmitsCreateOrReplaceFunction() + { + var current = new SchemaDefinition + { + Name = "nap", + Functions = + [ + new PostgresFunctionDefinition + { + Name = "app_tenant_id", + Returns = "uuid", + Body = "SELECT current_setting('app.tenant_id', true)::uuid", + }, + ], + }; + var desired = current with + { + Functions = + [ + new PostgresFunctionDefinition + { + Name = "app_tenant_id", + Returns = "text", + BodyLql = "current_setting('app.tenant_id')", + }, + ], + }; + + var result = SchemaDiff.Calculate(current, desired); + + Assert.True(result is OperationsResultOk); + Assert.Contains( + ((OperationsResultOk)result).Value, + op => op is CreateOrReplaceFunctionOperation + ); + } + + private static string Body(string bodyLql) + { + var result = LqlFunctionBodyTranspiler.TranslatePostgresBody(bodyLql, "public.test"); + Assert.True(result is BodyOk, result is BodyError e ? e.Value.Message : "expected Ok"); + return ((BodyOk)result).Value; + } +} diff --git a/Migration/Nimblesite.DataProvider.Migration.Tests/RlsCurrentSettingLqlTests.cs b/Migration/Nimblesite.DataProvider.Migration.Tests/RlsCurrentSettingLqlTests.cs new file mode 100644 index 0000000..a839116 --- /dev/null +++ b/Migration/Nimblesite.DataProvider.Migration.Tests/RlsCurrentSettingLqlTests.cs @@ -0,0 +1,199 @@ +using TranspileError = Outcome.Result< + string, + Nimblesite.DataProvider.Migration.Core.MigrationError +>.Error; +using TranspileOk = Outcome.Result< + string, + Nimblesite.DataProvider.Migration.Core.MigrationError +>.Ok; + +namespace Nimblesite.DataProvider.Migration.Tests; + +public sealed class RlsCurrentSettingLqlTests +{ + [Theory] + [InlineData("tenant_id = current_setting('app.tenant_id')::uuid", "'app.tenant_id'")] + [InlineData("user_id = current_setting('app.user_id')::uuid", "'app.user_id'")] + [InlineData("workspace_id = current_setting('app.workspace_id')::uuid", "'app.workspace_id'")] + [InlineData("api_key_id = current_setting('app.api_key_id')::uuid", "'app.api_key_id'")] + [InlineData("role = current_setting('app.role')", "'app.role'")] + [InlineData( + "tenant_id = current_setting('request.jwt.claims.tenant_id')::uuid", + "'request.jwt.claims.tenant_id'" + )] + [InlineData("created_by = current_setting('app.user_id')::uuid", "'app.user_id'")] + [InlineData("updated_by = current_setting('app.user_id')::uuid", "'app.user_id'")] + [InlineData("owner_id = current_setting('app.user_id')::uuid", "'app.user_id'")] + [InlineData("actor_id = current_setting('app.user_id')::uuid", "'app.user_id'")] + [InlineData( + "tenant_id = current_setting('app.tenant_id')::uuid and is_member(current_setting('app.user_id')::uuid, current_setting('app.tenant_id')::uuid)", + "'app.tenant_id'" + )] + [InlineData( + "user_id = current_setting('app.user_id')::uuid or is_owner(current_setting('app.user_id')::uuid, current_setting('app.tenant_id')::uuid)", + "'app.user_id'" + )] + public void Translate_CurrentSetting_NapPolicyShapes_EmitsMissingOkArgument( + string lql, + string keyLiteral + ) + { + var sql = Pg(lql); + + Assert.Contains($"current_setting({keyLiteral}, true)", sql, StringComparison.Ordinal); + Assert.DoesNotContain("__RLS_CURRENT_SETTING_", sql, StringComparison.Ordinal); + } + + [Fact] + public void Translate_CurrentSetting_WithUuidCast_KeepsCastTypeUnquoted() + { + var sql = Pg("tenant_id = current_setting('app.tenant_id')::uuid"); + + Assert.Contains("\"tenant_id\"", sql, StringComparison.Ordinal); + Assert.Contains( + "current_setting('app.tenant_id', true)::uuid", + sql, + StringComparison.Ordinal + ); + Assert.DoesNotContain("\"uuid\"", sql, StringComparison.Ordinal); + } + + [Fact] + public void Translate_CurrentSetting_MultipleKeys_RewritesEachCall() + { + var sql = Pg( + "tenant_id = current_setting('app.tenant_id')::uuid and user_id = current_setting('app.user_id')::uuid" + ); + + Assert.Contains( + "current_setting('app.tenant_id', true)::uuid", + sql, + StringComparison.Ordinal + ); + Assert.Contains( + "current_setting('app.user_id', true)::uuid", + sql, + StringComparison.Ordinal + ); + } + + [Fact] + public void Translate_CurrentSetting_InsideStringLiteral_IsNotRewritten() + { + var sql = Pg("note = 'current_setting(''app.tenant_id'')'"); + + Assert.Contains("'current_setting(''app.tenant_id'')'", sql, StringComparison.Ordinal); + Assert.DoesNotContain(", true", sql, StringComparison.Ordinal); + } + + [Fact] + public void Translate_CurrentSetting_ExistsPipeline_RewritesSettingsAfterLqlTranspile() + { + var sql = Pg( + """ + exists( + tenant_members + |> filter(fn(m) => m.tenant_id = current_setting('app.tenant_id')::uuid and m.user_id = current_setting('app.user_id')::uuid) + ) + """ + ); + + Assert.StartsWith("EXISTS (", sql, StringComparison.Ordinal); + Assert.Contains("FROM tenant_members", sql, StringComparison.Ordinal); + Assert.Contains( + "current_setting('app.tenant_id', true)::uuid", + sql, + StringComparison.Ordinal + ); + Assert.Contains( + "current_setting('app.user_id', true)::uuid", + sql, + StringComparison.Ordinal + ); + Assert.DoesNotContain("__RLS_CURRENT_SETTING_", sql, StringComparison.Ordinal); + } + + [Fact] + public void Translate_CurrentSetting_ExistsPipeline_FnCallArg_RewritesSetting() + { + var sql = Pg( + """ + exists( + tenant_members + |> filter(fn(m) => is_member(current_setting('app.user_id')::uuid, m.tenant_id)) + ) + """ + ); + + Assert.Contains( + "is_member(current_setting('app.user_id', true)::uuid", + sql, + StringComparison.Ordinal + ); + Assert.DoesNotContain("__RLS_CURRENT_SETTING_", sql, StringComparison.Ordinal); + } + + [Fact] + public void Generate_RlsPolicy_CurrentSetting_UsesPostgresMissingOkArgument() + { + var ddl = PostgresDdlGenerator.Generate( + new CreateRlsPolicyOperation( + "public", + "agent_configs", + new RlsPolicyDefinition + { + Name = "tenant_member", + Roles = ["app_user"], + UsingLql = "tenant_id = current_setting('app.tenant_id')::uuid", + WithCheckLql = "tenant_id = current_setting('app.tenant_id')::uuid", + } + ) + ); + + Assert.Contains( + "current_setting('app.tenant_id', true)::uuid", + ddl, + StringComparison.Ordinal + ); + Assert.DoesNotContain("current_setting('app.tenant_id')", ddl, StringComparison.Ordinal); + } + + [Fact] + public void Translate_CurrentSetting_Sqlite_ReturnsUnsupportedError() + { + var result = RlsPredicateTranspiler.Translate( + "tenant_id = current_setting('app.tenant_id')::uuid", + RlsPlatform.Sqlite, + "tenant" + ); + + Assert.True(result is TranspileError); + Assert.Contains("supported only for PostgreSQL", ((TranspileError)result).Value.Message); + } + + [Fact] + public void Translate_CurrentSetting_NonLiteralArgument_ReturnsParseError() + { + var result = RlsPredicateTranspiler.Translate( + "tenant_id = current_setting(app.tenant_id)::uuid", + RlsPlatform.Postgres, + "tenant" + ); + + Assert.True(result is TranspileError); + Assert.Contains( + "requires exactly one string-literal argument", + ((TranspileError)result).Value.Message + ); + } + + private static string Pg(string lql) + { + var result = RlsPredicateTranspiler.Translate(lql, RlsPlatform.Postgres, "p"); + Assert.True( + result is TranspileOk, + result is TranspileError e ? e.Value.Message : "expected Ok" + ); + return ((TranspileOk)result).Value; + } +}