Skip to content

Commit eb0ebb8

Browse files
committed
[SPARK-55325][PYTHON] Introduce ArrowArrayToPandasConversion.convert_pyarrow
1 parent 3665506 commit eb0ebb8

2 files changed

Lines changed: 180 additions & 0 deletions

File tree

python/pyspark/sql/conversion.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1750,3 +1750,89 @@ def convert_numpy(
17501750
assert False, f"Need converter for {spark_type} but failed to find one."
17511751

17521752
return series.rename(ser_name)
1753+
1754+
@classmethod
1755+
def convert_pyarrow(
1756+
cls,
1757+
arr: Union["pa.Array", "pa.ChunkedArray"],
1758+
spark_type: DataType,
1759+
*,
1760+
ser_name: Optional[str] = None,
1761+
) -> "pd.Series":
1762+
"""
1763+
Convert a PyArrow Array or ChunkedArray to a pandas Series backed by ArrowDtype.
1764+
1765+
This is similar to :meth:`convert_numpy`, but instead of producing
1766+
numpy-backed pandas Series, it produces ArrowDtype-backed Series via
1767+
``arr.to_pandas(types_mapper=pd.ArrowDtype)``.
1768+
1769+
Parameters
1770+
----------
1771+
arr : pa.Array or pa.ChunkedArray
1772+
The Arrow column to convert.
1773+
spark_type : DataType
1774+
The target Spark type for the column to be converted to.
1775+
ser_name : str, optional
1776+
The name of returned pd.Series. If not set, will try to get it from arr._name.
1777+
1778+
Returns
1779+
-------
1780+
pd.Series
1781+
Converted pandas Series backed by ArrowDtype.
1782+
"""
1783+
import pyarrow as pa
1784+
import pandas as pd
1785+
1786+
assert isinstance(arr, (pa.Array, pa.ChunkedArray))
1787+
1788+
if ser_name is None:
1789+
ser_name = arr._name
1790+
1791+
arr = ArrowArrayConversion.preprocess_time(arr)
1792+
1793+
series: pd.Series
1794+
1795+
if isinstance(
1796+
spark_type,
1797+
(
1798+
NullType,
1799+
BinaryType,
1800+
BooleanType,
1801+
FloatType,
1802+
DoubleType,
1803+
ByteType,
1804+
ShortType,
1805+
IntegerType,
1806+
LongType,
1807+
DecimalType,
1808+
StringType,
1809+
DateType,
1810+
TimeType,
1811+
TimestampType,
1812+
TimestampNTZType,
1813+
DayTimeIntervalType,
1814+
YearMonthIntervalType,
1815+
),
1816+
):
1817+
series = arr.to_pandas(types_mapper=pd.ArrowDtype)
1818+
# elif isinstance(spark_type, UserDefinedType):
1819+
# TODO: Support UserDefinedType
1820+
# elif isinstance(spark_type, VariantType):
1821+
# TODO: Support VariantType
1822+
# elif isinstance(spark_type, GeographyType):
1823+
# TODO: Support GeographyType
1824+
# elif isinstance(spark_type, GeometryType):
1825+
# TODO: Support GeometryType
1826+
# elif isinstance(
1827+
# spark_type,
1828+
# (
1829+
# ArrayType,
1830+
# MapType,
1831+
# StructType,
1832+
# ),
1833+
# ):
1834+
# TODO: Support complex types
1835+
else: # pragma: no cover
1836+
assert False, f"Need converter for {spark_type} but failed to find one."
1837+
1838+
return series.rename(ser_name)

python/pyspark/sql/tests/test_conversion.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,22 +30,31 @@
3030
from pyspark.sql.types import (
3131
ArrayType,
3232
BinaryType,
33+
BooleanType,
34+
ByteType,
35+
DateType,
36+
DayTimeIntervalType,
3337
DecimalType,
3438
DoubleType,
39+
FloatType,
3540
GeographyType,
3641
GeometryType,
3742
IntegerType,
3843
LongType,
3944
MapType,
4045
NullType,
4146
Row,
47+
ShortType,
4248
StringType,
4349
StructField,
4450
StructType,
51+
TimeType,
52+
TimestampNTZType,
4553
TimestampType,
4654
UserDefinedType,
4755
VariantType,
4856
VariantVal,
57+
YearMonthIntervalType,
4958
)
5059
from pyspark.testing.objects import ExamplePoint, ExamplePointUDT, PythonOnlyPoint, PythonOnlyUDT
5160
from pyspark.testing.utils import (
@@ -656,6 +665,91 @@ def test_variant_convert_numpy(self):
656665
)
657666
self.assertEqual(len(result), 0)
658667

668+
def test_convert_pyarrow(self):
669+
import pyarrow as pa
670+
import pandas as pd
671+
672+
from decimal import Decimal
673+
674+
# Cases where input data equals expected output
675+
cases = [
676+
([None, None], pa.null(), NullType()),
677+
([b"\x01", None], pa.binary(), BinaryType()),
678+
([True, None, False], pa.bool_(), BooleanType()),
679+
([1.0, None], pa.float32(), FloatType()),
680+
([1.0, None], pa.float64(), DoubleType()),
681+
([1, None, 3], pa.int8(), ByteType()),
682+
([1, None, 3], pa.int16(), ShortType()),
683+
([1, None, 3], pa.int32(), IntegerType()),
684+
([1, None, 3], pa.int64(), LongType()),
685+
([Decimal("1.23"), None], pa.decimal128(10, 2), DecimalType(10, 2)),
686+
(["a", None, "c"], pa.string(), StringType()),
687+
([1, None], pa.int32(), YearMonthIntervalType()),
688+
]
689+
for data, arrow_type, spark_type in cases:
690+
arr = pa.array(data, type=arrow_type)
691+
result = ArrowArrayToPandasConversion.convert_pyarrow(arr, spark_type)
692+
self.assertIsInstance(result.dtype, pd.ArrowDtype, f"Failed for {spark_type}")
693+
for i, val in enumerate(data):
694+
msg = f"Failed for {spark_type} at index {i}: expected {val}, got {result.iloc[i]}"
695+
if val is None:
696+
self.assertTrue(pd.isna(result.iloc[i]), msg)
697+
else:
698+
self.assertEqual(result.iloc[i], val, msg)
699+
700+
def test_convert_pyarrow_temporal(self):
701+
import pyarrow as pa
702+
import pandas as pd
703+
704+
cases = [
705+
([1, None], pa.date32(), DateType(), [datetime.date(1970, 1, 2), None]),
706+
([1000000, None], pa.time64("us"), TimeType(), [datetime.time(0, 0, 1), None]),
707+
(
708+
[1000000, None],
709+
pa.timestamp("us", tz="UTC"),
710+
TimestampType(),
711+
[datetime.datetime(1970, 1, 1, 0, 0, 1), None],
712+
),
713+
(
714+
[1000000, None],
715+
pa.timestamp("us"),
716+
TimestampNTZType(),
717+
[datetime.datetime(1970, 1, 1, 0, 0, 1), None],
718+
),
719+
(
720+
[1000000, None],
721+
pa.duration("us"),
722+
DayTimeIntervalType(),
723+
[datetime.timedelta(seconds=1), None],
724+
),
725+
]
726+
for data, arrow_type, spark_type, expected in cases:
727+
arr = pa.array(data, type=arrow_type)
728+
result = ArrowArrayToPandasConversion.convert_pyarrow(arr, spark_type)
729+
self.assertIsInstance(result.dtype, pd.ArrowDtype, f"Failed for {spark_type}")
730+
for i, exp in enumerate(expected):
731+
msg = f"Failed for {spark_type} at index {i}: expected {exp}, got {result.iloc[i]}"
732+
if exp is None:
733+
self.assertTrue(pd.isna(result.iloc[i]), msg)
734+
else:
735+
self.assertEqual(result.iloc[i], exp, msg)
736+
737+
def test_convert_pyarrow_ser_name(self):
738+
import pyarrow as pa
739+
import pandas as pd
740+
741+
# explicit ser_name
742+
arr = pa.array([1, 2, 3], type=pa.int64())
743+
result = ArrowArrayToPandasConversion.convert_pyarrow(arr, LongType(), ser_name="col")
744+
self.assertEqual(result.name, "col")
745+
self.assertIsInstance(result.dtype, pd.ArrowDtype)
746+
747+
# default name from arrow array (set via RecordBatch column extraction)
748+
batch = pa.record_batch({"my_col": [1, 2, 3]})
749+
arr = batch.column("my_col")
750+
result = ArrowArrayToPandasConversion.convert_pyarrow(arr, LongType())
751+
self.assertEqual(result.name, "my_col")
752+
659753

660754
if __name__ == "__main__":
661755
from pyspark.testing import main

0 commit comments

Comments
 (0)