diff --git a/packages/plugins/policy/src/expression-transformer.ts b/packages/plugins/policy/src/expression-transformer.ts index 27d77432a..e943e00a3 100644 --- a/packages/plugins/policy/src/expression-transformer.ts +++ b/packages/plugins/policy/src/expression-transformer.ts @@ -40,7 +40,9 @@ import { ReferenceNode, SelectionNode, SelectQueryNode, + sql, TableNode, + UnaryOperationNode, ValueListNode, ValueNode, WhereNode, @@ -253,13 +255,20 @@ export class ExpressionTransformer { if (ValueListNode.is(right)) { return BinaryOperationNode.create(left, OperatorNode.create('in'), right); } else { - // array contains const leftFieldDef = this.getFieldDefFromFieldRef(normalizedLeft, context); const comparand = leftFieldDef && QueryUtils.isEnum(this.schema, leftFieldDef.type) - ? // cast lhs otherwise dialect like pg can reject due to type mismatch - this.dialect.castText(new ExpressionWrapper(left)).toOperationNode() + ? this.dialect.castText(new ExpressionWrapper(left)).toOperationNode() : left; + + // if RHS is a subquery selecting an array column, use + // a cross-db EXISTS approach instead of `= ANY(subquery)` + const rightFieldDef = this.getFieldDefFromFieldRef(normalizedRight, context); + if (rightFieldDef?.array && SelectQueryNode.is(right)) { + return this.buildArrayInExists(comparand, right as SelectQueryNode); + } + + // array contains return BinaryOperationNode.create( comparand, OperatorNode.create('='), @@ -311,16 +320,11 @@ export class ExpressionTransformer { } private normalizeBinaryOperationOperands(expr: BinaryExpression, context: ExpressionTransformerContext) { - if (context.contextValue) { - // no normalization needed if evaluating against a value object - return { normalizedLeft: expr.left, normalizedRight: expr.right }; - } - - // if relation fields are used directly in comparison, it can only be compared with null, - // so we normalize the args with the id field (use the first id field if multiple) + // If relation fields are used directly in comparison, normalize both sides to the + // first id field. This is required both for SQL-backed relation comparisons and for + // value-backed auth/binding objects carried through collection predicates. let normalizedLeft: Expression = expr.left; if (this.isRelationField(expr.left, context)) { - invariant(ExpressionUtils.isNull(expr.right), 'only null comparison is supported for relation field'); const leftRelDef = this.getFieldDefFromFieldRef(expr.left, context); invariant(leftRelDef, 'failed to get relation field definition'); const idFields = QueryUtils.requireIdFields(this.schema, leftRelDef.type); @@ -328,7 +332,6 @@ export class ExpressionTransformer { } let normalizedRight: Expression = expr.right; if (this.isRelationField(expr.right, context)) { - invariant(ExpressionUtils.isNull(expr.left), 'only null comparison is supported for relation field'); const rightRelDef = this.getFieldDefFromFieldRef(expr.right, context); invariant(rightRelDef, 'failed to get relation field definition'); const idFields = QueryUtils.requireIdFields(this.schema, rightRelDef.type); @@ -340,7 +343,7 @@ export class ExpressionTransformer { private transformCollectionPredicate(expr: BinaryExpression, context: ExpressionTransformerContext) { this.ensureCollectionPredicateOperator(expr.op); - if (this.isAuthMember(expr.left) || context.contextValue) { + if (this.isAuthMember(expr.left) || (context.contextValue && !this.isThisRootedMember(expr.left))) { invariant( ExpressionUtils.isMember(expr.left) || ExpressionUtils.isField(expr.left), 'expected member or field expression', @@ -358,7 +361,7 @@ export class ExpressionTransformer { // get LHS's type const baseType = this.isAuthMember(expr.left) ? this.authType : context.modelOrType; - const memberType = this.getMemberType(baseType, expr.left); + const memberType = this.getMemberType(baseType, expr.left, context); // transform the entire expression with a value LHS and the correct context type return this.transformValueCollectionPredicate(receiver, expr, { ...context, modelOrType: memberType }); @@ -409,13 +412,10 @@ export class ExpressionTransformer { ...context, modelOrType: newContextModel, alias: undefined, + contextValue: undefined, bindingScope: bindingScope, }); - if (expr.op === '!') { - predicateFilter = logicalNot(this.dialect, predicateFilter); - } - const count = FunctionNode.create('count', [ValueNode.createImmediate(1)]); const predicateResult = match(expr.op) @@ -502,12 +502,26 @@ export class ExpressionTransformer { } } - private getMemberType(receiverType: string, expr: MemberExpression | FieldExpression) { + private getMemberType( + receiverType: string, + expr: MemberExpression | FieldExpression, + context?: ExpressionTransformerContext, + ) { if (ExpressionUtils.isField(expr)) { const fieldDef = QueryUtils.requireField(this.schema, receiverType, expr.field); return fieldDef.type; } else { let currType = receiverType; + if (ExpressionUtils.isThis(expr.receiver)) { + invariant(context, 'context is required for resolving this-rooted member types'); + currType = context.thisType; + } else if (ExpressionUtils.isBinding(expr.receiver)) { + invariant(context, 'context is required for resolving binding-rooted member types'); + currType = this.requireBindingScope(expr.receiver, context).type; + } else if (ExpressionUtils.isField(expr.receiver)) { + const fieldDef = QueryUtils.requireField(this.schema, receiverType, expr.receiver.field); + currType = fieldDef.type; + } for (const member of expr.members) { const fieldDef = QueryUtils.requireField(this.schema, currType, member); currType = fieldDef.type; @@ -515,6 +529,9 @@ export class ExpressionTransformer { return currType; } } + private isThisRootedMember(expr: Expression) { + return ExpressionUtils.isMember(expr) && ExpressionUtils.isThis(expr.receiver); + } private transformAuthBinary(expr: BinaryExpression, context: ExpressionTransformerContext) { if (expr.op !== '==' && expr.op !== '!=') { @@ -702,7 +719,11 @@ export class ExpressionTransformer { } else { // transform the first segment into a relation access, then continue with the rest of the members const firstMemberFieldDef = QueryUtils.requireField(this.schema, context.thisType, expr.members[0]!); - receiver = this.transformRelationAccess(expr.members[0]!, firstMemberFieldDef.type, restContext); + receiver = this.transformRelationAccess(expr.members[0]!, firstMemberFieldDef.type, { + ...restContext, + modelOrType: context.thisType, + alias: context.thisAlias, + }); members = expr.members.slice(1); // startType should be the type of the relation access startType = firstMemberFieldDef.type; @@ -756,7 +777,7 @@ export class ExpressionTransformer { currType = fieldDef.type; } - let currNode: SelectQueryNode | ColumnNode | ReferenceNode | undefined = undefined; + let currNode: SelectQueryNode | ColumnNode | ReferenceNode | FunctionNode | undefined = undefined; for (let i = members.length - 1; i >= 0; i--) { const member = members[i]!; @@ -788,7 +809,9 @@ export class ExpressionTransformer { invariant(i === members.length - 1, 'plain field access must be the last segment'); invariant(!currNode, 'plain field access must be the last segment'); - currNode = ColumnNode.create(member); + currNode = fieldDef.array && this.schema.provider.type === 'postgresql' + ? FunctionNode.create('unnest', [ColumnNode.create(member)]) + : ColumnNode.create(member); } } @@ -1015,6 +1038,36 @@ export class ExpressionTransformer { ExpressionUtils.isThis(expr.receiver) ) { return QueryUtils.getField(this.schema, model, expr.members[0]!); + } else if ( + ExpressionUtils.isMember(expr) && + ExpressionUtils.isThis(expr.receiver) && + expr.members.length > 1 + ) { + // `this.relation.field` chain — walk from the @@allow model + const firstDef = QueryUtils.getField(this.schema, model, expr.members[0]!); + if (!firstDef?.relation) return undefined; + let currModel = firstDef.type; + for (let i = 1; i < expr.members.length - 1; i++) { + const hopDef = QueryUtils.getField(this.schema, currModel, expr.members[i]!); + if (!hopDef?.relation) return undefined; + currModel = hopDef.type; + } + return QueryUtils.getField(this.schema, currModel, expr.members[expr.members.length - 1]!); + } else if (ExpressionUtils.isMember(expr) && ExpressionUtils.isBinding(expr.receiver)) { + const binding = this.requireBindingScope(expr.receiver, context); + const firstDef = QueryUtils.getField(this.schema, binding.type, expr.members[0]!); + if (!firstDef) return undefined; + if (expr.members.length === 1) { + return firstDef; + } + if (!firstDef.relation) return undefined; + let currModel = firstDef.type; + for (let i = 1; i < expr.members.length - 1; i++) { + const hopDef = QueryUtils.getField(this.schema, currModel, expr.members[i]!); + if (!hopDef?.relation) return undefined; + currModel = hopDef.type; + } + return QueryUtils.getField(this.schema, currModel, expr.members[expr.members.length - 1]!); } else if (ExpressionUtils.isMember(expr) && ExpressionUtils.isField(expr.receiver)) { // relation chain access (e.g. `owner.id`, `user.profile.uuid_field`): walk the // relation hops and return the terminal field's FieldDef so native-type info @@ -1032,4 +1085,72 @@ export class ExpressionTransformer { return undefined; } } + /** + * Build a cross-database EXISTS subquery for `scalar IN relation.arrayField`. + * Preserves the original subquery FROM to handle joined relations (e.g. m2m). + */ + private buildArrayInExists( + scalar: OperationNode, + subquery: SelectQueryNode, + ): OperationNode { + // PG: subquery already has unnest() from _member, just use = ANY(subquery) + if (this.schema.provider.type === 'postgresql') { + return BinaryOperationNode.create( + scalar, + OperatorNode.create('='), + FunctionNode.create('any', [subquery as unknown as OperationNode]), + ); + } + + const eb = this.eb; + + const table = subquery.from!.froms[0] as TableNode; + const tableName = table.table.identifier.name; + + const sel = subquery.selections![0]!; + const alias = sel.selection as AliasNode; + const colName = (alias.node as ColumnNode).column.name; + + const tableRef = eb.ref(`${tableName}.${colName}`); + const scalarRef = new ExpressionWrapper(scalar); + + let arrayCheck: OperationNode; + if (this.schema.provider.type === 'sqlite') { + arrayCheck = eb + .exists( + eb + .selectFrom(eb.fn('json_each', [tableRef]).as('_je')) + .select(eb.lit(1).as('_')) + .where(eb.ref('_je.value'), '=', scalarRef), + ) + .toOperationNode(); + } else { + // mysql + arrayCheck = eb + .exists( + eb + .selectFrom( + sql`JSON_TABLE(${tableRef}, '$[*]' COLUMNS(value JSON PATH '$'))`.as('_jt'), + ) + .select(eb.lit(1).as('_')) + .where(eb.ref('_jt.value'), '=', scalarRef), + ) + .toOperationNode(); + } + + const combinedWhere = subquery.where + ? conjunction(this.dialect, [subquery.where.where, arrayCheck]) + : arrayCheck; + + return UnaryOperationNode.create( + { + ...(subquery as SelectQueryNode), + where: WhereNode.create(combinedWhere), + selections: [ + SelectionNode.create(AliasNode.create(ValueNode.createImmediate(1), IdentifierNode.create('_'))), + ], + } as SelectQueryNode, + OperatorNode.create('exists'), + ); + } } diff --git a/tests/e2e/orm/policy/auth-access.test.ts b/tests/e2e/orm/policy/auth-access.test.ts index 56942de49..7643e5368 100644 --- a/tests/e2e/orm/policy/auth-access.test.ts +++ b/tests/e2e/orm/policy/auth-access.test.ts @@ -475,4 +475,208 @@ model Channel { userDb2.channel.update({ where: { id: 1 }, data: { name: 'general-updated' } }), ).resolves.toBeTruthy(); }); + + it('resolves this.relation.field against @@allow model in collection predicates (Fix #1)', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id @default(autoincrement()) + level Int + permissions Permission[] + posts Post[] + @@auth +} + +model Permission { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int + clearance Int +} + +model Post { + id Int @id @default(autoincrement()) + author User @relation(fields: [authorId], references: [id]) + authorId Int + + @@allow('read', auth().permissions?[p, p.clearance >= this.author.level]) +} +`, + { provider: 'postgresql' }, + ); + + await db.$unuseAll().post.create({ + data: { id: 1, author: { create: { id: 1, level: 5 } } }, + }); + await db.$unuseAll().post.create({ + data: { id: 2, author: { create: { id: 2, level: 10 } } }, + }); + + // no auth: no permissions → cannot read any post + await expect(db.post.findMany()).resolves.toHaveLength(0); + + // clearance 5: can read author level ≤ 5 → only post 1 (author level 5) + const user1 = db.$setAuth({ + id: 3, + permissions: [{ id: 1, clearance: 5 }], + }); + const posts1 = await user1.post.findMany(); + expect(posts1.map((p) => p.id).sort()).toEqual([1]); + + // clearance 10: can read author level ≤ 10 → both posts + const user2 = db.$setAuth({ + id: 4, + permissions: [{ id: 2, clearance: 10 }], + }); + const posts2 = await user2.post.findMany(); + expect(posts2.map((p) => p.id).sort()).toEqual([1, 2]); + }); + + it('handles this.relation.arrayField with in operator (Fix #2)', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id @default(autoincrement()) + permissions Permission[] + @@auth +} + +model Group { + id Int @id @default(autoincrement()) + visibleDocIds Int[] + docs Doc[] +} + +model Permission { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int + allowedDocIds Int[] +} + +model Doc { + id Int @id @default(autoincrement()) + group Group @relation(fields: [groupId], references: [id]) + groupId Int + + @@allow('read', + auth().permissions?[p, this.id in p.allowedDocIds] || + this.id in this.group.visibleDocIds + ) +} +`, + { provider: 'postgresql' }, + ); + + await db.$unuseAll().group.create({ + data: { id: 1, visibleDocIds: [1] }, + }); + await db.$unuseAll().group.create({ + data: { id: 2, visibleDocIds: [] }, + }); + await db.$unuseAll().user.create({ + data: { id: 1 }, + }); + await db.$unuseAll().user.create({ + data: { id: 2 }, + }); + await db.$unuseAll().permission.create({ + data: { id: 10, userId: 2, allowedDocIds: [2] }, + }); + await db.$unuseAll().doc.createMany({ + data: [ + { id: 1, groupId: 1 }, + { id: 2, groupId: 2 }, + ], + }); + + // User 1 (no perms): doc 1 visible via group.visibleDocIds + const user1 = db.$setAuth({ id: 1, permissions: [] }); + expect((await user1.doc.findMany()).map((d) => d.id).sort()).toEqual([1]); + + // User 2 (perm allows doc 2): sees doc 1 (group-visible) + doc 2 (permission) + const user2 = db.$setAuth({ + id: 2, + permissions: [{ id: 10, allowedDocIds: [2] }], + }); + expect((await user2.doc.findMany()).map((d) => d.id).sort()).toEqual([1, 2]); + }); + + it('keeps this-rooted collection predicates on the @@allow model inside auth bindings (Fix #3)', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + assignments RoleAssignment[] + @@auth +} + +model Scope { + id Int @id + parentId Int? + parent Scope? @relation("ScopeParent", fields: [parentId], references: [id]) + children Scope[] @relation("ScopeParent") + ancestors ScopeClosure[] @relation("Descendant") + descendants ScopeClosure[] @relation("Ancestor") + docs Doc[] + assignments RoleAssignment[] + @@allow('all', true) +} + +model ScopeClosure { + ancestorId Int + descendantId Int + ancestor Scope @relation("Ancestor", fields: [ancestorId], references: [id]) + descendant Scope @relation("Descendant", fields: [descendantId], references: [id]) + @@id([ancestorId, descendantId]) + @@allow('all', true) +} + +model RoleAssignment { + id Int @id + userId Int + scopeId Int + user User @relation(fields: [userId], references: [id]) + scope Scope @relation(fields: [scopeId], references: [id]) + @@allow('all', true) +} + +model Doc { + id Int @id + authScopeId Int + authScope Scope @relation(fields: [authScopeId], references: [id]) + + @@allow('read', auth().assignments?[rs, + rs.scope == this.authScope || + this.authScope.ancestors?[ancestor == rs.scope] + ]) +} +`, + { provider: 'postgresql' }, + ); + + await db.$unuseAll().scope.createMany({ + data: [ + { id: 1 }, + { id: 2, parentId: 1 }, + ], + }); + await db.$unuseAll().scopeClosure.createMany({ + data: [ + { ancestorId: 1, descendantId: 1 }, + { ancestorId: 2, descendantId: 2 }, + { ancestorId: 1, descendantId: 2 }, + ], + }); + await db.$unuseAll().doc.create({ + data: { id: 1, authScopeId: 2 }, + }); + + const reader = db.$setAuth({ + id: 1, + assignments: [{ id: 1, scopeId: 1, scope: { id: 1 } }], + }); + + await expect(reader.doc.findUnique({ where: { id: 1 } })).toResolveTruthy(); + }); });