Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ public void processSubtyping(Predicate expectedType, List<GhostState> list, CtEl
}
SMTResult result = verifySMTSubtype(expected, premises, element.getPosition());
if (result.isError()) {
throw new RefinementError(element.getPosition(), expectedType.simplify(), premisesBeforeChange.simplify(),
map, result.getCounterexample(), customMessage);
throw new RefinementError(element.getPosition(), expectedType.simplify(context),
premisesBeforeChange.simplify(context), map, result.getCounterexample(), customMessage);
}
}

Expand Down Expand Up @@ -277,7 +277,7 @@ protected void throwRefinementError(SourcePosition position, Predicate expected,
gatherVariables(found, lrv, mainVars);
TranslationTable map = new TranslationTable();
Predicate premises = joinPredicates(expected, mainVars, lrv, map).toConjunctions();
throw new RefinementError(position, expected.simplify(), premises.simplify(), map, counterexample,
throw new RefinementError(position, expected.simplify(context), premises.simplify(context), map, counterexample,
customMessage);
}

Expand All @@ -288,7 +288,7 @@ protected void throwStateRefinementError(SourcePosition position, Predicate foun
TranslationTable map = new TranslationTable();
VCImplication foundState = joinPredicates(found, mainVars, lrv, map);
throw new StateRefinementError(position, expected.getExpression(),
foundState.toConjunctions().simplify().getValue(), map, customMessage);
foundState.toConjunctions().simplify(context).getValue(), map, customMessage);
}

protected void throwStateConflictError(SourcePosition position, Predicate expected) throws StateConflictError {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,14 @@ public Expression getExpression() {
return exp;
}

public ValDerivationNode simplify() {
return ExpressionSimplifier.simplify(exp.clone());
public ValDerivationNode simplify(Context context) {
// collect aliases from context
Map<String, AliasDTO> aliases = new HashMap<>();
for (AliasWrapper aw : context.getAliases()) {
aliases.put(aw.getName(), aw.createAliasDTO());
}
// simplify expression
return ExpressionSimplifier.simplify(exp.clone(), aliases);
}

private static boolean isBooleanLiteral(Expression expr, boolean value) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package liquidjava.rj_language.opt;

import java.util.Map;

import liquidjava.processor.facade.AliasDTO;
import liquidjava.rj_language.ast.AliasInvocation;
import liquidjava.rj_language.ast.BinaryExpression;
import liquidjava.rj_language.ast.Expression;
import liquidjava.rj_language.ast.GroupExpression;
import liquidjava.rj_language.ast.UnaryExpression;
import liquidjava.rj_language.ast.Var;
import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode;
import liquidjava.rj_language.opt.derivation_node.DerivationNode;
import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode;
import liquidjava.rj_language.opt.derivation_node.ValDerivationNode;

public class AliasExpansion {

/**
* Expands alias invocations in a derivation node to their definitions, storing the expanded body as the origin of
* each alias invocation node.
*/
public static ValDerivationNode expand(ValDerivationNode node, Map<String, AliasDTO> aliases) {
return expandRecursive(node, aliases);
}

private static ValDerivationNode expandRecursive(ValDerivationNode node, Map<String, AliasDTO> aliases) {
Expression exp = node.getValue();

// expand alias invocation
if (exp instanceof AliasInvocation ai) {
return expandAlias(ai, aliases);
}

// recurse into binary expressions
if (exp instanceof BinaryExpression binary) {
ValDerivationNode leftNode;
ValDerivationNode rightNode;
if (node.getOrigin()instanceof BinaryDerivationNode binOrigin) {
leftNode = expandRecursive(binOrigin.getLeft(), aliases);
rightNode = expandRecursive(binOrigin.getRight(), aliases);
} else {
leftNode = expandRecursive(new ValDerivationNode(binary.getFirstOperand(), null), aliases);
rightNode = expandRecursive(new ValDerivationNode(binary.getSecondOperand(), null), aliases);
}
boolean hasExpansion = leftNode.getOrigin() != null || rightNode.getOrigin() != null;
DerivationNode origin = hasExpansion ? new BinaryDerivationNode(leftNode, rightNode, binary.getOperator())
: node.getOrigin();
return new ValDerivationNode(exp, origin);
}

// recurse into unary expressions
if (exp instanceof UnaryExpression unary) {
ValDerivationNode operandNode;
if (node.getOrigin()instanceof UnaryDerivationNode unaryOrigin) {
operandNode = expandRecursive(unaryOrigin.getOperand(), aliases);
} else {
operandNode = expandRecursive(new ValDerivationNode(unary.getChildren().get(0), null), aliases);
}
DerivationNode origin = operandNode.getOrigin() != null
? new UnaryDerivationNode(operandNode, unary.getOp()) : node.getOrigin();
return new ValDerivationNode(exp, origin);
}

// recurse into group expressions
if (exp instanceof GroupExpression group && group.getChildren().size() == 1) {
return expandRecursive(new ValDerivationNode(group.getChildren().get(0), node.getOrigin()), aliases);
}

return node;
}

private static ValDerivationNode expandAlias(AliasInvocation ai, Map<String, AliasDTO> aliases) {
AliasDTO dto = aliases.get(ai.getName());

// no alias found
if (dto == null || dto.getExpression() == null) {
return new ValDerivationNode(ai, null);
}

// substitute parameters in the body with the invocation arguments
Expression body = dto.getExpression().clone();
for (int i = 0; i < ai.getArgs().size() && i < dto.getVarNames().size(); i++) {
body = body.substitute(new Var(dto.getVarNames().get(i)), ai.getArgs().get(i));
}

// recursively expand the body
ValDerivationNode expandedBody = expandRecursive(new ValDerivationNode(body, null), aliases);
return new ValDerivationNode(ai, expandedBody);
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
package liquidjava.rj_language.opt;

import java.util.Map;

import liquidjava.processor.facade.AliasDTO;
import liquidjava.rj_language.ast.BinaryExpression;
import liquidjava.rj_language.ast.Expression;
import liquidjava.rj_language.ast.LiteralBoolean;
Expand All @@ -10,12 +13,18 @@
public class ExpressionSimplifier {

/**
* Simplifies an expression by applying constant propagation, constant folding and removing redundant conjuncts
* Returns a derivation node representing the tree of simplifications applied
* Simplifies an expression by applying constant propagation, constant folding, removing redundant conjuncts and
* expanding aliases Returns a derivation node representing the tree of simplifications applied
*/
public static ValDerivationNode simplify(Expression exp, Map<String, AliasDTO> aliases) {
ValDerivationNode node = new ValDerivationNode(exp, null);
ValDerivationNode fixedPoint = simplifyToFixedPoint(node, exp);
ValDerivationNode simplified = simplifyValDerivationNode(fixedPoint);
return AliasExpansion.expand(simplified, aliases);
}

public static ValDerivationNode simplify(Expression exp) {
ValDerivationNode fixedPoint = simplifyToFixedPoint(null, exp);
return simplifyValDerivationNode(fixedPoint);
return simplify(exp, Map.of());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

import static org.junit.jupiter.api.Assertions.*;

import java.util.List;
import java.util.Map;

import liquidjava.processor.facade.AliasDTO;
import liquidjava.rj_language.ast.AliasInvocation;
import liquidjava.rj_language.ast.BinaryExpression;
import liquidjava.rj_language.ast.Expression;
import liquidjava.rj_language.ast.LiteralBoolean;
Expand Down Expand Up @@ -550,6 +555,72 @@ void testTransitive() {
assertEquals("a == 1", result.getValue().toString(), "Expected result to be a == 1");
}

@Test
void testByteAliasExpansion() {
// Given: Byte(b) with alias Byte(int b) { b >= -128 && b <= 127 }
AliasDTO byteAlias = new AliasDTO("Byte", List.of("int"), List.of("b"), "b >= -128 && b <= 127");
byteAlias.parse("");
Map<String, AliasDTO> aliases = Map.of("Byte", byteAlias);
Expression exp = new AliasInvocation("Byte", List.of(new Var("b")));

// When
ValDerivationNode result = ExpressionSimplifier.simplify(exp, aliases);

// Then
assertEquals("Byte(b)", result.getValue().toString());
assertNotNull(result.getOrigin(), "Origin should contain the expanded body");
ValDerivationNode origin = (ValDerivationNode) result.getOrigin();
assertEquals("b >= -128 && b <= 127", origin.getValue().toString());
}

@Test
void testPositiveAliasExpansion() {
// Given: Positive(x) with alias Positive(int v) { v > 0 }
AliasDTO positiveAlias = new AliasDTO("Positive", List.of("int"), List.of("v"), "v > 0");
positiveAlias.parse("");
Map<String, AliasDTO> aliases = Map.of("Positive", positiveAlias);
Expression exp = new AliasInvocation("Positive", List.of(new Var("x")));

// When
ValDerivationNode result = ExpressionSimplifier.simplify(exp, aliases);

// Then
assertEquals("Positive(x)", result.getValue().toString());
assertNotNull(result.getOrigin(), "Origin should contain the expanded body");
ValDerivationNode origin = (ValDerivationNode) result.getOrigin();
assertEquals("x > 0", origin.getValue().toString());
}

@Test
void testTwoArgAliasWithNormalExpression() {
// Given: Bounded(v, 100) && v > 50 with alias Bounded(int x, int n) { x > 0 && x < n }
AliasDTO boundedAlias = new AliasDTO("Bounded", List.of("int", "int"), List.of("x", "n"), "x > 0 && x < n");
boundedAlias.parse("");
Map<String, AliasDTO> aliases = Map.of("Bounded", boundedAlias);

Expression varV = new Var("v");
Expression bounded = new AliasInvocation("Bounded", List.of(varV, new LiteralInt(100)));
Expression vGt50 = new BinaryExpression(varV, ">", new LiteralInt(50));
Expression fullExpression = new BinaryExpression(bounded, "&&", vGt50);

// When
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression, aliases);

// Then
assertEquals("Bounded(v, 100) && v > 50", result.getValue().toString());
assertInstanceOf(BinaryDerivationNode.class, result.getOrigin());
BinaryDerivationNode binOrigin = (BinaryDerivationNode) result.getOrigin();
assertEquals("&&", binOrigin.getOp());
ValDerivationNode leftNode = binOrigin.getLeft();
assertEquals("Bounded(v, 100)", leftNode.getValue().toString());
assertNotNull(leftNode.getOrigin(), "Alias invocation should have expanded body as origin");
ValDerivationNode expandedBody = (ValDerivationNode) leftNode.getOrigin();
assertEquals("v > 0 && v < 100", expandedBody.getValue().toString());
ValDerivationNode rightNode = binOrigin.getRight();
assertEquals("v > 50", rightNode.getValue().toString());
assertNull(rightNode.getOrigin());
}

/**
* Helper method to compare two derivation nodes recursively
*/
Expand Down
Loading