Skip to content

Commit 2bef8da

Browse files
timsaucerclaude
andcommitted
refactor: register physical optimizer rules via live add method
Drop the `physical_optimizer_rules` constructor argument on `SessionContext` and replace it with `add_physical_optimizer_rule`, matching the existing `register_*` shape on the same class. The new method rebuilds the session state via `SessionStateBuilder::new_from_existing` so previously registered tables, UDFs, and catalogs are preserved. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent a9621ee commit 2bef8da

3 files changed

Lines changed: 43 additions & 28 deletions

File tree

crates/core/src/context.rs

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -376,12 +376,11 @@ pub struct PySessionContext {
376376

377377
#[pymethods]
378378
impl PySessionContext {
379-
#[pyo3(signature = (config=None, runtime=None, physical_optimizer_rules=None))]
379+
#[pyo3(signature = (config=None, runtime=None))]
380380
#[new]
381381
pub fn new(
382382
config: Option<PySessionConfig>,
383383
runtime: Option<PyRuntimeEnvBuilder>,
384-
physical_optimizer_rules: Option<Vec<Bound<'_, PyAny>>>,
385384
) -> PyDataFusionResult<Self> {
386385
let config = if let Some(c) = config {
387386
c.config
@@ -394,15 +393,11 @@ impl PySessionContext {
394393
RuntimeEnvBuilder::default()
395394
};
396395
let runtime = Arc::new(runtime_env_builder.build()?);
397-
let mut state_builder = SessionStateBuilder::new()
396+
let session_state = SessionStateBuilder::new()
398397
.with_config(config)
399398
.with_runtime_env(runtime)
400-
.with_default_features();
401-
for rule in physical_optimizer_rules.unwrap_or_default() {
402-
let rule = physical_optimizer_rule_from_pycapsule(&rule)?;
403-
state_builder = state_builder.with_physical_optimizer_rule(rule);
404-
}
405-
let session_state = state_builder.build();
399+
.with_default_features()
400+
.build();
406401
let ctx = Arc::new(SessionContext::new_with_state(session_state));
407402
Ok(PySessionContext {
408403
ctx,
@@ -1151,6 +1146,17 @@ impl PySessionContext {
11511146
self.ctx.remove_optimizer_rule(name)
11521147
}
11531148

1149+
pub fn add_physical_optimizer_rule(&self, rule: Bound<'_, PyAny>) -> PyDataFusionResult<()> {
1150+
let rule = physical_optimizer_rule_from_pycapsule(&rule)?;
1151+
let state_ref = self.ctx.state_ref();
1152+
let mut guard = state_ref.write();
1153+
let new_state = SessionStateBuilder::new_from_existing(guard.clone())
1154+
.with_physical_optimizer_rule(rule)
1155+
.build();
1156+
*guard = new_state;
1157+
Ok(())
1158+
}
1159+
11541160
pub fn table_provider(&self, name: &str, py: Python) -> PyResult<PyTable> {
11551161
let provider = wait_for_future(py, self.ctx.table_provider(name))
11561162
// Outer error: runtime/async failure

examples/datafusion-ffi-example/python/tests/_test_physical_optimizer_rule.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@
2323

2424

2525
def test_ffi_physical_optimizer_rule_runs_during_planning():
26-
"""A rule supplied via physical_optimizer_rules is invoked while the
26+
"""A rule added via add_physical_optimizer_rule is invoked while the
2727
physical plan is built, and the query still returns correct results."""
2828
rule = MyPhysicalOptimizerRule()
29-
ctx = SessionContext(physical_optimizer_rules=[rule])
29+
ctx = SessionContext()
30+
ctx.add_physical_optimizer_rule(rule)
3031
batch = pa.RecordBatch.from_arrays(
3132
[pa.array([1, 2, 3])],
3233
names=["a"],

python/datafusion/context.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,6 @@ def __init__(
534534
self,
535535
config: SessionConfig | None = None,
536536
runtime: RuntimeEnvBuilder | None = None,
537-
physical_optimizer_rules: list[PhysicalOptimizerRuleExportable] | None = None,
538537
) -> None:
539538
"""Main interface for executing queries with DataFusion.
540539
@@ -545,11 +544,6 @@ def __init__(
545544
Args:
546545
config: Session configuration options.
547546
runtime: Runtime configuration options.
548-
physical_optimizer_rules: User-defined physical optimizer rules to
549-
append to the default set, each a
550-
:class:`PhysicalOptimizerRuleExportable`. There is no upstream
551-
API to add physical rules to a live context, so these can only
552-
be supplied at construction time.
553547
554548
Example usage:
555549
@@ -560,21 +554,11 @@ def __init__(
560554
561555
ctx = SessionContext()
562556
df = ctx.read_csv("data.csv")
563-
564-
To register a physical optimizer rule supplied by a compiled
565-
extension, pass it via ``physical_optimizer_rules``::
566-
567-
from datafusion import SessionContext
568-
from my_extension import MyPhysicalOptimizerRule
569-
570-
ctx = SessionContext(
571-
physical_optimizer_rules=[MyPhysicalOptimizerRule()]
572-
)
573557
"""
574558
config = config.config_internal if config is not None else None
575559
runtime = runtime.config_internal if runtime is not None else None
576560

577-
self.ctx = SessionContextInternal(config, runtime, physical_optimizer_rules)
561+
self.ctx = SessionContextInternal(config, runtime)
578562

579563
def __repr__(self) -> str:
580564
"""Print a string representation of the Session Context."""
@@ -1404,6 +1388,30 @@ def remove_optimizer_rule(self, name: str) -> bool:
14041388
"""
14051389
return self.ctx.remove_optimizer_rule(name)
14061390

1391+
def add_physical_optimizer_rule(
1392+
self, rule: PhysicalOptimizerRuleExportable
1393+
) -> None:
1394+
"""Append a user-defined physical optimizer rule to the session.
1395+
1396+
The rule is imported via its ``__datafusion_physical_optimizer_rule__``
1397+
PyCapsule, typically produced by a separate compiled extension. The
1398+
underlying :class:`SessionState` is rebuilt from its current state
1399+
with the new rule appended, so previously registered tables, UDFs,
1400+
and catalogs are preserved.
1401+
1402+
Args:
1403+
rule: Object exposing ``__datafusion_physical_optimizer_rule__``,
1404+
a :class:`PhysicalOptimizerRuleExportable`.
1405+
1406+
Examples:
1407+
>>> from datafusion import SessionContext
1408+
>>> ctx = SessionContext()
1409+
>>> from my_extension import MyPhysicalOptimizerRule # doctest: +SKIP
1410+
>>> rule = MyPhysicalOptimizerRule() # doctest: +SKIP
1411+
>>> ctx.add_physical_optimizer_rule(rule) # doctest: +SKIP
1412+
"""
1413+
self.ctx.add_physical_optimizer_rule(rule)
1414+
14071415
def table_provider(self, name: str) -> Table:
14081416
"""Return the :py:class:`~datafusion.catalog.Table` for the given table name.
14091417

0 commit comments

Comments
 (0)