Skip to content

Commit 0366657

Browse files
committed
[CALCITE-6636] Support CNF condition of Arrow ArrowAdapter
1 parent ea7fb17 commit 0366657

9 files changed

Lines changed: 342 additions & 116 deletions

File tree

arrow/src/main/java/org/apache/calcite/adapter/arrow/ArrowFilter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
* relational expression in Arrow.
3535
*/
3636
class ArrowFilter extends Filter implements ArrowRel {
37-
private final List<String> match;
37+
private final List<List<ConditionToken>> match;
3838

3939
ArrowFilter(RelOptCluster cluster, RelTraitSet traitSet, RelNode input, RexNode condition) {
4040
super(cluster, traitSet, input, condition);

arrow/src/main/java/org/apache/calcite/adapter/arrow/ArrowRel.java

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,24 @@ public interface ArrowRel extends RelNode {
4141
* {@link ArrowRel} nodes into a SQL query. */
4242
class Implementor {
4343
@Nullable List<Integer> selectFields;
44-
final List<String> whereClause = new ArrayList<>();
44+
final List<List<ConditionToken>> whereClause = new ArrayList<>();
4545
@Nullable RelOptTable table;
4646
@Nullable ArrowTable arrowTable;
4747

4848
/** Adds new predicates.
4949
*
50-
* @param predicates Predicates
50+
* <p>The structure is two levels of nesting:
51+
* <ul>
52+
* <li>Outer list: conjunction (AND) of clauses
53+
* <li>Inner list: disjunction (OR) of conditions within a clause
54+
* </ul>
55+
*
56+
* <p>Each {@link ConditionToken} represents a single unary or binary
57+
* predicate condition.
58+
*
59+
* @param predicates Predicates in CNF form
5160
*/
52-
void addFilters(List<String> predicates) {
61+
void addFilters(List<List<ConditionToken>> predicates) {
5362
whereClause.addAll(predicates);
5463
}
5564

arrow/src/main/java/org/apache/calcite/adapter/arrow/ArrowRules.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
import org.apache.calcite.rel.logical.LogicalFilter;
3030
import org.apache.calcite.rel.logical.LogicalProject;
3131
import org.apache.calcite.rel.type.RelDataType;
32+
import org.apache.calcite.rex.RexNode;
33+
import org.apache.calcite.rex.RexUtil;
3234
import org.apache.calcite.sql.validate.SqlValidatorUtil;
3335

3436
import com.google.common.collect.ImmutableList;
@@ -97,9 +99,13 @@ protected ArrowFilterRule(Config config) {
9799
RelNode convert(Filter filter) {
98100
final RelTraitSet traitSet =
99101
filter.getTraitSet().replace(ArrowRel.CONVENTION);
102+
// Expand SEARCH (e.g. IN, BETWEEN) before pushing to Arrow,
103+
// since Gandiva does not support SEARCH natively.
104+
final RexNode condition =
105+
RexUtil.expandSearch(filter.getCluster().getRexBuilder(), null, filter.getCondition());
100106
return new ArrowFilter(filter.getCluster(), traitSet,
101107
convert(filter.getInput(), ArrowRel.CONVENTION),
102-
filter.getCondition());
108+
condition);
103109
}
104110

105111
/** Rule configuration. */

arrow/src/main/java/org/apache/calcite/adapter/arrow/ArrowTable.java

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ public class ArrowTable extends AbstractTable
9797
* {@link org.apache.calcite.adapter.arrow.ArrowMethod#ARROW_QUERY}. */
9898
@SuppressWarnings("unused")
9999
public Enumerable<Object> query(DataContext root, ImmutableIntList fields,
100-
List<String> conditions) {
100+
List<List<List<String>>> conditions) {
101101
requireNonNull(fields, "fields");
102102
final Projector projector;
103103
final Filter filter;
@@ -119,30 +119,26 @@ public Enumerable<Object> query(DataContext root, ImmutableIntList fields,
119119
} else {
120120
projector = null;
121121

122-
final List<TreeNode> conditionNodes = new ArrayList<>(conditions.size());
123-
for (String condition : conditions) {
124-
String[] data = condition.split(" ");
125-
List<TreeNode> treeNodes = new ArrayList<>(2);
126-
treeNodes.add(
127-
TreeBuilder.makeField(schema.getFields()
128-
.get(schema.getFields().indexOf(schema.findField(data[0])))));
129-
130-
// if the split condition has more than two parts it's a binary operator
131-
// with an additional literal node
132-
if (data.length > 2) {
133-
treeNodes.add(makeLiteralNode(data[2], data[3]));
122+
final List<TreeNode> conjuncts = new ArrayList<>(conditions.size());
123+
for (List<List<String>> orGroup : conditions) {
124+
final List<TreeNode> disjuncts = new ArrayList<>(orGroup.size());
125+
for (List<String> conditionParts : orGroup) {
126+
disjuncts.add(
127+
convertConditionToGandiva(
128+
ConditionToken.fromTokenList(conditionParts)));
129+
}
130+
if (disjuncts.size() == 1) {
131+
conjuncts.add(disjuncts.get(0));
132+
} else {
133+
conjuncts.add(TreeBuilder.makeOr(disjuncts));
134134
}
135-
136-
String operator = data[1];
137-
conditionNodes.add(
138-
TreeBuilder.makeFunction(operator, treeNodes, new ArrowType.Bool()));
139135
}
140136
final Condition filterCondition;
141-
if (conditionNodes.size() == 1) {
142-
filterCondition = TreeBuilder.makeCondition(conditionNodes.get(0));
137+
if (conjuncts.size() == 1) {
138+
filterCondition = TreeBuilder.makeCondition(conjuncts.get(0));
143139
} else {
144-
TreeNode treeNode = TreeBuilder.makeAnd(conditionNodes);
145-
filterCondition = TreeBuilder.makeCondition(treeNode);
140+
filterCondition =
141+
TreeBuilder.makeCondition(TreeBuilder.makeAnd(conjuncts));
146142
}
147143

148144
try {
@@ -184,6 +180,26 @@ private static RelDataType deduceRowType(Schema schema,
184180
return builder.build();
185181
}
186182

183+
/** Converts a single {@link ConditionToken} into a Gandiva {@link TreeNode}. */
184+
private TreeNode convertConditionToGandiva(ConditionToken token) {
185+
final List<TreeNode> treeNodes = new ArrayList<>(2);
186+
treeNodes.add(
187+
TreeBuilder.makeField(schema.getFields()
188+
.get(
189+
schema.getFields().indexOf(
190+
schema.findField(token.fieldName)))));
191+
192+
if (token.isBinary()) {
193+
treeNodes.add(
194+
makeLiteralNode(
195+
requireNonNull(token.value, "value"),
196+
requireNonNull(token.valueType, "valueType")));
197+
}
198+
199+
return TreeBuilder.makeFunction(
200+
token.operator, treeNodes, new ArrowType.Bool());
201+
}
202+
187203
private static TreeNode makeLiteralNode(String literal, String type) {
188204
if (type.startsWith("decimal")) {
189205
String[] typeParts =

arrow/src/main/java/org/apache/calcite/adapter/arrow/ArrowToEnumerableConverter.java

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
import com.google.common.primitives.Ints;
3737

38+
import java.util.ArrayList;
3839
import java.util.List;
3940

4041
import static java.util.Objects.requireNonNull;
@@ -84,6 +85,23 @@ protected ArrowToEnumerableConverter(RelOptCluster cluster,
8485
: Expressions.call(
8586
BuiltInMethod.IMMUTABLE_INT_LIST_IDENTITY.method,
8687
Expressions.constant(fieldCount)),
87-
Expressions.constant(arrowImplementor.whereClause))));
88+
Expressions.constant(
89+
toTokenLists(arrowImplementor.whereClause)))));
90+
}
91+
92+
/** Converts structured {@link ConditionToken} conditions to nested string
93+
* lists for serialization through {@link Expressions#constant}. */
94+
private static List<List<List<String>>> toTokenLists(
95+
List<List<ConditionToken>> conditions) {
96+
final List<List<List<String>>> result =
97+
new ArrayList<>(conditions.size());
98+
for (List<ConditionToken> orGroup : conditions) {
99+
final List<List<String>> group = new ArrayList<>(orGroup.size());
100+
for (ConditionToken token : orGroup) {
101+
group.add(token.toTokenList());
102+
}
103+
result.add(group);
104+
}
105+
return result;
88106
}
89107
}

arrow/src/main/java/org/apache/calcite/adapter/arrow/ArrowTranslator.java

Lines changed: 42 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
import static java.util.Objects.requireNonNull;
4242

4343
/**
44-
* Translates a {@link RexNode} expression to a Gandiva string.
44+
* Translates a {@link RexNode} expression to Gandiva predicate tokens.
4545
*/
4646
class ArrowTranslator {
4747
final RexBuilder rexBuilder;
@@ -61,13 +61,30 @@ public static ArrowTranslator create(RexBuilder rexBuilder,
6161
return new ArrowTranslator(rexBuilder, rowType);
6262
}
6363

64-
List<String> translateMatch(RexNode condition) {
65-
List<RexNode> disjunctions = RelOptUtil.disjunctions(condition);
66-
if (disjunctions.size() == 1) {
67-
return translateAnd(disjunctions.get(0));
68-
} else {
69-
throw new UnsupportedOperationException("Unsupported disjunctive condition " + condition);
64+
/** The maximum number of nodes allowed during CNF conversion.
65+
*
66+
* <p>If exceeded, {@link RexUtil#toCnf(RexBuilder, int, RexNode)} returns
67+
* the original expression unchanged, which may cause the subsequent
68+
* translation to Gandiva predicates to fail with an
69+
* {@link UnsupportedOperationException}. When invoked by the Arrow adapter
70+
* module, the exception is caught and the plan falls back to
71+
* an Enumerable convention. */
72+
private static final int MAX_CNF_NODE_COUNT = 256;
73+
74+
List<List<ConditionToken>> translateMatch(RexNode condition) {
75+
// Convert to CNF; SEARCH nodes are already expanded
76+
// by ArrowFilterRule before reaching here.
77+
final RexNode cnf = RexUtil.toCnf(rexBuilder, MAX_CNF_NODE_COUNT, condition);
78+
79+
final List<List<ConditionToken>> result = new ArrayList<>();
80+
for (RexNode conjunct : RelOptUtil.conjunctions(cnf)) {
81+
final List<ConditionToken> orGroup = new ArrayList<>();
82+
for (RexNode disjunct : RelOptUtil.disjunctions(conjunct)) {
83+
orGroup.add(translateMatch2(disjunct));
84+
}
85+
result.add(orGroup);
7086
}
87+
return result;
7188
}
7289

7390
/**
@@ -93,34 +110,14 @@ private static Object literalValue(RexLiteral literal) {
93110
}
94111
}
95112

96-
/**
97-
* Translate a conjunctive predicate to a SQL string.
98-
*
99-
* @param condition A conjunctive predicate
100-
*
101-
* @return SQL string for the predicate
102-
*/
103-
private List<String> translateAnd(RexNode condition) {
104-
List<String> predicates = new ArrayList<>();
105-
for (RexNode node : RelOptUtil.conjunctions(condition)) {
106-
if (node.getKind() == SqlKind.SEARCH) {
107-
final RexNode node2 = RexUtil.expandSearch(rexBuilder, null, node);
108-
predicates.addAll(translateMatch(node2));
109-
} else {
110-
predicates.add(translateMatch2(node));
111-
}
112-
}
113-
return predicates;
114-
}
115-
116113
/**
117114
* Translates a binary or unary relation.
118115
*
119116
* @param node A RexNode that always evaluates to a boolean expression.
120117
* Currently, this method is only called from translateAnd.
121-
* @return The translated SQL string for the relation.
118+
* @return The translated condition token for the relation.
122119
*/
123-
private String translateMatch2(RexNode node) {
120+
private ConditionToken translateMatch2(RexNode node) {
124121
switch (node.getKind()) {
125122
case EQUALS:
126123
return translateBinary("equal", "=", (RexCall) node);
@@ -144,7 +141,7 @@ private String translateMatch2(RexNode node) {
144141
return translateUnary("isnotfalse", (RexCall) node);
145142
case INPUT_REF:
146143
final RexInputRef inputRef = (RexInputRef) node;
147-
return fieldNames.get(inputRef.getIndex()) + " istrue";
144+
return ConditionToken.unary(fieldNames.get(inputRef.getIndex()), "istrue");
148145
case NOT:
149146
return translateUnary("isfalse", (RexCall) node);
150147
default:
@@ -156,10 +153,10 @@ private String translateMatch2(RexNode node) {
156153
* Translates a call to a binary operator, reversing arguments if
157154
* necessary.
158155
*/
159-
private String translateBinary(String op, String rop, RexCall call) {
156+
private ConditionToken translateBinary(String op, String rop, RexCall call) {
160157
final RexNode left = call.operands.get(0);
161158
final RexNode right = call.operands.get(1);
162-
@Nullable String expression = translateBinary2(op, left, right);
159+
@Nullable ConditionToken expression = translateBinary2(op, left, right);
163160
if (expression != null) {
164161
return expression;
165162
}
@@ -171,7 +168,8 @@ private String translateBinary(String op, String rop, RexCall call) {
171168
}
172169

173170
/** Translates a call to a binary operator. Returns null on failure. */
174-
private @Nullable String translateBinary2(String op, RexNode left, RexNode right) {
171+
private @Nullable ConditionToken translateBinary2(String op, RexNode left,
172+
RexNode right) {
175173
if (right.getKind() != SqlKind.LITERAL) {
176174
return null;
177175
}
@@ -189,26 +187,29 @@ private String translateBinary(String op, String rop, RexCall call) {
189187
}
190188
}
191189

192-
/** Combines a field name, operator, and literal to produce a predicate string. */
193-
private String translateOp2(String op, String name, RexLiteral right) {
190+
/** Combines a field name, operator, and literal to produce a binary
191+
* condition token. */
192+
private ConditionToken translateOp2(String op, String name,
193+
RexLiteral right) {
194194
Object value = literalValue(right);
195195
String valueString = value.toString();
196196
String valueType = getLiteralType(right.getType());
197197

198198
if (value instanceof String) {
199-
final RelDataTypeField field = requireNonNull(rowType.getField(name, true, false), "field");
199+
final RelDataTypeField field =
200+
requireNonNull(rowType.getField(name, true, false), "field");
200201
SqlTypeName typeName = field.getType().getSqlTypeName();
201202
if (typeName != SqlTypeName.CHAR) {
202203
valueString = "'" + valueString + "'";
203204
}
204205
}
205-
return name + " " + op + " " + valueString + " " + valueType;
206+
return ConditionToken.binary(name, op, valueString, valueType);
206207
}
207208

208209
/** Translates a call to a unary operator. */
209-
private String translateUnary(String op, RexCall call) {
210+
private ConditionToken translateUnary(String op, RexCall call) {
210211
final RexNode opNode = call.operands.get(0);
211-
@Nullable String expression = translateUnary2(op, opNode);
212+
@Nullable ConditionToken expression = translateUnary2(op, opNode);
212213

213214
if (expression != null) {
214215
return expression;
@@ -218,21 +219,16 @@ private String translateUnary(String op, RexCall call) {
218219
}
219220

220221
/** Translates a call to a unary operator. Returns null on failure. */
221-
private @Nullable String translateUnary2(String op, RexNode opNode) {
222+
private @Nullable ConditionToken translateUnary2(String op, RexNode opNode) {
222223
if (opNode.getKind() == SqlKind.INPUT_REF) {
223224
final RexInputRef inputRef = (RexInputRef) opNode;
224225
final String name = fieldNames.get(inputRef.getIndex());
225-
return translateUnaryOp(op, name);
226+
return ConditionToken.unary(name, op);
226227
}
227228

228229
return null;
229230
}
230231

231-
/** Combines a field name and a unary operator to produce a predicate string. */
232-
private static String translateUnaryOp(String op, String name) {
233-
return name + " " + op;
234-
}
235-
236232
private static String getLiteralType(RelDataType type) {
237233
if (type.getSqlTypeName() == SqlTypeName.DECIMAL) {
238234
return "decimal" + "(" + type.getPrecision() + "," + type.getScale() + ")";

0 commit comments

Comments
 (0)