Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions bindings/python/python/pypaimon_rust/datafusion.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,20 @@ class Table:
def location(self) -> str: ...
def schema(self) -> TableSchema: ...
def new_read_builder(self) -> ReadBuilder: ...
def new_write_builder(self) -> "WriteBuilder": ...

class CommitMessage: ...

class TableWrite:
def write_arrow(self, batch: pyarrow.RecordBatch) -> None: ...
def prepare_commit(self) -> List[CommitMessage]: ...

class TableCommit:
def commit(self, messages: Sequence[CommitMessage]) -> None: ...

class WriteBuilder:
def new_write(self) -> TableWrite: ...
def new_commit(self) -> TableCommit: ...

class PaimonCatalog:
def __init__(self, catalog_options: Dict[str, str]) -> None: ...
Expand Down
4 changes: 4 additions & 0 deletions bindings/python/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,10 @@ pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()>
this.add_class::<crate::schema::PyDataField>()?;
this.add_class::<PyPythonScalarUDFObject>()?;
this.add_class::<PySQLContext>()?;
this.add_class::<crate::write::PyWriteBuilder>()?;
this.add_class::<crate::write::PyTableWrite>()?;
this.add_class::<crate::write::PyTableCommit>()?;
this.add_class::<crate::write::PyCommitMessage>()?;
this.add_function(wrap_pyfunction!(udf, &this)?)?;
m.add_submodule(&this)?;
py.import("sys")?
Expand Down
1 change: 1 addition & 0 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ mod read;
mod schema;
mod table;
mod udf;
mod write;

#[pymodule]
fn pypaimon_rust(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
Expand Down
6 changes: 6 additions & 0 deletions bindings/python/src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use pyo3::prelude::*;

use crate::read::PyReadBuilder;
use crate::schema::PyTableSchema;
use crate::write::PyWriteBuilder;

#[pyclass(name = "Table", module = "pypaimon_rust.datafusion")]
pub struct PyTable {
Expand Down Expand Up @@ -52,4 +53,9 @@ impl PyTable {
fn new_read_builder(&self) -> PyReadBuilder {
PyReadBuilder::new(Arc::clone(&self.inner))
}

/// Create a [`PyWriteBuilder`] for the batch write loop.
fn new_write_builder(&self) -> PyWriteBuilder {
PyWriteBuilder::new(Arc::clone(&self.inner))
}
}
180 changes: 180 additions & 0 deletions bindings/python/src/write.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use arrow::datatypes::Schema as ArrowSchema;
use arrow::pyarrow::FromPyArrow;
use arrow::record_batch::RecordBatch;
use paimon::table::{CommitMessage, Table, TableCommit, TableWrite};
use paimon_datafusion::runtime::runtime;
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;

use crate::error::to_py_err;

/// Validate an incoming batch schema against the table's target Arrow schema:
/// field count, order, and names must match, and types must match exactly. The
/// nullable flag is intentionally NOT compared, since `build_target_arrow_schema`
/// derives nullability from the Paimon field while pyarrow-constructed batches
/// infer nullable=true. No cast — callers supply correctly-typed batches.
///
/// Type matching is strict (no binary-family interchange): the lower write path
/// downcasts to the exact Arrow array for each Paimon type (e.g. a `Binary` /
/// `VarBinary` field requires `arrow_array::BinaryArray`, not `LargeBinary` /
/// `FixedSizeBinary`). Accepting a near-equivalent type here would pass
/// validation but then fail deeper with a type-mismatch (or write files whose
/// Arrow schema differs from the table), so it is rejected up front.
fn validate_batch_schema(input: &ArrowSchema, target: &ArrowSchema) -> PyResult<()> {
let mismatch = || {
PyValueError::new_err(format!(
"Input schema is not consistent with the table schema. \
input: {input:?}, table: {target:?}"
))
};
if input.fields().len() != target.fields().len() {
return Err(mismatch());
}
for (i, t) in input.fields().iter().zip(target.fields().iter()) {
if i.name() != t.name() {
return Err(mismatch());
}
if i.data_type() != t.data_type() {
return Err(mismatch());
}
}
Ok(())
}

/// Builder for the batch write loop, created via [`crate::table::PyTable::new_write_builder`].
///
/// Holds the owning table plus a single fixed `commit_user`, generated once and
/// shared by both `new_write()` and `new_commit()` so that writers and the
/// committer agree on the commit user (Paimon uses it for duplicate-commit
/// detection). Creating a fresh `WriteBuilder` per call would otherwise mint a
/// new random UUID each time.
#[pyclass(name = "WriteBuilder", module = "pypaimon_rust.datafusion")]
pub struct PyWriteBuilder {
table: Arc<Table>,
commit_user: String,
}

impl PyWriteBuilder {
pub fn new(table: Arc<Table>) -> Self {
let commit_user = table.new_write_builder().commit_user().to_string();
Self { table, commit_user }
}
}

#[pymethods]
impl PyWriteBuilder {
/// Create a writer for accumulating Arrow batches.
fn new_write(&self) -> PyResult<PyTableWrite> {
let builder = self
.table
.new_write_builder()
.with_commit_user(self.commit_user.clone())
.map_err(to_py_err)?;
let target_schema = paimon::arrow::build_target_arrow_schema(self.table.schema().fields())
.map_err(to_py_err)?;
Ok(PyTableWrite {
inner: builder.new_write().map_err(to_py_err)?,
target_schema,
})
}

/// Create a committer for persisting prepared commit messages.
fn new_commit(&self) -> PyResult<PyTableCommit> {
let builder = self
.table
.new_write_builder()
.with_commit_user(self.commit_user.clone())
.map_err(to_py_err)?;
Ok(PyTableCommit {
inner: builder.new_commit(),
})
}
}

/// A stateful writer that accumulates Arrow batches until `prepare_commit`.
///
/// Marked `unsendable`: the underlying `TableWrite` holds file writers that are
/// not `Sync`, so the object enforces single-thread access at runtime.
#[pyclass(name = "TableWrite", module = "pypaimon_rust.datafusion", unsendable)]
pub struct PyTableWrite {
inner: TableWrite,
/// The table's target Arrow schema, used to validate incoming batches.
target_schema: Arc<ArrowSchema>,
}

#[pymethods]
impl PyTableWrite {
/// Write a single PyArrow RecordBatch into the table's writers.
fn write_arrow(&mut self, py: Python<'_>, batch: &Bound<'_, PyAny>) -> PyResult<()> {
let batch = RecordBatch::from_pyarrow_bound(batch)?;
validate_batch_schema(&batch.schema(), &self.target_schema)?;
let rt = runtime();
py.detach(|| rt.block_on(async { self.inner.write_arrow_batch(&batch).await }))
.map_err(to_py_err)
}

/// Close writers and return the commit messages (opaque; pass to commit()).
fn prepare_commit(&mut self, py: Python<'_>) -> PyResult<Vec<PyCommitMessage>> {
let rt = runtime();
let messages = py
.detach(|| rt.block_on(async { self.inner.prepare_commit().await }))
.map_err(to_py_err)?;
Ok(messages
.into_iter()
.map(|inner| PyCommitMessage { inner })
.collect())
}
}

/// A committer that persists prepared commit messages as a snapshot.
#[pyclass(name = "TableCommit", module = "pypaimon_rust.datafusion")]
pub struct PyTableCommit {
inner: TableCommit,
}

#[pymethods]
impl PyTableCommit {
/// Commit the given commit messages. Empty input is a no-op success.
fn commit(&self, py: Python<'_>, messages: &Bound<'_, PyAny>) -> PyResult<()> {
let mut inner_messages = Vec::new();
let iter = messages.try_iter().map_err(|_| {
PyTypeError::new_err("commit() expects a sequence of CommitMessage objects")
})?;
for item in iter {
let item = item?;
let msg: PyRef<PyCommitMessage> = item.extract().map_err(|_| {
PyTypeError::new_err("commit() expects a sequence of CommitMessage objects")
})?;
inner_messages.push(msg.inner.clone());
}
let rt = runtime();
py.detach(|| rt.block_on(async { self.inner.commit(inner_messages).await }))
.map_err(to_py_err)
}
}

/// An opaque commit message produced by `prepare_commit`, consumed by `commit`.
/// PR1 supports same-process transfer only (no pickle/serialization).
#[pyclass(name = "CommitMessage", module = "pypaimon_rust.datafusion")]
pub struct PyCommitMessage {
pub(crate) inner: CommitMessage,
}
139 changes: 139 additions & 0 deletions bindings/python/tests/test_write.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import tempfile

import pyarrow as pa
import pytest

from pypaimon_rust.datafusion import PaimonCatalog, SQLContext

# The table created by _make_empty_table is (id INT, name STRING). Paimon INT maps
# to Arrow int32, so batches must use int32 for id — pyarrow infers Python ints as
# int64, which write_arrow now (correctly, matching pypaimon) rejects as a type
# mismatch. Build batches against this explicit schema to match the table.
_TABLE_SCHEMA = pa.schema([("id", pa.int32()), ("name", pa.string())])


def _batch(ids, names):
return pa.record_batch([ids, names], schema=_TABLE_SCHEMA)


def _make_empty_table(warehouse):
ctx = SQLContext()
ctx.register_catalog("paimon", {"warehouse": warehouse})
ctx.sql("CREATE SCHEMA paimon.wdb")
ctx.sql("CREATE TABLE paimon.wdb.t (id INT, name STRING)")
return ctx


def _get_table(warehouse):
return PaimonCatalog({"warehouse": warehouse}).get_table("wdb.t")


def test_write_commit_read_roundtrip():
with tempfile.TemporaryDirectory() as warehouse:
ctx = _make_empty_table(warehouse)
table = _get_table(warehouse)
batch = _batch([1, 2, 3], ["a", "b", "c"])
wb = table.new_write_builder()
write = wb.new_write()
write.write_arrow(batch)
messages = write.prepare_commit()
assert len(messages) >= 1 # cover API shape in the first test
wb.new_commit().commit(messages) # same wb → shared commit_user
result = pa.Table.from_batches(
ctx.sql("SELECT id, name FROM paimon.wdb.t")
).sort_by("id").to_pydict()
assert result == {"id": [1, 2, 3], "name": ["a", "b", "c"]}


def test_write_multiple_batches():
with tempfile.TemporaryDirectory() as warehouse:
ctx = _make_empty_table(warehouse)
table = _get_table(warehouse)
wb = table.new_write_builder()
write = wb.new_write()
write.write_arrow(_batch([1], ["a"]))
write.write_arrow(_batch([2], ["b"]))
messages = write.prepare_commit()
wb.new_commit().commit(messages)
result = pa.Table.from_batches(
ctx.sql("SELECT id, name FROM paimon.wdb.t")
).sort_by("id").to_pydict()
assert result == {"id": [1, 2], "name": ["a", "b"]}


def test_prepare_commit_returns_messages():
with tempfile.TemporaryDirectory() as warehouse:
_make_empty_table(warehouse)
table = _get_table(warehouse)
write = table.new_write_builder().new_write()
write.write_arrow(_batch([1], ["a"]))
messages = write.prepare_commit()
assert len(messages) >= 1
assert all(type(m).__name__ == "CommitMessage" for m in messages)


def test_commit_empty_messages_noop():
with tempfile.TemporaryDirectory() as warehouse:
ctx = _make_empty_table(warehouse)
table = _get_table(warehouse)
wb = table.new_write_builder()
messages = wb.new_write().prepare_commit() # no write
assert messages == []
wb.new_commit().commit(messages) # no-op success
batches = ctx.sql("SELECT COUNT(*) AS cnt FROM paimon.wdb.t")
assert batches[0].column(0).to_pylist() == [0]


def test_write_arrow_type_mismatch_raises():
with tempfile.TemporaryDirectory() as warehouse:
_make_empty_table(warehouse) # table (id INT, name STRING)
table = _get_table(warehouse)
write = table.new_write_builder().new_write()
bad = pa.record_batch([["x", "y"], ["a", "b"]], names=["id", "name"]) # id as STRING
with pytest.raises(ValueError):
write.write_arrow(bad)


def test_write_arrow_binary_family_mismatch_raises():
# A BINARY column requires Arrow `binary`; a near-equivalent `large_binary`
# must be rejected at validation (it would otherwise fail deeper, since the
# write path downcasts binary fields to arrow_array::BinaryArray only).
with tempfile.TemporaryDirectory() as warehouse:
ctx = SQLContext()
ctx.register_catalog("paimon", {"warehouse": warehouse})
ctx.sql("CREATE SCHEMA paimon.wdb")
ctx.sql("CREATE TABLE paimon.wdb.bt (id INT, data BINARY)")
table = PaimonCatalog({"warehouse": warehouse}).get_table("wdb.bt")
write = table.new_write_builder().new_write()
schema = pa.schema([("id", pa.int32()), ("data", pa.large_binary())])
bad = pa.record_batch([[1], [b"x"]], schema=schema)
with pytest.raises(ValueError):
write.write_arrow(bad)


def test_commit_non_message_raises_typeerror():
with tempfile.TemporaryDirectory() as warehouse:
_make_empty_table(warehouse)
table = _get_table(warehouse)
with pytest.raises(TypeError):
table.new_write_builder().new_commit().commit([object()])
# A non-iterable argument also raises TypeError (not a raw PyO3 error).
with pytest.raises(TypeError):
table.new_write_builder().new_commit().commit(42)
Loading