diff --git a/bindings/python/python/pypaimon_rust/datafusion.pyi b/bindings/python/python/pypaimon_rust/datafusion.pyi index cd126099..4c14de64 100644 --- a/bindings/python/python/pypaimon_rust/datafusion.pyi +++ b/bindings/python/python/pypaimon_rust/datafusion.pyi @@ -57,6 +57,20 @@ class Table: def location(self) -> str: ... def schema(self) -> TableSchema: ... def new_read_builder(self) -> ReadBuilder: ... + def new_write_builder(self) -> "WriteBuilder": ... + +class CommitMessage: ... + +class TableWrite: + def write_arrow(self, batch: pyarrow.RecordBatch) -> None: ... + def prepare_commit(self) -> List[CommitMessage]: ... + +class TableCommit: + def commit(self, messages: Sequence[CommitMessage]) -> None: ... + +class WriteBuilder: + def new_write(self) -> TableWrite: ... + def new_commit(self) -> TableCommit: ... class PaimonCatalog: def __init__(self, catalog_options: Dict[str, str]) -> None: ... diff --git a/bindings/python/src/context.rs b/bindings/python/src/context.rs index 0880302d..d61b94a9 100644 --- a/bindings/python/src/context.rs +++ b/bindings/python/src/context.rs @@ -277,6 +277,10 @@ pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> this.add_class::()?; this.add_class::()?; this.add_class::()?; + this.add_class::()?; + this.add_class::()?; + this.add_class::()?; + this.add_class::()?; this.add_function(wrap_pyfunction!(udf, &this)?)?; m.add_submodule(&this)?; py.import("sys")? diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index c4db29af..dce23d3d 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -24,6 +24,7 @@ mod read; mod schema; mod table; mod udf; +mod write; #[pymodule] fn pypaimon_rust(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { diff --git a/bindings/python/src/table.rs b/bindings/python/src/table.rs index 263c821c..2864b02a 100644 --- a/bindings/python/src/table.rs +++ b/bindings/python/src/table.rs @@ -21,6 +21,7 @@ use pyo3::prelude::*; use crate::read::PyReadBuilder; use crate::schema::PyTableSchema; +use crate::write::PyWriteBuilder; #[pyclass(name = "Table", module = "pypaimon_rust.datafusion")] pub struct PyTable { @@ -52,4 +53,9 @@ impl PyTable { fn new_read_builder(&self) -> PyReadBuilder { PyReadBuilder::new(Arc::clone(&self.inner)) } + + /// Create a [`PyWriteBuilder`] for the batch write loop. + fn new_write_builder(&self) -> PyWriteBuilder { + PyWriteBuilder::new(Arc::clone(&self.inner)) + } } diff --git a/bindings/python/src/write.rs b/bindings/python/src/write.rs new file mode 100644 index 00000000..2b604296 --- /dev/null +++ b/bindings/python/src/write.rs @@ -0,0 +1,225 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::datatypes::Schema as ArrowSchema; +use arrow::pyarrow::FromPyArrow; +use arrow::record_batch::RecordBatch; +use paimon::table::{CommitMessage, Table, TableCommit, TableWrite}; +use paimon_datafusion::runtime::runtime; +use pyo3::exceptions::{PyTypeError, PyValueError}; +use pyo3::prelude::*; + +use crate::error::to_py_err; + +/// Validate an incoming batch schema against the table's target Arrow schema: +/// field count, order, and names must match, and types must match exactly. The +/// nullable flag is intentionally NOT compared, since `build_target_arrow_schema` +/// derives nullability from the Paimon field while pyarrow-constructed batches +/// infer nullable=true. No cast — callers supply correctly-typed batches. +/// +/// Type matching is strict (no binary-family interchange): the lower write path +/// downcasts to the exact Arrow array for each Paimon type (e.g. a `Binary` / +/// `VarBinary` field requires `arrow_array::BinaryArray`, not `LargeBinary` / +/// `FixedSizeBinary`). Accepting a near-equivalent type here would pass +/// validation but then fail deeper with a type-mismatch (or write files whose +/// Arrow schema differs from the table), so it is rejected up front. +fn validate_batch_schema(input: &ArrowSchema, target: &ArrowSchema) -> PyResult<()> { + let mismatch = || { + PyValueError::new_err(format!( + "Input schema is not consistent with the table schema. \ + input: {input:?}, table: {target:?}" + )) + }; + if input.fields().len() != target.fields().len() { + return Err(mismatch()); + } + for (i, t) in input.fields().iter().zip(target.fields().iter()) { + if i.name() != t.name() { + return Err(mismatch()); + } + if i.data_type() != t.data_type() { + return Err(mismatch()); + } + } + Ok(()) +} + +/// Builder for the batch write loop, created via [`crate::table::PyTable::new_write_builder`]. +/// +/// Holds the owning table plus a single fixed `commit_user`, generated once and +/// shared by both `new_write()` and `new_commit()` so that writers and the +/// committer agree on the commit user (Paimon uses it for duplicate-commit +/// detection). Creating a fresh `WriteBuilder` per call would otherwise mint a +/// new random UUID each time. +#[pyclass(name = "WriteBuilder", module = "pypaimon_rust.datafusion")] +pub struct PyWriteBuilder { + table: Arc, + commit_user: String, +} + +impl PyWriteBuilder { + pub fn new(table: Arc
) -> Self { + let commit_user = table.new_write_builder().commit_user().to_string(); + Self { table, commit_user } + } +} + +#[pymethods] +impl PyWriteBuilder { + /// Create a writer for accumulating Arrow batches. + fn new_write(&self) -> PyResult { + let builder = self + .table + .new_write_builder() + .with_commit_user(self.commit_user.clone()) + .map_err(to_py_err)?; + let target_schema = paimon::arrow::build_target_arrow_schema(self.table.schema().fields()) + .map_err(to_py_err)?; + Ok(PyTableWrite { + inner: builder.new_write().map_err(to_py_err)?, + target_schema, + table_location: self.table.location().to_string(), + commit_user: self.commit_user.clone(), + }) + } + + /// Create a committer for persisting prepared commit messages. + fn new_commit(&self) -> PyResult { + let builder = self + .table + .new_write_builder() + .with_commit_user(self.commit_user.clone()) + .map_err(to_py_err)?; + Ok(PyTableCommit { + inner: builder.new_commit(), + table_location: self.table.location().to_string(), + commit_user: self.commit_user.clone(), + }) + } +} + +/// A stateful writer that accumulates Arrow batches until `prepare_commit`. +/// +/// Marked `unsendable`: the underlying `TableWrite` holds file writers that are +/// not `Sync`, so the object enforces single-thread access at runtime. +#[pyclass(name = "TableWrite", module = "pypaimon_rust.datafusion", unsendable)] +pub struct PyTableWrite { + inner: TableWrite, + /// The table's target Arrow schema, used to validate incoming batches. + target_schema: Arc, + /// The owning table's location, stamped onto produced commit messages so a + /// committer can reject messages prepared for a different table. + table_location: String, + /// The originating builder's `commit_user`, stamped onto produced messages so + /// a committer can reject messages prepared by a different `WriteBuilder` + /// (writers and committers from the same builder must share one commit_user; + /// it drives snapshot duplicate detection and postpone-bucket file naming). + commit_user: String, +} + +#[pymethods] +impl PyTableWrite { + /// Write a single PyArrow RecordBatch into the table's writers. + fn write_arrow(&mut self, py: Python<'_>, batch: &Bound<'_, PyAny>) -> PyResult<()> { + let batch = RecordBatch::from_pyarrow_bound(batch)?; + validate_batch_schema(&batch.schema(), &self.target_schema)?; + let rt = runtime(); + py.detach(|| rt.block_on(async { self.inner.write_arrow_batch(&batch).await })) + .map_err(to_py_err) + } + + /// Close writers and return the commit messages (opaque; pass to commit()). + fn prepare_commit(&mut self, py: Python<'_>) -> PyResult> { + let rt = runtime(); + let messages = py + .detach(|| rt.block_on(async { self.inner.prepare_commit().await })) + .map_err(to_py_err)?; + Ok(messages + .into_iter() + .map(|inner| PyCommitMessage { + inner, + table_location: self.table_location.clone(), + commit_user: self.commit_user.clone(), + }) + .collect()) + } +} + +/// A committer that persists prepared commit messages as a snapshot. +#[pyclass(name = "TableCommit", module = "pypaimon_rust.datafusion")] +pub struct PyTableCommit { + inner: TableCommit, + /// The owning table's location, used to reject commit messages that were + /// prepared for a different table (which would otherwise persist a snapshot + /// referencing data files written under another table). + table_location: String, + /// The committer's `commit_user`, used to reject messages prepared by a + /// different `WriteBuilder` — even for the same table — since the writer and + /// committer must share one commit_user. + commit_user: String, +} + +#[pymethods] +impl PyTableCommit { + /// Commit the given commit messages. Empty input is a no-op success. + fn commit(&self, py: Python<'_>, messages: &Bound<'_, PyAny>) -> PyResult<()> { + let mut inner_messages = Vec::new(); + let iter = messages.try_iter().map_err(|_| { + PyTypeError::new_err("commit() expects a sequence of CommitMessage objects") + })?; + for item in iter { + let item = item?; + let msg: PyRef = item.extract().map_err(|_| { + PyTypeError::new_err("commit() expects a sequence of CommitMessage objects") + })?; + if msg.table_location != self.table_location { + return Err(PyValueError::new_err(format!( + "commit message was prepared for a different table \ + (message table '{}', committer table '{}')", + msg.table_location, self.table_location + ))); + } + if msg.commit_user != self.commit_user { + return Err(PyValueError::new_err( + "commit message was prepared by a different WriteBuilder \ + (writer and committer must come from the same \ + table.new_write_builder() so they share one commit_user)" + .to_string(), + )); + } + inner_messages.push(msg.inner.clone()); + } + let rt = runtime(); + py.detach(|| rt.block_on(async { self.inner.commit(inner_messages).await })) + .map_err(to_py_err) + } +} + +/// An opaque commit message produced by `prepare_commit`, consumed by `commit`. +/// PR1 supports same-process transfer only (no pickle/serialization). +/// +/// Carries the originating table's location and builder `commit_user` so a +/// committer can reject messages prepared for a different table or by a +/// different `WriteBuilder`. +#[pyclass(name = "CommitMessage", module = "pypaimon_rust.datafusion")] +pub struct PyCommitMessage { + pub(crate) inner: CommitMessage, + pub(crate) table_location: String, + pub(crate) commit_user: String, +} diff --git a/bindings/python/tests/test_write.py b/bindings/python/tests/test_write.py new file mode 100644 index 00000000..66d7d10d --- /dev/null +++ b/bindings/python/tests/test_write.py @@ -0,0 +1,179 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tempfile + +import pyarrow as pa +import pytest + +from pypaimon_rust.datafusion import PaimonCatalog, SQLContext + +# The table created by _make_empty_table is (id INT, name STRING). Paimon INT maps +# to Arrow int32, so batches must use int32 for id — pyarrow infers Python ints as +# int64, which write_arrow now (correctly, matching pypaimon) rejects as a type +# mismatch. Build batches against this explicit schema to match the table. +_TABLE_SCHEMA = pa.schema([("id", pa.int32()), ("name", pa.string())]) + + +def _batch(ids, names): + return pa.record_batch([ids, names], schema=_TABLE_SCHEMA) + + +def _make_empty_table(warehouse): + ctx = SQLContext() + ctx.register_catalog("paimon", {"warehouse": warehouse}) + ctx.sql("CREATE SCHEMA paimon.wdb") + ctx.sql("CREATE TABLE paimon.wdb.t (id INT, name STRING)") + return ctx + + +def _get_table(warehouse): + return PaimonCatalog({"warehouse": warehouse}).get_table("wdb.t") + + +def test_write_commit_read_roundtrip(): + with tempfile.TemporaryDirectory() as warehouse: + ctx = _make_empty_table(warehouse) + table = _get_table(warehouse) + batch = _batch([1, 2, 3], ["a", "b", "c"]) + wb = table.new_write_builder() + write = wb.new_write() + write.write_arrow(batch) + messages = write.prepare_commit() + assert len(messages) >= 1 # cover API shape in the first test + wb.new_commit().commit(messages) # same wb → shared commit_user + result = pa.Table.from_batches( + ctx.sql("SELECT id, name FROM paimon.wdb.t") + ).sort_by("id").to_pydict() + assert result == {"id": [1, 2, 3], "name": ["a", "b", "c"]} + + +def test_write_multiple_batches(): + with tempfile.TemporaryDirectory() as warehouse: + ctx = _make_empty_table(warehouse) + table = _get_table(warehouse) + wb = table.new_write_builder() + write = wb.new_write() + write.write_arrow(_batch([1], ["a"])) + write.write_arrow(_batch([2], ["b"])) + messages = write.prepare_commit() + wb.new_commit().commit(messages) + result = pa.Table.from_batches( + ctx.sql("SELECT id, name FROM paimon.wdb.t") + ).sort_by("id").to_pydict() + assert result == {"id": [1, 2], "name": ["a", "b"]} + + +def test_prepare_commit_returns_messages(): + with tempfile.TemporaryDirectory() as warehouse: + _make_empty_table(warehouse) + table = _get_table(warehouse) + write = table.new_write_builder().new_write() + write.write_arrow(_batch([1], ["a"])) + messages = write.prepare_commit() + assert len(messages) >= 1 + assert all(type(m).__name__ == "CommitMessage" for m in messages) + + +def test_commit_empty_messages_noop(): + with tempfile.TemporaryDirectory() as warehouse: + ctx = _make_empty_table(warehouse) + table = _get_table(warehouse) + wb = table.new_write_builder() + messages = wb.new_write().prepare_commit() # no write + assert messages == [] + wb.new_commit().commit(messages) # no-op success + batches = ctx.sql("SELECT COUNT(*) AS cnt FROM paimon.wdb.t") + assert batches[0].column(0).to_pylist() == [0] + + +def test_write_arrow_type_mismatch_raises(): + with tempfile.TemporaryDirectory() as warehouse: + _make_empty_table(warehouse) # table (id INT, name STRING) + table = _get_table(warehouse) + write = table.new_write_builder().new_write() + bad = pa.record_batch([["x", "y"], ["a", "b"]], names=["id", "name"]) # id as STRING + with pytest.raises(ValueError): + write.write_arrow(bad) + + +def test_write_arrow_binary_family_mismatch_raises(): + # A BINARY column requires Arrow `binary`; a near-equivalent `large_binary` + # must be rejected at validation (it would otherwise fail deeper, since the + # write path downcasts binary fields to arrow_array::BinaryArray only). + with tempfile.TemporaryDirectory() as warehouse: + ctx = SQLContext() + ctx.register_catalog("paimon", {"warehouse": warehouse}) + ctx.sql("CREATE SCHEMA paimon.wdb") + ctx.sql("CREATE TABLE paimon.wdb.bt (id INT, data BINARY)") + table = PaimonCatalog({"warehouse": warehouse}).get_table("wdb.bt") + write = table.new_write_builder().new_write() + schema = pa.schema([("id", pa.int32()), ("data", pa.large_binary())]) + bad = pa.record_batch([[1], [b"x"]], schema=schema) + with pytest.raises(ValueError): + write.write_arrow(bad) + + +def test_commit_non_message_raises_typeerror(): + with tempfile.TemporaryDirectory() as warehouse: + _make_empty_table(warehouse) + table = _get_table(warehouse) + with pytest.raises(TypeError): + table.new_write_builder().new_commit().commit([object()]) + # A non-iterable argument also raises TypeError (not a raw PyO3 error). + with pytest.raises(TypeError): + table.new_write_builder().new_commit().commit(42) + + +def test_commit_cross_table_messages_raises(): + # Messages prepared for one table must not be committed by another table's + # committer (would persist a snapshot referencing data files written + # elsewhere). The wrapper stamps each message with its source table location + # and the committer rejects mismatches. + with tempfile.TemporaryDirectory() as warehouse: + ctx = SQLContext() + ctx.register_catalog("paimon", {"warehouse": warehouse}) + ctx.sql("CREATE SCHEMA paimon.wdb") + ctx.sql("CREATE TABLE paimon.wdb.t1 (id INT, name STRING)") + ctx.sql("CREATE TABLE paimon.wdb.t2 (id INT, name STRING)") + catalog = PaimonCatalog({"warehouse": warehouse}) + t1 = catalog.get_table("wdb.t1") + t2 = catalog.get_table("wdb.t2") + batch = pa.record_batch( + [pa.array([1], pa.int32()), pa.array(["a"], pa.string())], + names=["id", "name"], + ) + w1 = t1.new_write_builder().new_write() + w1.write_arrow(batch) + messages = w1.prepare_commit() + with pytest.raises(ValueError): + t2.new_write_builder().new_commit().commit(messages) + + +def test_commit_different_builder_same_table_raises(): + # Even for the same table, a committer from a different WriteBuilder must + # reject the messages: each builder mints its own commit_user, and writers + # and committers must share one (snapshot duplicate detection / postpone + # bucket file naming depend on it). + with tempfile.TemporaryDirectory() as warehouse: + _make_empty_table(warehouse) + table = _get_table(warehouse) + write = table.new_write_builder().new_write() + write.write_arrow(_batch([1], ["a"])) + messages = write.prepare_commit() + with pytest.raises(ValueError): + table.new_write_builder().new_commit().commit(messages)