Skip to content

Commit ce83065

Browse files
timsaucerclaude
andcommitted
test: parameterize arrow cast / try_cast tests
Folds the previous four cast tests (arrow_cast + arrow_try_cast × str + pyarrow target type) into a single parameterized test that runs both functions across all five target-type variants. Collapses the two cast_to_type tests (happy path + try_cast=True) into one parameterized test, and parameterizes arrow_try_cast null-on-failure over both target-type syntaxes. 7 test functions, 19 cases — net less code, same coverage. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 7d8a435 commit ce83065

1 file changed

Lines changed: 31 additions & 73 deletions

File tree

python/tests/test_functions.py

Lines changed: 31 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1299,66 +1299,32 @@ def test_make_time(df):
12991299
assert result.column(0)[0].as_py() == time(12, 30)
13001300

13011301

1302-
def test_arrow_cast(df):
1303-
df = df.select(
1304-
f.arrow_cast(column("b"), "Float64").alias("b_as_float"),
1305-
f.arrow_cast(column("b"), "Int32").alias("b_as_int"),
1306-
)
1307-
result = df.collect()
1308-
assert len(result) == 1
1309-
result = result[0]
1310-
1311-
assert result.column(0) == pa.array([4.0, 5.0, 6.0], type=pa.float64())
1312-
assert result.column(1) == pa.array([4, 5, 6], type=pa.int32())
1313-
1314-
1315-
def test_arrow_cast_with_pyarrow_type(df):
1316-
df = df.select(
1317-
f.arrow_cast(column("b"), pa.float64()).alias("b_as_float"),
1318-
f.arrow_cast(column("b"), pa.int32()).alias("b_as_int"),
1319-
f.arrow_cast(column("b"), pa.string()).alias("b_as_str"),
1320-
)
1321-
result = df.collect()[0]
1322-
1323-
assert result.column(0) == pa.array([4.0, 5.0, 6.0], type=pa.float64())
1324-
assert result.column(1) == pa.array([4, 5, 6], type=pa.int32())
1325-
assert result.column(2) == pa.array(["4", "5", "6"], type=pa.string())
1326-
1327-
1328-
def test_arrow_try_cast(df):
1329-
df = df.select(
1330-
f.arrow_try_cast(column("b"), "Float64").alias("b_as_float"),
1331-
f.arrow_try_cast(column("b"), "Int32").alias("b_as_int"),
1332-
)
1333-
result = df.collect()[0]
1334-
1335-
assert result.column(0) == pa.array([4.0, 5.0, 6.0], type=pa.float64())
1336-
assert result.column(1) == pa.array([4, 5, 6], type=pa.int32())
1337-
1338-
1339-
def test_arrow_try_cast_with_pyarrow_type(df):
1340-
df = df.select(
1341-
f.arrow_try_cast(column("b"), pa.float64()).alias("b_as_float"),
1342-
f.arrow_try_cast(column("b"), pa.int32()).alias("b_as_int"),
1343-
)
1344-
result = df.collect()[0]
1345-
1346-
assert result.column(0) == pa.array([4.0, 5.0, 6.0], type=pa.float64())
1347-
assert result.column(1) == pa.array([4, 5, 6], type=pa.int32())
1302+
@pytest.mark.parametrize("cast_fn", [f.arrow_cast, f.arrow_try_cast])
1303+
@pytest.mark.parametrize(
1304+
("data_type", "expected"),
1305+
[
1306+
("Float64", pa.array([4.0, 5.0, 6.0], type=pa.float64())),
1307+
("Int32", pa.array([4, 5, 6], type=pa.int32())),
1308+
(pa.float64(), pa.array([4.0, 5.0, 6.0], type=pa.float64())),
1309+
(pa.int32(), pa.array([4, 5, 6], type=pa.int32())),
1310+
(pa.string(), pa.array(["4", "5", "6"], type=pa.string())),
1311+
],
1312+
)
1313+
def test_arrow_cast_variants(df, cast_fn, data_type, expected):
1314+
"""arrow_cast / arrow_try_cast accept str and pyarrow target types."""
1315+
result = df.select(cast_fn(column("b"), data_type).alias("c")).collect()[0]
1316+
assert result.column(0) == expected
13481317

13491318

1350-
def test_arrow_try_cast_null_on_failure():
1319+
@pytest.mark.parametrize("data_type", ["Float64", pa.float64()])
1320+
def test_arrow_try_cast_null_on_failure(data_type):
13511321
ctx = SessionContext()
13521322
batch = pa.RecordBatch.from_arrays([pa.array(["1.5", "oops", "3"])], names=["s"])
13531323
df = ctx.create_dataframe([[batch]])
13541324

1355-
result = df.select(
1356-
f.arrow_try_cast(column("s"), "Float64").alias("c"),
1357-
f.arrow_try_cast(column("s"), pa.float64()).alias("c_pa"),
1358-
).collect()[0]
1325+
result = df.select(f.arrow_try_cast(column("s"), data_type).alias("c")).collect()[0]
13591326

13601327
assert result.column(0).to_pylist() == [1.5, None, 3.0]
1361-
assert result.column(1).to_pylist() == [1.5, None, 3.0]
13621328

13631329

13641330
def test_arrow_field():
@@ -1381,34 +1347,26 @@ def test_arrow_field():
13811347
}
13821348

13831349

1384-
def test_cast_to_type():
1385-
ctx = SessionContext()
1386-
batch = pa.RecordBatch.from_arrays(
1387-
[pa.array([4, 5, 6]), pa.array([1.0, 2.0, 3.0])],
1388-
names=["b", "fl"],
1389-
)
1390-
df = ctx.create_dataframe([[batch]])
1391-
1392-
result = df.select(f.cast_to_type(column("b"), column("fl")).alias("c")).collect()[
1393-
0
1394-
]
1395-
1396-
assert result.column(0) == pa.array([4.0, 5.0, 6.0], type=pa.float64())
1397-
1398-
1399-
def test_cast_to_type_try_cast_null_on_failure():
1350+
@pytest.mark.parametrize(
1351+
("values", "try_cast", "expected"),
1352+
[
1353+
(pa.array([4, 5, 6]), False, [4.0, 5.0, 6.0]),
1354+
(pa.array(["oops", "2", "3"]), True, [None, 2.0, 3.0]),
1355+
],
1356+
)
1357+
def test_cast_to_type(values, try_cast, expected):
1358+
"""cast_to_type takes target type from ``type_ref``; try_cast nullifies failures."""
14001359
ctx = SessionContext()
14011360
batch = pa.RecordBatch.from_arrays(
1402-
[pa.array(["oops", "2", "3"]), pa.array([1.0, 2.0, 3.0])],
1403-
names=["a", "fl"],
1361+
[values, pa.array([1.0, 2.0, 3.0])], names=["v", "fl"]
14041362
)
14051363
df = ctx.create_dataframe([[batch]])
14061364

14071365
result = df.select(
1408-
f.cast_to_type(column("a"), column("fl"), try_cast=True).alias("c")
1366+
f.cast_to_type(column("v"), column("fl"), try_cast=try_cast).alias("c")
14091367
).collect()[0]
14101368

1411-
assert result.column(0).to_pylist() == [None, 2.0, 3.0]
1369+
assert result.column(0).to_pylist() == expected
14121370
assert result.column(0).type == pa.float64()
14131371

14141372

@@ -1425,7 +1383,7 @@ def test_with_metadata_empty_dict_noop(df):
14251383
assert out.column(0) == pa.array([4, 5, 6])
14261384

14271385

1428-
def test_with_metadata_empty_key_raises(df):
1386+
def test_with_metadata_empty_key_raises():
14291387
with pytest.raises(ValueError, match="non-empty"):
14301388
f.with_metadata(column("b"), {"": "v"})
14311389

0 commit comments

Comments
 (0)