diff --git a/liquidjava-verifier/src/main/java/liquidjava/processor/refinement_checker/VCChecker.java b/liquidjava-verifier/src/main/java/liquidjava/processor/refinement_checker/VCChecker.java index 51084672..ac213fb9 100644 --- a/liquidjava-verifier/src/main/java/liquidjava/processor/refinement_checker/VCChecker.java +++ b/liquidjava-verifier/src/main/java/liquidjava/processor/refinement_checker/VCChecker.java @@ -58,8 +58,8 @@ public void processSubtyping(Predicate expectedType, List list, CtEl } SMTResult result = verifySMTSubtype(expected, premises, element.getPosition()); if (result.isError()) { - throw new RefinementError(element.getPosition(), expectedType.simplify(), premisesBeforeChange.simplify(), - map, result.getCounterexample(), customMessage); + throw new RefinementError(element.getPosition(), expectedType.simplify(context), + premisesBeforeChange.simplify(context), map, result.getCounterexample(), customMessage); } } @@ -277,7 +277,7 @@ protected void throwRefinementError(SourcePosition position, Predicate expected, gatherVariables(found, lrv, mainVars); TranslationTable map = new TranslationTable(); Predicate premises = joinPredicates(expected, mainVars, lrv, map).toConjunctions(); - throw new RefinementError(position, expected.simplify(), premises.simplify(), map, counterexample, + throw new RefinementError(position, expected.simplify(context), premises.simplify(context), map, counterexample, customMessage); } @@ -288,7 +288,7 @@ protected void throwStateRefinementError(SourcePosition position, Predicate foun TranslationTable map = new TranslationTable(); VCImplication foundState = joinPredicates(found, mainVars, lrv, map); throw new StateRefinementError(position, expected.getExpression(), - foundState.toConjunctions().simplify().getValue(), map, customMessage); + foundState.toConjunctions().simplify(context).getValue(), map, customMessage); } protected void throwStateConflictError(SourcePosition position, Predicate expected) throws StateConflictError { diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/Predicate.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/Predicate.java index c91380ee..1445517f 100644 --- a/liquidjava-verifier/src/main/java/liquidjava/rj_language/Predicate.java +++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/Predicate.java @@ -187,8 +187,14 @@ public Expression getExpression() { return exp; } - public ValDerivationNode simplify() { - return ExpressionSimplifier.simplify(exp.clone()); + public ValDerivationNode simplify(Context context) { + // collect aliases from context + Map aliases = new HashMap<>(); + for (AliasWrapper aw : context.getAliases()) { + aliases.put(aw.getName(), aw.createAliasDTO()); + } + // simplify expression + return ExpressionSimplifier.simplify(exp.clone(), aliases); } private static boolean isBooleanLiteral(Expression expr, boolean value) { diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/AliasExpansion.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/AliasExpansion.java new file mode 100644 index 00000000..b8ab6b3f --- /dev/null +++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/AliasExpansion.java @@ -0,0 +1,91 @@ +package liquidjava.rj_language.opt; + +import java.util.Map; + +import liquidjava.processor.facade.AliasDTO; +import liquidjava.rj_language.ast.AliasInvocation; +import liquidjava.rj_language.ast.BinaryExpression; +import liquidjava.rj_language.ast.Expression; +import liquidjava.rj_language.ast.GroupExpression; +import liquidjava.rj_language.ast.UnaryExpression; +import liquidjava.rj_language.ast.Var; +import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode; +import liquidjava.rj_language.opt.derivation_node.DerivationNode; +import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode; +import liquidjava.rj_language.opt.derivation_node.ValDerivationNode; + +public class AliasExpansion { + + /** + * Expands alias invocations in a derivation node to their definitions, storing the expanded body as the origin of + * each alias invocation node. + */ + public static ValDerivationNode expand(ValDerivationNode node, Map aliases) { + return expandRecursive(node, aliases); + } + + private static ValDerivationNode expandRecursive(ValDerivationNode node, Map aliases) { + Expression exp = node.getValue(); + + // expand alias invocation + if (exp instanceof AliasInvocation ai) { + return expandAlias(ai, aliases); + } + + // recurse into binary expressions + if (exp instanceof BinaryExpression binary) { + ValDerivationNode leftNode; + ValDerivationNode rightNode; + if (node.getOrigin()instanceof BinaryDerivationNode binOrigin) { + leftNode = expandRecursive(binOrigin.getLeft(), aliases); + rightNode = expandRecursive(binOrigin.getRight(), aliases); + } else { + leftNode = expandRecursive(new ValDerivationNode(binary.getFirstOperand(), null), aliases); + rightNode = expandRecursive(new ValDerivationNode(binary.getSecondOperand(), null), aliases); + } + boolean hasExpansion = leftNode.getOrigin() != null || rightNode.getOrigin() != null; + DerivationNode origin = hasExpansion ? new BinaryDerivationNode(leftNode, rightNode, binary.getOperator()) + : node.getOrigin(); + return new ValDerivationNode(exp, origin); + } + + // recurse into unary expressions + if (exp instanceof UnaryExpression unary) { + ValDerivationNode operandNode; + if (node.getOrigin()instanceof UnaryDerivationNode unaryOrigin) { + operandNode = expandRecursive(unaryOrigin.getOperand(), aliases); + } else { + operandNode = expandRecursive(new ValDerivationNode(unary.getChildren().get(0), null), aliases); + } + DerivationNode origin = operandNode.getOrigin() != null + ? new UnaryDerivationNode(operandNode, unary.getOp()) : node.getOrigin(); + return new ValDerivationNode(exp, origin); + } + + // recurse into group expressions + if (exp instanceof GroupExpression group && group.getChildren().size() == 1) { + return expandRecursive(new ValDerivationNode(group.getChildren().get(0), node.getOrigin()), aliases); + } + + return node; + } + + private static ValDerivationNode expandAlias(AliasInvocation ai, Map aliases) { + AliasDTO dto = aliases.get(ai.getName()); + + // no alias found + if (dto == null || dto.getExpression() == null) { + return new ValDerivationNode(ai, null); + } + + // substitute parameters in the body with the invocation arguments + Expression body = dto.getExpression().clone(); + for (int i = 0; i < ai.getArgs().size() && i < dto.getVarNames().size(); i++) { + body = body.substitute(new Var(dto.getVarNames().get(i)), ai.getArgs().get(i)); + } + + // recursively expand the body + ValDerivationNode expandedBody = expandRecursive(new ValDerivationNode(body, null), aliases); + return new ValDerivationNode(ai, expandedBody); + } +} diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionSimplifier.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionSimplifier.java index 2e43e326..289b2e38 100644 --- a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionSimplifier.java +++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionSimplifier.java @@ -1,5 +1,8 @@ package liquidjava.rj_language.opt; +import java.util.Map; + +import liquidjava.processor.facade.AliasDTO; import liquidjava.rj_language.ast.BinaryExpression; import liquidjava.rj_language.ast.Expression; import liquidjava.rj_language.ast.LiteralBoolean; @@ -10,12 +13,18 @@ public class ExpressionSimplifier { /** - * Simplifies an expression by applying constant propagation, constant folding and removing redundant conjuncts - * Returns a derivation node representing the tree of simplifications applied + * Simplifies an expression by applying constant propagation, constant folding, removing redundant conjuncts and + * expanding aliases Returns a derivation node representing the tree of simplifications applied */ + public static ValDerivationNode simplify(Expression exp, Map aliases) { + ValDerivationNode node = new ValDerivationNode(exp, null); + ValDerivationNode fixedPoint = simplifyToFixedPoint(node, exp); + ValDerivationNode simplified = simplifyValDerivationNode(fixedPoint); + return AliasExpansion.expand(simplified, aliases); + } + public static ValDerivationNode simplify(Expression exp) { - ValDerivationNode fixedPoint = simplifyToFixedPoint(null, exp); - return simplifyValDerivationNode(fixedPoint); + return simplify(exp, Map.of()); } /** diff --git a/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java b/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java index b49ce805..bbdb8381 100644 --- a/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java +++ b/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java @@ -2,6 +2,11 @@ import static org.junit.jupiter.api.Assertions.*; +import java.util.List; +import java.util.Map; + +import liquidjava.processor.facade.AliasDTO; +import liquidjava.rj_language.ast.AliasInvocation; import liquidjava.rj_language.ast.BinaryExpression; import liquidjava.rj_language.ast.Expression; import liquidjava.rj_language.ast.LiteralBoolean; @@ -550,6 +555,72 @@ void testTransitive() { assertEquals("a == 1", result.getValue().toString(), "Expected result to be a == 1"); } + @Test + void testByteAliasExpansion() { + // Given: Byte(b) with alias Byte(int b) { b >= -128 && b <= 127 } + AliasDTO byteAlias = new AliasDTO("Byte", List.of("int"), List.of("b"), "b >= -128 && b <= 127"); + byteAlias.parse(""); + Map aliases = Map.of("Byte", byteAlias); + Expression exp = new AliasInvocation("Byte", List.of(new Var("b"))); + + // When + ValDerivationNode result = ExpressionSimplifier.simplify(exp, aliases); + + // Then + assertEquals("Byte(b)", result.getValue().toString()); + assertNotNull(result.getOrigin(), "Origin should contain the expanded body"); + ValDerivationNode origin = (ValDerivationNode) result.getOrigin(); + assertEquals("b >= -128 && b <= 127", origin.getValue().toString()); + } + + @Test + void testPositiveAliasExpansion() { + // Given: Positive(x) with alias Positive(int v) { v > 0 } + AliasDTO positiveAlias = new AliasDTO("Positive", List.of("int"), List.of("v"), "v > 0"); + positiveAlias.parse(""); + Map aliases = Map.of("Positive", positiveAlias); + Expression exp = new AliasInvocation("Positive", List.of(new Var("x"))); + + // When + ValDerivationNode result = ExpressionSimplifier.simplify(exp, aliases); + + // Then + assertEquals("Positive(x)", result.getValue().toString()); + assertNotNull(result.getOrigin(), "Origin should contain the expanded body"); + ValDerivationNode origin = (ValDerivationNode) result.getOrigin(); + assertEquals("x > 0", origin.getValue().toString()); + } + + @Test + void testTwoArgAliasWithNormalExpression() { + // Given: Bounded(v, 100) && v > 50 with alias Bounded(int x, int n) { x > 0 && x < n } + AliasDTO boundedAlias = new AliasDTO("Bounded", List.of("int", "int"), List.of("x", "n"), "x > 0 && x < n"); + boundedAlias.parse(""); + Map aliases = Map.of("Bounded", boundedAlias); + + Expression varV = new Var("v"); + Expression bounded = new AliasInvocation("Bounded", List.of(varV, new LiteralInt(100))); + Expression vGt50 = new BinaryExpression(varV, ">", new LiteralInt(50)); + Expression fullExpression = new BinaryExpression(bounded, "&&", vGt50); + + // When + ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression, aliases); + + // Then + assertEquals("Bounded(v, 100) && v > 50", result.getValue().toString()); + assertInstanceOf(BinaryDerivationNode.class, result.getOrigin()); + BinaryDerivationNode binOrigin = (BinaryDerivationNode) result.getOrigin(); + assertEquals("&&", binOrigin.getOp()); + ValDerivationNode leftNode = binOrigin.getLeft(); + assertEquals("Bounded(v, 100)", leftNode.getValue().toString()); + assertNotNull(leftNode.getOrigin(), "Alias invocation should have expanded body as origin"); + ValDerivationNode expandedBody = (ValDerivationNode) leftNode.getOrigin(); + assertEquals("v > 0 && v < 100", expandedBody.getValue().toString()); + ValDerivationNode rightNode = binOrigin.getRight(); + assertEquals("v > 50", rightNode.getValue().toString()); + assertNull(rightNode.getOrigin()); + } + /** * Helper method to compare two derivation nodes recursively */