@@ -1299,66 +1299,32 @@ def test_make_time(df):
12991299 assert result .column (0 )[0 ].as_py () == time (12 , 30 )
13001300
13011301
1302- def test_arrow_cast (df ):
1303- df = df .select (
1304- f .arrow_cast (column ("b" ), "Float64" ).alias ("b_as_float" ),
1305- f .arrow_cast (column ("b" ), "Int32" ).alias ("b_as_int" ),
1306- )
1307- result = df .collect ()
1308- assert len (result ) == 1
1309- result = result [0 ]
1310-
1311- assert result .column (0 ) == pa .array ([4.0 , 5.0 , 6.0 ], type = pa .float64 ())
1312- assert result .column (1 ) == pa .array ([4 , 5 , 6 ], type = pa .int32 ())
1313-
1314-
1315- def test_arrow_cast_with_pyarrow_type (df ):
1316- df = df .select (
1317- f .arrow_cast (column ("b" ), pa .float64 ()).alias ("b_as_float" ),
1318- f .arrow_cast (column ("b" ), pa .int32 ()).alias ("b_as_int" ),
1319- f .arrow_cast (column ("b" ), pa .string ()).alias ("b_as_str" ),
1320- )
1321- result = df .collect ()[0 ]
1322-
1323- assert result .column (0 ) == pa .array ([4.0 , 5.0 , 6.0 ], type = pa .float64 ())
1324- assert result .column (1 ) == pa .array ([4 , 5 , 6 ], type = pa .int32 ())
1325- assert result .column (2 ) == pa .array (["4" , "5" , "6" ], type = pa .string ())
1326-
1327-
1328- def test_arrow_try_cast (df ):
1329- df = df .select (
1330- f .arrow_try_cast (column ("b" ), "Float64" ).alias ("b_as_float" ),
1331- f .arrow_try_cast (column ("b" ), "Int32" ).alias ("b_as_int" ),
1332- )
1333- result = df .collect ()[0 ]
1334-
1335- assert result .column (0 ) == pa .array ([4.0 , 5.0 , 6.0 ], type = pa .float64 ())
1336- assert result .column (1 ) == pa .array ([4 , 5 , 6 ], type = pa .int32 ())
1337-
1338-
1339- def test_arrow_try_cast_with_pyarrow_type (df ):
1340- df = df .select (
1341- f .arrow_try_cast (column ("b" ), pa .float64 ()).alias ("b_as_float" ),
1342- f .arrow_try_cast (column ("b" ), pa .int32 ()).alias ("b_as_int" ),
1343- )
1344- result = df .collect ()[0 ]
1345-
1346- assert result .column (0 ) == pa .array ([4.0 , 5.0 , 6.0 ], type = pa .float64 ())
1347- assert result .column (1 ) == pa .array ([4 , 5 , 6 ], type = pa .int32 ())
1302+ @pytest .mark .parametrize ("cast_fn" , [f .arrow_cast , f .arrow_try_cast ])
1303+ @pytest .mark .parametrize (
1304+ ("data_type" , "expected" ),
1305+ [
1306+ ("Float64" , pa .array ([4.0 , 5.0 , 6.0 ], type = pa .float64 ())),
1307+ ("Int32" , pa .array ([4 , 5 , 6 ], type = pa .int32 ())),
1308+ (pa .float64 (), pa .array ([4.0 , 5.0 , 6.0 ], type = pa .float64 ())),
1309+ (pa .int32 (), pa .array ([4 , 5 , 6 ], type = pa .int32 ())),
1310+ (pa .string (), pa .array (["4" , "5" , "6" ], type = pa .string ())),
1311+ ],
1312+ )
1313+ def test_arrow_cast_variants (df , cast_fn , data_type , expected ):
1314+ """arrow_cast / arrow_try_cast accept str and pyarrow target types."""
1315+ result = df .select (cast_fn (column ("b" ), data_type ).alias ("c" )).collect ()[0 ]
1316+ assert result .column (0 ) == expected
13481317
13491318
1350- def test_arrow_try_cast_null_on_failure ():
1319+ @pytest .mark .parametrize ("data_type" , ["Float64" , pa .float64 ()])
1320+ def test_arrow_try_cast_null_on_failure (data_type ):
13511321 ctx = SessionContext ()
13521322 batch = pa .RecordBatch .from_arrays ([pa .array (["1.5" , "oops" , "3" ])], names = ["s" ])
13531323 df = ctx .create_dataframe ([[batch ]])
13541324
1355- result = df .select (
1356- f .arrow_try_cast (column ("s" ), "Float64" ).alias ("c" ),
1357- f .arrow_try_cast (column ("s" ), pa .float64 ()).alias ("c_pa" ),
1358- ).collect ()[0 ]
1325+ result = df .select (f .arrow_try_cast (column ("s" ), data_type ).alias ("c" )).collect ()[0 ]
13591326
13601327 assert result .column (0 ).to_pylist () == [1.5 , None , 3.0 ]
1361- assert result .column (1 ).to_pylist () == [1.5 , None , 3.0 ]
13621328
13631329
13641330def test_arrow_field ():
@@ -1381,34 +1347,26 @@ def test_arrow_field():
13811347 }
13821348
13831349
1384- def test_cast_to_type ():
1385- ctx = SessionContext ()
1386- batch = pa .RecordBatch .from_arrays (
1387- [pa .array ([4 , 5 , 6 ]), pa .array ([1.0 , 2.0 , 3.0 ])],
1388- names = ["b" , "fl" ],
1389- )
1390- df = ctx .create_dataframe ([[batch ]])
1391-
1392- result = df .select (f .cast_to_type (column ("b" ), column ("fl" )).alias ("c" )).collect ()[
1393- 0
1394- ]
1395-
1396- assert result .column (0 ) == pa .array ([4.0 , 5.0 , 6.0 ], type = pa .float64 ())
1397-
1398-
1399- def test_cast_to_type_try_cast_null_on_failure ():
1350+ @pytest .mark .parametrize (
1351+ ("values" , "try_cast" , "expected" ),
1352+ [
1353+ (pa .array ([4 , 5 , 6 ]), False , [4.0 , 5.0 , 6.0 ]),
1354+ (pa .array (["oops" , "2" , "3" ]), True , [None , 2.0 , 3.0 ]),
1355+ ],
1356+ )
1357+ def test_cast_to_type (values , try_cast , expected ):
1358+ """cast_to_type takes target type from ``type_ref``; try_cast nullifies failures."""
14001359 ctx = SessionContext ()
14011360 batch = pa .RecordBatch .from_arrays (
1402- [pa .array (["oops" , "2" , "3" ]), pa .array ([1.0 , 2.0 , 3.0 ])],
1403- names = ["a" , "fl" ],
1361+ [values , pa .array ([1.0 , 2.0 , 3.0 ])], names = ["v" , "fl" ]
14041362 )
14051363 df = ctx .create_dataframe ([[batch ]])
14061364
14071365 result = df .select (
1408- f .cast_to_type (column ("a " ), column ("fl" ), try_cast = True ).alias ("c" )
1366+ f .cast_to_type (column ("v " ), column ("fl" ), try_cast = try_cast ).alias ("c" )
14091367 ).collect ()[0 ]
14101368
1411- assert result .column (0 ).to_pylist () == [ None , 2.0 , 3.0 ]
1369+ assert result .column (0 ).to_pylist () == expected
14121370 assert result .column (0 ).type == pa .float64 ()
14131371
14141372
@@ -1425,7 +1383,7 @@ def test_with_metadata_empty_dict_noop(df):
14251383 assert out .column (0 ) == pa .array ([4 , 5 , 6 ])
14261384
14271385
1428- def test_with_metadata_empty_key_raises (df ):
1386+ def test_with_metadata_empty_key_raises ():
14291387 with pytest .raises (ValueError , match = "non-empty" ):
14301388 f .with_metadata (column ("b" ), {"" : "v" })
14311389
0 commit comments