Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions python/psqlpy/_internal/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,32 @@ class Transaction:
number of inserted rows;
"""

async def copy_records_to_table(
self: Self,
table_name: str,
records: typing.Iterable[Sequence[Any]],
columns: Sequence[str] | None = None,
schema_name: str | None = None,
) -> int:
"""Copy records into a table using the binary COPY FROM STDIN protocol.

Column types are introspected from the target table, so each record
may contain raw Python values (the same conversions used by
`execute`). Mirrors `asyncpg.Connection.copy_records_to_table`.

### Parameters:
- `table_name`: name of the table.
- `records`: iterable of records (each a sequence of column values
matching the order of `columns`, or of the table's columns when
`columns` is `None`).
- `columns`: sequence of column names to load into. When `None`,
all columns of the table are used in their declared order.
- `schema_name`: optional schema for `table_name`.

### Returns:
number of inserted rows;
"""

async def connect(
dsn: str | None = None,
username: str | None = None,
Expand Down Expand Up @@ -1146,6 +1172,32 @@ class Connection:
number of inserted rows;
"""

async def copy_records_to_table(
self: Self,
table_name: str,
records: typing.Iterable[Sequence[Any]],
columns: Sequence[str] | None = None,
schema_name: str | None = None,
) -> int:
"""Copy records into a table using the binary COPY FROM STDIN protocol.

Column types are introspected from the target table, so each record
may contain raw Python values (the same conversions used by
`execute`). Mirrors `asyncpg.Connection.copy_records_to_table`.

### Parameters:
- `table_name`: name of the table.
- `records`: iterable of records (each a sequence of column values
matching the order of `columns`, or of the table's columns when
`columns` is `None`).
- `columns`: sequence of column names to load into. When `None`,
all columns of the table are used in their declared order.
- `schema_name`: optional schema for `table_name`.

### Returns:
number of inserted rows;
"""

class ConnectionPoolStatus:
max_size: int
size: int
Expand Down
174 changes: 174 additions & 0 deletions python/tests/test_copy_records.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import typing
from datetime import datetime, timezone

import pytest
from psqlpy import ConnectionPool
from psqlpy.exceptions import PyToRustValueMappingError

pytestmark = pytest.mark.anyio


async def _setup_target_table(psql_pool: ConnectionPool, name: str) -> None:
async with psql_pool.acquire() as connection:
await connection.execute(f"DROP TABLE IF EXISTS {name}")
await connection.execute(
f"""
CREATE TABLE {name} (
id INTEGER,
label TEXT,
weight DOUBLE PRECISION,
created_at TIMESTAMPTZ
)
""",
)


async def _drop_target_table(psql_pool: ConnectionPool, name: str) -> None:
async with psql_pool.acquire() as connection:
await connection.execute(f"DROP TABLE IF EXISTS {name}")


async def test_copy_records_to_table_on_connection(
psql_pool: ConnectionPool,
) -> None:
target: typing.Final = "copy_records_conn"
await _setup_target_table(psql_pool, target)
try:
records = [
(1, "alpha", 1.5, datetime(2026, 1, 1, tzinfo=timezone.utc)),
(2, "beta", 2.25, datetime(2026, 1, 2, tzinfo=timezone.utc)),
(3, "gamma", None, datetime(2026, 1, 3, tzinfo=timezone.utc)),
]

async with psql_pool.acquire() as connection:
inserted = await connection.copy_records_to_table(
table_name=target,
records=records,
)

assert inserted == len(records)

async with psql_pool.acquire() as connection:
result = await connection.execute(
f"SELECT id, label, weight FROM {target} ORDER BY id",
)
rows = result.result()
assert [(r["id"], r["label"], r["weight"]) for r in rows] == [
(1, "alpha", 1.5),
(2, "beta", 2.25),
(3, "gamma", None),
]
finally:
await _drop_target_table(psql_pool, target)


async def test_copy_records_to_table_with_columns_subset(
psql_pool: ConnectionPool,
) -> None:
target: typing.Final = "copy_records_subset"
await _setup_target_table(psql_pool, target)
try:
records = [(10, "only-label"), (11, "another")]

async with psql_pool.acquire() as connection:
inserted = await connection.copy_records_to_table(
table_name=target,
records=records,
columns=["id", "label"],
)

assert inserted == len(records)

async with psql_pool.acquire() as connection:
result = await connection.execute(
f"SELECT id, label, weight, created_at FROM {target} ORDER BY id",
)
rows = result.result()
assert [(r["id"], r["label"]) for r in rows] == [
(10, "only-label"),
(11, "another"),
]
# Untouched columns must remain NULL
assert all(r["weight"] is None and r["created_at"] is None for r in rows)
finally:
await _drop_target_table(psql_pool, target)


async def test_copy_records_to_table_in_transaction(
psql_pool: ConnectionPool,
) -> None:
target: typing.Final = "copy_records_tx"
await _setup_target_table(psql_pool, target)
try:
records = [(100, "tx-row", 0.0, datetime(2026, 5, 1, tzinfo=timezone.utc))]

async with (
psql_pool.acquire() as connection,
connection.transaction() as tx,
):
inserted = await tx.copy_records_to_table(
table_name=target,
records=records,
)

assert inserted == 1

async with psql_pool.acquire() as connection:
result = await connection.execute(
f"SELECT COUNT(*) AS c FROM {target}",
)
assert result.result()[0]["c"] == 1
finally:
await _drop_target_table(psql_pool, target)


async def test_copy_records_to_table_rejects_record_arity_mismatch(
psql_pool: ConnectionPool,
) -> None:
target: typing.Final = "copy_records_mismatch"
await _setup_target_table(psql_pool, target)
try:
records = [(1, "missing-rest")] # table has 4 columns

async with psql_pool.acquire() as connection:
with pytest.raises(PyToRustValueMappingError):
await connection.copy_records_to_table(
table_name=target,
records=records,
)
finally:
await _drop_target_table(psql_pool, target)


async def test_copy_records_to_table_uses_schema_qualifier(
psql_pool: ConnectionPool,
) -> None:
schema: typing.Final = "copy_records_schema"
target: typing.Final = "tbl"

async with psql_pool.acquire() as connection:
await connection.execute(f"DROP SCHEMA IF EXISTS {schema} CASCADE")
await connection.execute(f"CREATE SCHEMA {schema}")
await connection.execute(
f"CREATE TABLE {schema}.{target} (id INTEGER, label TEXT)",
)

try:
records = [(1, "schema-a"), (2, "schema-b")]
async with psql_pool.acquire() as connection:
inserted = await connection.copy_records_to_table(
table_name=target,
records=records,
schema_name=schema,
)

assert inserted == len(records)

async with psql_pool.acquire() as connection:
result = await connection.execute(
f"SELECT id, label FROM {schema}.{target} ORDER BY id",
)
assert [(r["id"], r["label"]) for r in result.result()] == records
finally:
async with psql_pool.acquire() as connection:
await connection.execute(f"DROP SCHEMA IF EXISTS {schema} CASCADE")
142 changes: 139 additions & 3 deletions src/driver/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@ use super::{
use pyo3::{pymethods, Py, PyAny};

use crate::{
connection::traits::CloseTransaction,
connection::traits::{CloseTransaction, Connection as _},
exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError},
value_converter::{dto::enums::PythonDTO, from_python::from_python_typed},
};

use bytes::BytesMut;
use futures_util::pin_mut;
use pyo3::{buffer::PyBuffer, Python};
use tokio_postgres::binary_copy::BinaryCopyInWriter;
use pyo3::{buffer::PyBuffer, types::PyAnyMethods, Python};
use tokio_postgres::{binary_copy::BinaryCopyInWriter, types::ToSql};

use crate::format_helpers::quote_ident;

Expand Down Expand Up @@ -320,3 +321,138 @@ macro_rules! impl_binary_copy_method {

impl_binary_copy_method!(Connection);
impl_binary_copy_method!(Transaction);

macro_rules! impl_copy_records_method {
($name:ident) => {
#[pymethods]
impl $name {
/// Copy a list of records into a table using the COPY FROM STDIN
/// binary protocol.
///
/// Column types are introspected from the target table, so callers
/// pass Python values directly (the same conversions used by
/// `execute`). Mirrors `asyncpg.Connection.copy_records_to_table`.
///
/// # Errors
/// May return error if there is some problem with DB communication,
/// the table cannot be introspected, or a value cannot be converted.
#[pyo3(signature = (table_name, records, columns=None, schema_name=None))]
pub async fn copy_records_to_table(
self_: pyo3::Py<Self>,
table_name: String,
records: Py<PyAny>,
columns: Option<Vec<String>>,
schema_name: Option<String>,
) -> PSQLPyResult<u64> {
let (db_client, raw_records) = Python::with_gil(
|gil| -> PSQLPyResult<(Option<_>, Vec<Vec<pyo3::Py<PyAny>>>)> {
let db_client = self_.borrow(gil).conn.clone();

let Some(db_client) = db_client else {
return Ok((None, Vec::new()));
};

let bound = records.bind(gil);
let mut rows: Vec<Vec<pyo3::Py<PyAny>>> = Vec::new();
for item in bound.try_iter()? {
let row = item?;
let mut row_vec: Vec<pyo3::Py<PyAny>> = Vec::new();
for cell in row.try_iter()? {
row_vec.push(cell?.unbind());
}
rows.push(row_vec);
}

Ok((Some(db_client), rows))
},
)?;

let Some(db_client) = db_client else {
return Ok(0);
};

let full_table_name = match schema_name {
Some(ref schema) => {
format!("{}.{}", quote_ident(schema), quote_ident(&table_name))
}
None => quote_ident(&table_name),
};

let columns_sql = match columns {
Some(ref cols) if !cols.is_empty() => Some(
cols.iter()
.map(|c| quote_ident(c))
.collect::<Vec<_>>()
.join(", "),
),
_ => None,
};

let introspect_qs = match &columns_sql {
Some(cols) => format!("SELECT {} FROM {} WHERE false", cols, full_table_name),
None => format!("SELECT * FROM {} WHERE false", full_table_name),
};

let read_conn_g = db_client.read().await;

let stmt = read_conn_g.prepare(&introspect_qs, false).await?;
let column_types: Vec<tokio_postgres::types::Type> =
stmt.columns().iter().map(|c| c.type_().clone()).collect();

if column_types.is_empty() {
return Err(RustPSQLDriverError::PyToRustValueConversionError(
"Cannot introspect column types from target table".into(),
));
}

let typed_rows: Vec<Vec<PythonDTO>> =
Python::with_gil(|gil| -> PSQLPyResult<Vec<Vec<PythonDTO>>> {
let mut typed: Vec<Vec<PythonDTO>> = Vec::with_capacity(raw_records.len());
for (row_idx, row) in raw_records.iter().enumerate() {
if row.len() != column_types.len() {
return Err(RustPSQLDriverError::PyToRustValueConversionError(
format!(
"Record at index {} has {} fields, expected {}",
row_idx,
row.len(),
column_types.len()
),
));
}
let mut row_dto: Vec<PythonDTO> = Vec::with_capacity(row.len());
for (cell, ty) in row.iter().zip(column_types.iter()) {
row_dto.push(from_python_typed(cell.bind(gil), ty)?);
}
typed.push(row_dto);
}
Ok(typed)
})?;

let copy_qs = match &columns_sql {
Some(cols) => format!(
"COPY {}({}) FROM STDIN (FORMAT binary)",
full_table_name, cols
),
None => format!("COPY {} FROM STDIN (FORMAT binary)", full_table_name),
};

let sink = read_conn_g.copy_in(&copy_qs).await?;
let writer = BinaryCopyInWriter::new(sink, &column_types);
pin_mut!(writer);

for row in &typed_rows {
let row_refs: Vec<&(dyn ToSql + Sync)> =
row.iter().map(|v| v as &(dyn ToSql + Sync)).collect();
writer.as_mut().write(&row_refs).await?;
}

let rows_created = writer.as_mut().finish().await?;

Ok(rows_created)
}
}
};
}

impl_copy_records_method!(Connection);
impl_copy_records_method!(Transaction);
Loading