Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import liquidjava.rj_language.ast.BinaryExpression;
import liquidjava.rj_language.ast.Expression;
import liquidjava.rj_language.ast.GroupExpression;
import liquidjava.rj_language.ast.Ite;
import liquidjava.rj_language.ast.LiteralBoolean;
import liquidjava.rj_language.ast.LiteralInt;
import liquidjava.rj_language.ast.LiteralReal;
import liquidjava.rj_language.ast.UnaryExpression;
import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode;
import liquidjava.rj_language.opt.derivation_node.DerivationNode;
import liquidjava.rj_language.opt.derivation_node.IteDerivationNode;
import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode;
import liquidjava.rj_language.opt.derivation_node.ValDerivationNode;

Expand All @@ -26,6 +28,9 @@ public static ValDerivationNode fold(ValDerivationNode node) {
if (exp instanceof UnaryExpression)
return foldUnary(node);

if (exp instanceof Ite)
return foldIte(node);

if (exp instanceof GroupExpression group) {
if (group.getChildren().size() == 1) {
return fold(new ValDerivationNode(group.getChildren().get(0), node.getOrigin()));
Expand Down Expand Up @@ -191,4 +196,45 @@ private static ValDerivationNode foldUnary(ValDerivationNode node) {
DerivationNode origin = operandNode.getOrigin() != null ? new UnaryDerivationNode(operandNode, operator) : null;
return new ValDerivationNode(unaryExp, origin);
}

/**
* Folds ternary expressions by checking if condition is a boolean literal or both branches are the same
*/
private static ValDerivationNode foldIte(ValDerivationNode node) {
Ite iteExp = (Ite) node.getValue();

ValDerivationNode condNode = fold(new ValDerivationNode(iteExp.getCondition(), null));
ValDerivationNode thenNode = fold(new ValDerivationNode(iteExp.getThen(), null));
ValDerivationNode elseNode = fold(new ValDerivationNode(iteExp.getElse(), null));

Expression condition = condNode.getValue();
Expression thenExp = thenNode.getValue();
Expression elseExp = elseNode.getValue();

iteExp.setChild(0, condition);
iteExp.setChild(1, thenExp);
iteExp.setChild(2, elseExp);

// if condition is a boolean literal, select the corresponding branch: true ? a : b => a, false ? a : b => b
if (condition instanceof LiteralBoolean boolCond) {
Expression selected = boolCond.isBooleanTrue() ? thenExp : elseExp;
DerivationNode origin = new IteDerivationNode(condNode, thenNode, elseNode);
return new ValDerivationNode(selected, origin);
}

// if both branches are the same, return one of them (e.g. cond ? b : b => b)
if (thenExp.equals(elseExp)) {
DerivationNode origin = new IteDerivationNode(condNode, thenNode, elseNode);
return new ValDerivationNode(thenExp, origin);
}

// no folding, but keep track of the folding steps in the origin
DerivationNode origin = hasIteChildOrigin(condNode, thenNode, elseNode)
? new IteDerivationNode(condNode, thenNode, elseNode) : node.getOrigin();
return new ValDerivationNode(iteExp, origin);
}

private static boolean hasIteChildOrigin(ValDerivationNode cond, ValDerivationNode then, ValDerivationNode els) {
return cond.getOrigin() != null || then.getOrigin() != null || els.getOrigin() != null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import liquidjava.rj_language.ast.Var;
import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode;
import liquidjava.rj_language.opt.derivation_node.DerivationNode;
import liquidjava.rj_language.opt.derivation_node.IteDerivationNode;
import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode;
import liquidjava.rj_language.opt.derivation_node.ValDerivationNode;
import liquidjava.rj_language.opt.derivation_node.VarDerivationNode;
Expand Down Expand Up @@ -134,6 +135,10 @@ private static void extractVarOrigins(ValDerivationNode node, Map<String, Deriva
extractVarOrigins(binOrigin.getRight(), varOrigins);
} else if (origin instanceof UnaryDerivationNode unaryOrigin) {
extractVarOrigins(unaryOrigin.getOperand(), varOrigins);
} else if (origin instanceof IteDerivationNode iteOrigin) {
extractVarOrigins(iteOrigin.getCondition(), varOrigins);
extractVarOrigins(iteOrigin.getThenBranch(), varOrigins);
extractVarOrigins(iteOrigin.getElseBranch(), varOrigins);
} else if (origin instanceof ValDerivationNode valOrigin) {
extractVarOrigins(valOrigin, varOrigins);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package liquidjava.rj_language.opt.derivation_node;

public class IteDerivationNode extends DerivationNode {

private final ValDerivationNode condition;
private final ValDerivationNode thenBranch;
private final ValDerivationNode elseBranch;

public IteDerivationNode(ValDerivationNode condition, ValDerivationNode thenBranch, ValDerivationNode elseBranch) {
this.condition = condition;
this.thenBranch = thenBranch;
this.elseBranch = elseBranch;
}

public ValDerivationNode getCondition() {
return condition;
}

public ValDerivationNode getThenBranch() {
return thenBranch;
}

public ValDerivationNode getElseBranch() {
return elseBranch;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@

import liquidjava.rj_language.ast.BinaryExpression;
import liquidjava.rj_language.ast.Expression;
import liquidjava.rj_language.ast.Ite;
import liquidjava.rj_language.ast.LiteralBoolean;
import liquidjava.rj_language.ast.LiteralInt;
import liquidjava.rj_language.ast.UnaryExpression;
import liquidjava.rj_language.ast.Var;
import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode;
import liquidjava.rj_language.opt.derivation_node.DerivationNode;
import liquidjava.rj_language.opt.derivation_node.IteDerivationNode;
import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode;
import liquidjava.rj_language.opt.derivation_node.ValDerivationNode;
import liquidjava.rj_language.opt.derivation_node.VarDerivationNode;
Expand Down Expand Up @@ -550,6 +552,76 @@ void testTransitive() {
assertEquals("a == 1", result.getValue().toString(), "Expected result to be a == 1");
}

@Test
void testIteTrueConditionSimplifiesToThenBranch() {
// Given: true ? a : b
// Expected: a

Expression expr = new Ite(new LiteralBoolean(true), new Var("a"), new Var("b"));

// When
ValDerivationNode result = ExpressionSimplifier.simplify(expr);

// Then
assertNotNull(result, "Result should not be null");
assertEquals("a", result.getValue().toString(), "Expected result to be a");

ValDerivationNode conditionNode = new ValDerivationNode(new LiteralBoolean(true), null);
ValDerivationNode thenNode = new ValDerivationNode(new Var("a"), null);
ValDerivationNode elseNode = new ValDerivationNode(new Var("b"), null);
IteDerivationNode iteOrigin = new IteDerivationNode(conditionNode, thenNode, elseNode);
ValDerivationNode expected = new ValDerivationNode(new Var("a"), iteOrigin);

assertDerivationEquals(expected, result, "");
}

@Test
void testIteFalseConditionSimplifiesToElseBranch() {
// Given: false ? a : b
// Expected: b

Expression expr = new Ite(new LiteralBoolean(false), new Var("a"), new Var("b"));

// When
ValDerivationNode result = ExpressionSimplifier.simplify(expr);

// Then
assertNotNull(result, "Result should not be null");
assertEquals("b", result.getValue().toString(), "Expected result to be b");

ValDerivationNode conditionNode = new ValDerivationNode(new LiteralBoolean(false), null);
ValDerivationNode thenNode = new ValDerivationNode(new Var("a"), null);
ValDerivationNode elseNode = new ValDerivationNode(new Var("b"), null);
IteDerivationNode iteOrigin = new IteDerivationNode(conditionNode, thenNode, elseNode);
ValDerivationNode expected = new ValDerivationNode(new Var("b"), iteOrigin);

assertDerivationEquals(expected, result, "");
}

@Test
void testIteEqualBranchesSimplifiesToBranch() {
// Given: cond ? b : b
// Expected: b

Expression branch = new Var("b");
Expression expr = new Ite(new Var("cond"), branch, branch.clone());

// When
ValDerivationNode result = ExpressionSimplifier.simplify(expr);

// Then
assertNotNull(result, "Result should not be null");
assertEquals("b", result.getValue().toString(), "Expected result to be b");

ValDerivationNode conditionNode = new ValDerivationNode(new Var("cond"), null);
ValDerivationNode thenNode = new ValDerivationNode(new Var("b"), null);
ValDerivationNode elseNode = new ValDerivationNode(new Var("b"), null);
IteDerivationNode iteOrigin = new IteDerivationNode(conditionNode, thenNode, elseNode);
ValDerivationNode expected = new ValDerivationNode(new Var("b"), iteOrigin);

assertDerivationEquals(expected, result, "");
}

/**
* Helper method to compare two derivation nodes recursively
*/
Expand All @@ -576,6 +648,11 @@ private void assertDerivationEquals(DerivationNode expected, DerivationNode actu
UnaryDerivationNode actualUnary = (UnaryDerivationNode) actual;
assertEquals(expectedUnary.getOp(), actualUnary.getOp(), message + ": operators should match");
assertDerivationEquals(expectedUnary.getOperand(), actualUnary.getOperand(), message + " > operand");
} else if (expected instanceof IteDerivationNode expectedIte) {
IteDerivationNode actualIte = (IteDerivationNode) actual;
assertDerivationEquals(expectedIte.getCondition(), actualIte.getCondition(), message + " > condition");
assertDerivationEquals(expectedIte.getThenBranch(), actualIte.getThenBranch(), message + " > then");
assertDerivationEquals(expectedIte.getElseBranch(), actualIte.getElseBranch(), message + " > else");
}
}
}
Loading