From 87bf681ed14f9838bdf4f938a873a4fb83c317b9 Mon Sep 17 00:00:00 2001 From: Dev-iL <6509619+Dev-iL@users.noreply.github.com> Date: Thu, 14 May 2026 16:26:17 +0300 Subject: [PATCH 1/2] Add copy_records_to_table for COPY FROM STDIN bulk-load Closes #166. The existing binary_copy_to_table required callers to pre-encode PostgreSQL's binary COPY wire format, leaving no ergonomic bulk-load path comparable to asyncpg's copy_records_to_table or psycopg3's cursor.copy(...). The new method on Connection and Transaction accepts an iterable of records, introspects column types from the target table, and streams rows via tokio-postgres' BinaryCopyInWriter using the same PythonDTO conversions used by execute(). --- python/psqlpy/_internal/__init__.pyi | 52 ++++++++ python/tests/test_copy_records.py | 174 +++++++++++++++++++++++++++ src/driver/common.rs | 142 +++++++++++++++++++++- 3 files changed, 365 insertions(+), 3 deletions(-) create mode 100644 python/tests/test_copy_records.py diff --git a/python/psqlpy/_internal/__init__.pyi b/python/psqlpy/_internal/__init__.pyi index 336bc311..e0f4f794 100644 --- a/python/psqlpy/_internal/__init__.pyi +++ b/python/psqlpy/_internal/__init__.pyi @@ -820,6 +820,32 @@ class Transaction: number of inserted rows; """ + async def copy_records_to_table( + self: Self, + table_name: str, + records: typing.Iterable[Sequence[Any]], + columns: Sequence[str] | None = None, + schema_name: str | None = None, + ) -> int: + """Copy records into a table using the binary COPY FROM STDIN protocol. + + Column types are introspected from the target table, so each record + may contain raw Python values (the same conversions used by + `execute`). Mirrors `asyncpg.Connection.copy_records_to_table`. + + ### Parameters: + - `table_name`: name of the table. + - `records`: iterable of records (each a sequence of column values + matching the order of `columns`, or of the table's columns when + `columns` is `None`). + - `columns`: sequence of column names to load into. When `None`, + all columns of the table are used in their declared order. + - `schema_name`: optional schema for `table_name`. + + ### Returns: + number of inserted rows; + """ + async def connect( dsn: str | None = None, username: str | None = None, @@ -1146,6 +1172,32 @@ class Connection: number of inserted rows; """ + async def copy_records_to_table( + self: Self, + table_name: str, + records: typing.Iterable[Sequence[Any]], + columns: Sequence[str] | None = None, + schema_name: str | None = None, + ) -> int: + """Copy records into a table using the binary COPY FROM STDIN protocol. + + Column types are introspected from the target table, so each record + may contain raw Python values (the same conversions used by + `execute`). Mirrors `asyncpg.Connection.copy_records_to_table`. + + ### Parameters: + - `table_name`: name of the table. + - `records`: iterable of records (each a sequence of column values + matching the order of `columns`, or of the table's columns when + `columns` is `None`). + - `columns`: sequence of column names to load into. When `None`, + all columns of the table are used in their declared order. + - `schema_name`: optional schema for `table_name`. + + ### Returns: + number of inserted rows; + """ + class ConnectionPoolStatus: max_size: int size: int diff --git a/python/tests/test_copy_records.py b/python/tests/test_copy_records.py new file mode 100644 index 00000000..28aba34a --- /dev/null +++ b/python/tests/test_copy_records.py @@ -0,0 +1,174 @@ +import typing +from datetime import datetime, timezone + +import pytest +from psqlpy import ConnectionPool +from psqlpy.exceptions import PyToRustValueMappingError + +pytestmark = pytest.mark.anyio + + +async def _setup_target_table(psql_pool: ConnectionPool, name: str) -> None: + async with psql_pool.acquire() as connection: + await connection.execute(f"DROP TABLE IF EXISTS {name}") + await connection.execute( + f""" + CREATE TABLE {name} ( + id INTEGER, + label TEXT, + weight DOUBLE PRECISION, + created_at TIMESTAMPTZ + ) + """, + ) + + +async def _drop_target_table(psql_pool: ConnectionPool, name: str) -> None: + async with psql_pool.acquire() as connection: + await connection.execute(f"DROP TABLE IF EXISTS {name}") + + +async def test_copy_records_to_table_on_connection( + psql_pool: ConnectionPool, +) -> None: + target: typing.Final = "copy_records_conn" + await _setup_target_table(psql_pool, target) + try: + records = [ + (1, "alpha", 1.5, datetime(2026, 1, 1, tzinfo=timezone.utc)), + (2, "beta", 2.25, datetime(2026, 1, 2, tzinfo=timezone.utc)), + (3, "gamma", None, datetime(2026, 1, 3, tzinfo=timezone.utc)), + ] + + async with psql_pool.acquire() as connection: + inserted = await connection.copy_records_to_table( + table_name=target, + records=records, + ) + + assert inserted == len(records) + + async with psql_pool.acquire() as connection: + result = await connection.execute( + f"SELECT id, label, weight FROM {target} ORDER BY id", + ) + rows = result.result() + assert [(r["id"], r["label"], r["weight"]) for r in rows] == [ + (1, "alpha", 1.5), + (2, "beta", 2.25), + (3, "gamma", None), + ] + finally: + await _drop_target_table(psql_pool, target) + + +async def test_copy_records_to_table_with_columns_subset( + psql_pool: ConnectionPool, +) -> None: + target: typing.Final = "copy_records_subset" + await _setup_target_table(psql_pool, target) + try: + records = [(10, "only-label"), (11, "another")] + + async with psql_pool.acquire() as connection: + inserted = await connection.copy_records_to_table( + table_name=target, + records=records, + columns=["id", "label"], + ) + + assert inserted == len(records) + + async with psql_pool.acquire() as connection: + result = await connection.execute( + f"SELECT id, label, weight, created_at FROM {target} ORDER BY id", + ) + rows = result.result() + assert [(r["id"], r["label"]) for r in rows] == [ + (10, "only-label"), + (11, "another"), + ] + # Untouched columns must remain NULL + assert all(r["weight"] is None and r["created_at"] is None for r in rows) + finally: + await _drop_target_table(psql_pool, target) + + +async def test_copy_records_to_table_in_transaction( + psql_pool: ConnectionPool, +) -> None: + target: typing.Final = "copy_records_tx" + await _setup_target_table(psql_pool, target) + try: + records = [(100, "tx-row", 0.0, datetime(2026, 5, 1, tzinfo=timezone.utc))] + + async with ( + psql_pool.acquire() as connection, + connection.transaction() as tx, + ): + inserted = await tx.copy_records_to_table( + table_name=target, + records=records, + ) + + assert inserted == 1 + + async with psql_pool.acquire() as connection: + result = await connection.execute( + f"SELECT COUNT(*) AS c FROM {target}", + ) + assert result.result()[0]["c"] == 1 + finally: + await _drop_target_table(psql_pool, target) + + +async def test_copy_records_to_table_rejects_record_arity_mismatch( + psql_pool: ConnectionPool, +) -> None: + target: typing.Final = "copy_records_mismatch" + await _setup_target_table(psql_pool, target) + try: + records = [(1, "missing-rest")] # table has 4 columns + + async with psql_pool.acquire() as connection: + with pytest.raises(PyToRustValueMappingError): + await connection.copy_records_to_table( + table_name=target, + records=records, + ) + finally: + await _drop_target_table(psql_pool, target) + + +async def test_copy_records_to_table_uses_schema_qualifier( + psql_pool: ConnectionPool, +) -> None: + schema: typing.Final = "copy_records_schema" + target: typing.Final = "tbl" + + async with psql_pool.acquire() as connection: + await connection.execute(f"DROP SCHEMA IF EXISTS {schema} CASCADE") + await connection.execute(f"CREATE SCHEMA {schema}") + await connection.execute( + f"CREATE TABLE {schema}.{target} (id INTEGER, label TEXT)", + ) + + try: + records = [(1, "schema-a"), (2, "schema-b")] + async with psql_pool.acquire() as connection: + inserted = await connection.copy_records_to_table( + table_name=target, + records=records, + schema_name=schema, + ) + + assert inserted == len(records) + + async with psql_pool.acquire() as connection: + result = await connection.execute( + f"SELECT id, label FROM {schema}.{target} ORDER BY id", + ) + assert [(r["id"], r["label"]) for r in result.result()] == records + finally: + async with psql_pool.acquire() as connection: + await connection.execute(f"DROP SCHEMA IF EXISTS {schema} CASCADE") diff --git a/src/driver/common.rs b/src/driver/common.rs index b2ff6a52..2794cba2 100644 --- a/src/driver/common.rs +++ b/src/driver/common.rs @@ -10,14 +10,15 @@ use super::{ use pyo3::{pymethods, Py, PyAny}; use crate::{ - connection::traits::CloseTransaction, + connection::traits::{CloseTransaction, Connection as _}, exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, + value_converter::{dto::enums::PythonDTO, from_python::from_python_typed}, }; use bytes::BytesMut; use futures_util::pin_mut; -use pyo3::{buffer::PyBuffer, Python}; -use tokio_postgres::binary_copy::BinaryCopyInWriter; +use pyo3::{buffer::PyBuffer, types::PyAnyMethods, Python}; +use tokio_postgres::{binary_copy::BinaryCopyInWriter, types::ToSql}; use crate::format_helpers::quote_ident; @@ -320,3 +321,138 @@ macro_rules! impl_binary_copy_method { impl_binary_copy_method!(Connection); impl_binary_copy_method!(Transaction); + +macro_rules! impl_copy_records_method { + ($name:ident) => { + #[pymethods] + impl $name { + /// Copy a list of records into a table using the COPY FROM STDIN + /// binary protocol. + /// + /// Column types are introspected from the target table, so callers + /// pass Python values directly (the same conversions used by + /// `execute`). Mirrors `asyncpg.Connection.copy_records_to_table`. + /// + /// # Errors + /// May return error if there is some problem with DB communication, + /// the table cannot be introspected, or a value cannot be converted. + #[pyo3(signature = (table_name, records, columns=None, schema_name=None))] + pub async fn copy_records_to_table( + self_: pyo3::Py, + table_name: String, + records: Py, + columns: Option>, + schema_name: Option, + ) -> PSQLPyResult { + let (db_client, raw_records) = Python::with_gil( + |gil| -> PSQLPyResult<(Option<_>, Vec>>)> { + let db_client = self_.borrow(gil).conn.clone(); + + let Some(db_client) = db_client else { + return Ok((None, Vec::new())); + }; + + let bound = records.bind(gil); + let mut rows: Vec>> = Vec::new(); + for item in bound.try_iter()? { + let row = item?; + let mut row_vec: Vec> = Vec::new(); + for cell in row.try_iter()? { + row_vec.push(cell?.unbind()); + } + rows.push(row_vec); + } + + Ok((Some(db_client), rows)) + }, + )?; + + let Some(db_client) = db_client else { + return Ok(0); + }; + + let full_table_name = match schema_name { + Some(ref schema) => { + format!("{}.{}", quote_ident(schema), quote_ident(&table_name)) + } + None => quote_ident(&table_name), + }; + + let columns_sql = match columns { + Some(ref cols) if !cols.is_empty() => Some( + cols.iter() + .map(|c| quote_ident(c)) + .collect::>() + .join(", "), + ), + _ => None, + }; + + let introspect_qs = match &columns_sql { + Some(cols) => format!("SELECT {} FROM {} WHERE false", cols, full_table_name), + None => format!("SELECT * FROM {} WHERE false", full_table_name), + }; + + let read_conn_g = db_client.read().await; + + let stmt = read_conn_g.prepare(&introspect_qs, false).await?; + let column_types: Vec = + stmt.columns().iter().map(|c| c.type_().clone()).collect(); + + if column_types.is_empty() { + return Err(RustPSQLDriverError::PyToRustValueConversionError( + "Cannot introspect column types from target table".into(), + )); + } + + let typed_rows: Vec> = + Python::with_gil(|gil| -> PSQLPyResult>> { + let mut typed: Vec> = Vec::with_capacity(raw_records.len()); + for (row_idx, row) in raw_records.iter().enumerate() { + if row.len() != column_types.len() { + return Err(RustPSQLDriverError::PyToRustValueConversionError( + format!( + "Record at index {} has {} fields, expected {}", + row_idx, + row.len(), + column_types.len() + ), + )); + } + let mut row_dto: Vec = Vec::with_capacity(row.len()); + for (cell, ty) in row.iter().zip(column_types.iter()) { + row_dto.push(from_python_typed(cell.bind(gil), ty)?); + } + typed.push(row_dto); + } + Ok(typed) + })?; + + let copy_qs = match &columns_sql { + Some(cols) => format!( + "COPY {}({}) FROM STDIN (FORMAT binary)", + full_table_name, cols + ), + None => format!("COPY {} FROM STDIN (FORMAT binary)", full_table_name), + }; + + let sink = read_conn_g.copy_in(©_qs).await?; + let writer = BinaryCopyInWriter::new(sink, &column_types); + pin_mut!(writer); + + for row in &typed_rows { + let row_refs: Vec<&(dyn ToSql + Sync)> = + row.iter().map(|v| v as &(dyn ToSql + Sync)).collect(); + writer.as_mut().write(&row_refs).await?; + } + + let rows_created = writer.as_mut().finish().await?; + + Ok(rows_created) + } + } + }; +} + +impl_copy_records_method!(Connection); +impl_copy_records_method!(Transaction); From a0a25946687e2b855ac36c19ff28fb28cdef1fe3 Mon Sep 17 00:00:00 2001 From: Dev-iL <6509619+Dev-iL@users.noreply.github.com> Date: Sat, 16 May 2026 12:17:23 +0300 Subject: [PATCH 2/2] Add COPY FROM STDIN documentation Documents binary_copy_to_table and copy_records_to_table on a dedicated docs/components/copy.md page with tabbed Connection/Transaction examples. connection.md and transaction.md reference that page. Sidebar updated. Co-Authored-By: Claude Sonnet 4.6 --- docs/.vuepress/sidebar.ts | 1 + docs/components/connection.md | 5 ++ docs/components/copy.md | 144 +++++++++++++++++++++++++++++++++ docs/components/transaction.md | 5 ++ 4 files changed, 155 insertions(+) create mode 100644 docs/components/copy.md diff --git a/docs/.vuepress/sidebar.ts b/docs/.vuepress/sidebar.ts index 1db34495..4687fdaa 100644 --- a/docs/.vuepress/sidebar.ts +++ b/docs/.vuepress/sidebar.ts @@ -22,6 +22,7 @@ export default sidebar({ "connection_pool_builder", "connection", "transaction", + "copy", "cursor", "prepared_statement", "listener", diff --git a/docs/components/connection.md b/docs/components/connection.md index 6cd0ec82..bc5a8be6 100644 --- a/docs/components/connection.md +++ b/docs/components/connection.md @@ -245,6 +245,11 @@ async def main() -> None: ) ``` +### COPY FROM STDIN + +`Connection` supports bulk-loading via `binary_copy_to_table` and `copy_records_to_table`. +See the [COPY FROM STDIN](./copy.md) page for full documentation and examples. + ### Close Returns connection to the pool. It's crucial to commit all transactions and close all cursor which are made from the connection. diff --git a/docs/components/copy.md b/docs/components/copy.md new file mode 100644 index 00000000..fafd2fe0 --- /dev/null +++ b/docs/components/copy.md @@ -0,0 +1,144 @@ +--- +title: COPY FROM STDIN +--- + +PSQLPy exposes two methods for bulk-loading data via PostgreSQL's `COPY FROM STDIN` protocol. +Both are available on `Connection` and `Transaction`. + +## Binary Copy To Table + +#### Parameters: + +- `source`: bytes, bytearray, or `BytesIO` containing a PostgreSQL binary COPY stream. +- `table_name`: name of the target table. +- `columns`: sequence of column names to load into. When `None`, all table columns are used in their declared order. +- `schema_name`: optional schema for `table_name`. + +Stream a pre-encoded PostgreSQL binary COPY payload directly into a table. +Executes `COPY table_name () FROM STDIN (FORMAT binary)`. + +::: warning +You are responsible for encoding the bytes correctly. Passing an invalid binary COPY stream will result in a database error. +::: + +::: tabs + +@tab Connection +```python +async def main() -> None: + ... + connection = await db_pool.connection() + with open("data.bin", "rb") as f: + inserted = await connection.binary_copy_to_table( + source=f.read(), + table_name="users", + columns=["id", "username"], + ) + print(f"Inserted {inserted} rows") +``` + +@tab Transaction +```python +async def main() -> None: + ... + connection = await db_pool.connection() + async with connection.transaction() as transaction: + with open("data.bin", "rb") as f: + inserted = await transaction.binary_copy_to_table( + source=f.read(), + table_name="users", + columns=["id", "username"], + ) + print(f"Inserted {inserted} rows") +``` + +::: + +## Copy Records To Table + +#### Parameters: + +- `table_name`: name of the target table. +- `records`: iterable of records, where each record is a sequence of column values. +- `columns`: sequence of column names to load into. When `None`, all table columns are used in their declared order. +- `schema_name`: optional schema for `table_name`. + +Bulk-load plain Python records into a table via the binary `COPY FROM STDIN` protocol. +Column types are introspected from the target table automatically, so each record may contain ordinary Python values — the same types accepted by `execute()`. +Returns the number of inserted rows. + +This is the ergonomic alternative to `binary_copy_to_table` when you have Python data rather than a pre-encoded binary stream. + +::: tabs + +@tab Connection +```python +from datetime import datetime, timezone + +async def main() -> None: + ... + connection = await db_pool.connection() + records = [ + (1, "alpha", 1.5, datetime(2026, 1, 1, tzinfo=timezone.utc)), + (2, "beta", 2.25, datetime(2026, 1, 2, tzinfo=timezone.utc)), + (3, "gamma", None, datetime(2026, 1, 3, tzinfo=timezone.utc)), + ] + inserted = await connection.copy_records_to_table( + table_name="measurements", + records=records, + ) + print(f"Inserted {inserted} rows") +``` + +@tab Transaction +```python +from datetime import datetime, timezone + +async def main() -> None: + ... + connection = await db_pool.connection() + records = [ + (1, "alpha", 1.5, datetime(2026, 1, 1, tzinfo=timezone.utc)), + (2, "beta", 2.25, datetime(2026, 1, 2, tzinfo=timezone.utc)), + (3, "gamma", None, datetime(2026, 1, 3, tzinfo=timezone.utc)), + ] + async with connection.transaction() as transaction: + inserted = await transaction.copy_records_to_table( + table_name="measurements", + records=records, + ) + print(f"Inserted {inserted} rows") +``` + +::: + +You can load only a subset of columns by providing the `columns` argument: + +::: tabs + +@tab Connection +```python +async def main() -> None: + ... + connection = await db_pool.connection() + inserted = await connection.copy_records_to_table( + table_name="measurements", + records=[(1, "alpha"), (2, "beta")], + columns=["id", "label"], + ) +``` + +@tab Transaction +```python +async def main() -> None: + ... + connection = await db_pool.connection() + async with connection.transaction() as transaction: + inserted = await transaction.copy_records_to_table( + table_name="measurements", + records=[(1, "alpha"), (2, "beta")], + columns=["id", "label"], + ) +``` + +::: diff --git a/docs/components/transaction.md b/docs/components/transaction.md index 8f18a097..d80c9182 100644 --- a/docs/components/transaction.md +++ b/docs/components/transaction.md @@ -426,3 +426,8 @@ async def main() -> None: dict_result: List[Dict[Any, Any]] = fetched_result.result() ... # do something with the result. ``` + +### COPY FROM STDIN + +`Transaction` supports bulk-loading via `binary_copy_to_table` and `copy_records_to_table`. +See the [COPY FROM STDIN](./copy.md) page for full documentation and examples.