Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 143 additions & 22 deletions packages/plugins/policy/src/expression-transformer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ import {
ReferenceNode,
SelectionNode,
SelectQueryNode,
sql,
TableNode,
UnaryOperationNode,
ValueListNode,
ValueNode,
WhereNode,
Expand Down Expand Up @@ -253,13 +255,20 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
if (ValueListNode.is(right)) {
return BinaryOperationNode.create(left, OperatorNode.create('in'), right);
} else {
// array contains
const leftFieldDef = this.getFieldDefFromFieldRef(normalizedLeft, context);
const comparand =
leftFieldDef && QueryUtils.isEnum(this.schema, leftFieldDef.type)
? // cast lhs otherwise dialect like pg can reject due to type mismatch
this.dialect.castText(new ExpressionWrapper(left)).toOperationNode()
? this.dialect.castText(new ExpressionWrapper(left)).toOperationNode()
: left;

// if RHS is a subquery selecting an array column, use
// a cross-db EXISTS approach instead of `= ANY(subquery)`
const rightFieldDef = this.getFieldDefFromFieldRef(normalizedRight, context);
if (rightFieldDef?.array && SelectQueryNode.is(right)) {
return this.buildArrayInExists(comparand, right as SelectQueryNode);
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

// array contains
return BinaryOperationNode.create(
comparand,
OperatorNode.create('='),
Expand Down Expand Up @@ -311,24 +320,18 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
}

private normalizeBinaryOperationOperands(expr: BinaryExpression, context: ExpressionTransformerContext) {
if (context.contextValue) {
// no normalization needed if evaluating against a value object
return { normalizedLeft: expr.left, normalizedRight: expr.right };
}

// if relation fields are used directly in comparison, it can only be compared with null,
// so we normalize the args with the id field (use the first id field if multiple)
// If relation fields are used directly in comparison, normalize both sides to the
// first id field. This is required both for SQL-backed relation comparisons and for
// value-backed auth/binding objects carried through collection predicates.
let normalizedLeft: Expression = expr.left;
if (this.isRelationField(expr.left, context)) {
invariant(ExpressionUtils.isNull(expr.right), 'only null comparison is supported for relation field');
const leftRelDef = this.getFieldDefFromFieldRef(expr.left, context);
invariant(leftRelDef, 'failed to get relation field definition');
const idFields = QueryUtils.requireIdFields(this.schema, leftRelDef.type);
normalizedLeft = this.makeOrAppendMember(normalizedLeft, idFields[0]!);
}
let normalizedRight: Expression = expr.right;
if (this.isRelationField(expr.right, context)) {
invariant(ExpressionUtils.isNull(expr.left), 'only null comparison is supported for relation field');
const rightRelDef = this.getFieldDefFromFieldRef(expr.right, context);
invariant(rightRelDef, 'failed to get relation field definition');
const idFields = QueryUtils.requireIdFields(this.schema, rightRelDef.type);
Expand All @@ -340,7 +343,7 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
private transformCollectionPredicate(expr: BinaryExpression, context: ExpressionTransformerContext) {
this.ensureCollectionPredicateOperator(expr.op);

if (this.isAuthMember(expr.left) || context.contextValue) {
if (this.isAuthMember(expr.left) || (context.contextValue && !this.isThisRootedMember(expr.left))) {
invariant(
ExpressionUtils.isMember(expr.left) || ExpressionUtils.isField(expr.left),
'expected member or field expression',
Expand All @@ -358,7 +361,7 @@ export class ExpressionTransformer<Schema extends SchemaDef> {

// get LHS's type
const baseType = this.isAuthMember(expr.left) ? this.authType : context.modelOrType;
const memberType = this.getMemberType(baseType, expr.left);
const memberType = this.getMemberType(baseType, expr.left, context);

// transform the entire expression with a value LHS and the correct context type
return this.transformValueCollectionPredicate(receiver, expr, { ...context, modelOrType: memberType });
Expand Down Expand Up @@ -409,13 +412,10 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
...context,
modelOrType: newContextModel,
alias: undefined,
contextValue: undefined,
bindingScope: bindingScope,
});

if (expr.op === '!') {
predicateFilter = logicalNot(this.dialect, predicateFilter);
}

const count = FunctionNode.create('count', [ValueNode.createImmediate(1)]);

const predicateResult = match(expr.op)
Expand Down Expand Up @@ -502,19 +502,36 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
}
}

private getMemberType(receiverType: string, expr: MemberExpression | FieldExpression) {
private getMemberType(
receiverType: string,
expr: MemberExpression | FieldExpression,
context?: ExpressionTransformerContext,
) {
if (ExpressionUtils.isField(expr)) {
const fieldDef = QueryUtils.requireField(this.schema, receiverType, expr.field);
return fieldDef.type;
} else {
let currType = receiverType;
if (ExpressionUtils.isThis(expr.receiver)) {
invariant(context, 'context is required for resolving this-rooted member types');
currType = context.thisType;
} else if (ExpressionUtils.isBinding(expr.receiver)) {
invariant(context, 'context is required for resolving binding-rooted member types');
currType = this.requireBindingScope(expr.receiver, context).type;
} else if (ExpressionUtils.isField(expr.receiver)) {
const fieldDef = QueryUtils.requireField(this.schema, receiverType, expr.receiver.field);
currType = fieldDef.type;
}
for (const member of expr.members) {
const fieldDef = QueryUtils.requireField(this.schema, currType, member);
currType = fieldDef.type;
}
return currType;
}
}
private isThisRootedMember(expr: Expression) {
return ExpressionUtils.isMember(expr) && ExpressionUtils.isThis(expr.receiver);
}

private transformAuthBinary(expr: BinaryExpression, context: ExpressionTransformerContext) {
if (expr.op !== '==' && expr.op !== '!=') {
Expand Down Expand Up @@ -702,7 +719,11 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
} else {
// transform the first segment into a relation access, then continue with the rest of the members
const firstMemberFieldDef = QueryUtils.requireField(this.schema, context.thisType, expr.members[0]!);
receiver = this.transformRelationAccess(expr.members[0]!, firstMemberFieldDef.type, restContext);
receiver = this.transformRelationAccess(expr.members[0]!, firstMemberFieldDef.type, {
...restContext,
modelOrType: context.thisType,
alias: context.thisAlias,
});
Comment thread
coderabbitai[bot] marked this conversation as resolved.
members = expr.members.slice(1);
// startType should be the type of the relation access
startType = firstMemberFieldDef.type;
Expand Down Expand Up @@ -756,7 +777,7 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
currType = fieldDef.type;
}

let currNode: SelectQueryNode | ColumnNode | ReferenceNode | undefined = undefined;
let currNode: SelectQueryNode | ColumnNode | ReferenceNode | FunctionNode | undefined = undefined;

for (let i = members.length - 1; i >= 0; i--) {
const member = members[i]!;
Expand Down Expand Up @@ -788,7 +809,9 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
invariant(i === members.length - 1, 'plain field access must be the last segment');
invariant(!currNode, 'plain field access must be the last segment');

currNode = ColumnNode.create(member);
currNode = fieldDef.array && this.schema.provider.type === 'postgresql'
? FunctionNode.create('unnest', [ColumnNode.create(member)])
: ColumnNode.create(member);
}
}

Expand Down Expand Up @@ -1015,6 +1038,36 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
ExpressionUtils.isThis(expr.receiver)
) {
return QueryUtils.getField(this.schema, model, expr.members[0]!);
} else if (
ExpressionUtils.isMember(expr) &&
ExpressionUtils.isThis(expr.receiver) &&
expr.members.length > 1
) {
// `this.relation.field` chain — walk from the @@allow model
const firstDef = QueryUtils.getField(this.schema, model, expr.members[0]!);
if (!firstDef?.relation) return undefined;
let currModel = firstDef.type;
for (let i = 1; i < expr.members.length - 1; i++) {
const hopDef = QueryUtils.getField(this.schema, currModel, expr.members[i]!);
if (!hopDef?.relation) return undefined;
currModel = hopDef.type;
}
return QueryUtils.getField(this.schema, currModel, expr.members[expr.members.length - 1]!);
} else if (ExpressionUtils.isMember(expr) && ExpressionUtils.isBinding(expr.receiver)) {
const binding = this.requireBindingScope(expr.receiver, context);
const firstDef = QueryUtils.getField(this.schema, binding.type, expr.members[0]!);
if (!firstDef) return undefined;
if (expr.members.length === 1) {
return firstDef;
}
if (!firstDef.relation) return undefined;
let currModel = firstDef.type;
for (let i = 1; i < expr.members.length - 1; i++) {
const hopDef = QueryUtils.getField(this.schema, currModel, expr.members[i]!);
if (!hopDef?.relation) return undefined;
currModel = hopDef.type;
}
return QueryUtils.getField(this.schema, currModel, expr.members[expr.members.length - 1]!);
} else if (ExpressionUtils.isMember(expr) && ExpressionUtils.isField(expr.receiver)) {
// relation chain access (e.g. `owner.id`, `user.profile.uuid_field`): walk the
// relation hops and return the terminal field's FieldDef so native-type info
Expand All @@ -1032,4 +1085,72 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
return undefined;
}
}
/**
* Build a cross-database EXISTS subquery for `scalar IN relation.arrayField`.
* Preserves the original subquery FROM to handle joined relations (e.g. m2m).
*/
private buildArrayInExists(
scalar: OperationNode,
subquery: SelectQueryNode,
): OperationNode {
// PG: subquery already has unnest() from _member, just use = ANY(subquery)
if (this.schema.provider.type === 'postgresql') {
return BinaryOperationNode.create(
scalar,
OperatorNode.create('='),
FunctionNode.create('any', [subquery as unknown as OperationNode]),
);
}

const eb = this.eb;

const table = subquery.from!.froms[0] as TableNode;
const tableName = table.table.identifier.name;

const sel = subquery.selections![0]!;
const alias = sel.selection as AliasNode;
const colName = (alias.node as ColumnNode).column.name;

const tableRef = eb.ref(`${tableName}.${colName}`);
const scalarRef = new ExpressionWrapper(scalar);

let arrayCheck: OperationNode;
if (this.schema.provider.type === 'sqlite') {
arrayCheck = eb
.exists(
eb
.selectFrom(eb.fn('json_each', [tableRef]).as('_je'))
.select(eb.lit(1).as('_'))
.where(eb.ref('_je.value'), '=', scalarRef),
)
.toOperationNode();
} else {
// mysql
arrayCheck = eb
.exists(
eb
.selectFrom(
sql`JSON_TABLE(${tableRef}, '$[*]' COLUMNS(value JSON PATH '$'))`.as('_jt'),
)
.select(eb.lit(1).as('_'))
.where(eb.ref('_jt.value'), '=', scalarRef),
)
.toOperationNode();
}

const combinedWhere = subquery.where
? conjunction(this.dialect, [subquery.where.where, arrayCheck])
: arrayCheck;

return UnaryOperationNode.create(
{
...(subquery as SelectQueryNode),
where: WhereNode.create(combinedWhere),
selections: [
SelectionNode.create(AliasNode.create(ValueNode.createImmediate(1), IdentifierNode.create('_'))),
],
} as SelectQueryNode,
OperatorNode.create('exists'),
);
}
}
Loading