diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py index 8fc4fa5cc0cc7..66af57533f9c9 100644 --- a/python/pyspark/sql/conversion.py +++ b/python/pyspark/sql/conversion.py @@ -188,6 +188,7 @@ def to_pandas( ndarray_as_list: bool = False, prefer_int_ext_dtype: bool = False, df_for_struct: bool = False, + arrow_dtype_types: Optional[tuple] = None, ) -> List[Union["pd.Series", "pd.DataFrame"]]: """ Convert a RecordBatch or Table to a list of pandas Series. @@ -208,6 +209,10 @@ def to_pandas( Whether to convert integers to Pandas ExtensionDType. df_for_struct : bool If True, convert struct columns to DataFrame instead of Series. + arrow_dtype_types : tuple of DataType classes, optional + If provided, columns whose Spark type matches one of these classes will be + converted via convert_pyarrow (ArrowDtype-backed). Unsupported types fall + through to convert_numpy/convert_legacy. Default is None (disabled). Returns ------- @@ -232,6 +237,7 @@ def to_pandas( ndarray_as_list=ndarray_as_list, prefer_int_ext_dtype=prefer_int_ext_dtype, df_for_struct=df_for_struct, + arrow_dtype_types=arrow_dtype_types, ) for i in range(batch.num_columns) ] @@ -1459,6 +1465,29 @@ class ArrowArrayToPandasConversion: where Arrow data needs to be converted to pandas for Python UDF processing. """ + # Types supported by convert_pyarrow (ArrowDtype-backed pandas Series). + # This tuple controls which types are routed to the pyarrow path when + # arrow_cast is enabled. Expand as more types are supported. + ARROW_DTYPE_TYPES = ( + NullType, + BinaryType, + BooleanType, + FloatType, + DoubleType, + ByteType, + ShortType, + IntegerType, + LongType, + DecimalType, + StringType, + DateType, + TimeType, + TimestampType, + TimestampNTZType, + DayTimeIntervalType, + YearMonthIntervalType, + ) + @classmethod def convert( cls, @@ -1471,6 +1500,7 @@ def convert( ndarray_as_list: bool = False, prefer_int_ext_dtype: bool = False, df_for_struct: bool = False, + arrow_dtype_types: Optional[tuple] = None, ) -> Union["pd.Series", "pd.DataFrame"]: """ Convert a PyArrow Array or ChunkedArray to a pandas Series or DataFrame. @@ -1495,6 +1525,10 @@ def convert( df_for_struct : bool, optional If True, convert struct columns to a DataFrame with columns corresponding to struct fields instead of a Series. Default is False. + arrow_dtype_types : tuple of DataType classes, optional + If provided, columns whose Spark type matches one of these classes will be + converted via convert_pyarrow (ArrowDtype-backed). Unsupported types fall + through to convert_numpy/convert_legacy. Default is None (disabled). Returns ------- @@ -1502,6 +1536,13 @@ def convert( Converted pandas Series. If df_for_struct is True and the type is StructType, returns a DataFrame with columns corresponding to struct fields. """ + if arrow_dtype_types is not None and isinstance(spark_type, arrow_dtype_types): + return cls.convert_pyarrow( + arr, + spark_type, + ser_name=ser_name, + ) + if cls._prefer_convert_numpy(spark_type, df_for_struct): return cls.convert_numpy( arr, @@ -1780,3 +1821,89 @@ def convert_numpy( assert False, f"Need converter for {spark_type} but failed to find one." return series.rename(ser_name) + + @classmethod + def convert_pyarrow( + cls, + arr: Union["pa.Array", "pa.ChunkedArray"], + spark_type: DataType, + *, + ser_name: Optional[str] = None, + ) -> "pd.Series": + """ + Convert a PyArrow Array or ChunkedArray to a pandas Series backed by ArrowDtype. + + This is similar to :meth:`convert_numpy`, but instead of producing + numpy-backed pandas Series, it produces ArrowDtype-backed Series via + ``arr.to_pandas(types_mapper=pd.ArrowDtype)``. + + Parameters + ---------- + arr : pa.Array or pa.ChunkedArray + The Arrow column to convert. + spark_type : DataType + The target Spark type for the column to be converted to. + ser_name : str, optional + The name of returned pd.Series. If not set, will try to get it from arr._name. + + Returns + ------- + pd.Series + Converted pandas Series backed by ArrowDtype. + """ + import pyarrow as pa + import pandas as pd + + assert isinstance(arr, (pa.Array, pa.ChunkedArray)) + + if ser_name is None: + ser_name = arr._name + + arr = ArrowArrayConversion.preprocess_time(arr) + + series: pd.Series + + if isinstance( + spark_type, + ( + NullType, + BinaryType, + BooleanType, + FloatType, + DoubleType, + ByteType, + ShortType, + IntegerType, + LongType, + DecimalType, + StringType, + DateType, + TimeType, + TimestampType, + TimestampNTZType, + DayTimeIntervalType, + YearMonthIntervalType, + ), + ): + series = arr.to_pandas(types_mapper=pd.ArrowDtype) + # elif isinstance(spark_type, UserDefinedType): + # TODO: Support UserDefinedType + # elif isinstance(spark_type, VariantType): + # TODO: Support VariantType + # elif isinstance(spark_type, GeographyType): + # TODO: Support GeographyType + # elif isinstance(spark_type, GeometryType): + # TODO: Support GeometryType + # elif isinstance( + # spark_type, + # ( + # ArrayType, + # MapType, + # StructType, + # ), + # ): + # TODO: Support complex types + else: # pragma: no cover + assert False, f"Need converter for {spark_type} but failed to find one." + + return series.rename(ser_name) diff --git a/python/pyspark/sql/pandas/conversion.py b/python/pyspark/sql/pandas/conversion.py index 3f5d68d10452e..622b48f9fd16e 100644 --- a/python/pyspark/sql/pandas/conversion.py +++ b/python/pyspark/sql/pandas/conversion.py @@ -39,6 +39,7 @@ ArrayType, MapType, TimestampType, + StructField, StructType, _has_type, DataType, @@ -184,6 +185,7 @@ def _convert_arrow_table_to_pandas( struct_handling_mode: Optional[str] = None, date_as_object: bool = False, self_destruct: bool = False, + arrow_dtype: bool = False, ) -> "PandasDataFrameLike": """ Helper function to convert Arrow table columns to a pandas DataFrame. @@ -207,6 +209,9 @@ def _convert_arrow_table_to_pandas( Whether to convert date values to Python datetime.date objects (default: False) self_destruct : bool Whether to enable memory-efficient self-destruct mode for large tables (default: False) + arrow_dtype : bool + Whether to produce ArrowDtype-backed pandas Series for supported types + (default: False) Returns ------- @@ -254,23 +259,32 @@ def _convert_arrow_table_to_pandas( error_on_duplicated_field_names = True struct_handling_mode = "dict" - # Convert arrow columns to pandas Series - column_data = (arrow_col.to_pandas(**pandas_options) for arrow_col in arrow_table.columns) + if arrow_dtype: + from pyspark.sql.conversion import ArrowArrayToPandasConversion + + arrow_dtype_types = ArrowArrayToPandasConversion.ARROW_DTYPE_TYPES + + def _convert_column(arrow_col: "pa.ChunkedArray", field: "StructField") -> "pd.Series": + if arrow_dtype and isinstance(field.dataType, arrow_dtype_types): + return ArrowArrayToPandasConversion.convert_pyarrow( + arrow_col, field.dataType, ser_name=field.name + ) + series = arrow_col.to_pandas(**pandas_options) + return _create_converter_to_pandas( + field.dataType, + field.nullable, + timezone=timezone, + struct_in_pandas=struct_handling_mode, + error_on_duplicated_field_names=error_on_duplicated_field_names, + )(series) - # Apply Spark-specific type converters to each column pdf = pd.concat( objs=cast( Sequence[pd.Series], - ( - _create_converter_to_pandas( - field.dataType, - field.nullable, - timezone=timezone, - struct_in_pandas=struct_handling_mode, - error_on_duplicated_field_names=error_on_duplicated_field_names, - )(series) - for series, field in zip(column_data, schema.fields) - ), + [ + _convert_column(arrow_table.column(i), schema.fields[i]) + for i in range(len(schema.fields)) + ], ), axis="columns", ) @@ -306,6 +320,7 @@ def _to_pandas(self, **kwargs: Any) -> "PandasDataFrameLike": arrowPySparkFallbackEnabled, arrowPySparkSelfDestructEnabled, pandasStructHandlingMode, + arrowPySparkArrowDtypeEnabled, ) = self.sparkSession._jconf.getConfs( [ "spark.sql.session.timeZone", @@ -314,6 +329,7 @@ def _to_pandas(self, **kwargs: Any) -> "PandasDataFrameLike": "spark.sql.execution.arrow.pyspark.fallback.enabled", "spark.sql.execution.arrow.pyspark.selfDestruct.enabled", "spark.sql.execution.pandas.structHandlingMode", + "spark.sql.execution.arrow.pyspark.arrowDtype.enabled", ] ) @@ -386,6 +402,7 @@ def _to_pandas(self, **kwargs: Any) -> "PandasDataFrameLike": struct_handling_mode=pandasStructHandlingMode, date_as_object=True, self_destruct=arrowPySparkSelfDestructEnabled == "true", + arrow_dtype=arrowPySparkArrowDtypeEnabled == "true", ) return pdf diff --git a/python/pyspark/sql/tests/test_conversion.py b/python/pyspark/sql/tests/test_conversion.py index 304d8be740d41..0f494e83d89c4 100644 --- a/python/pyspark/sql/tests/test_conversion.py +++ b/python/pyspark/sql/tests/test_conversion.py @@ -30,8 +30,13 @@ from pyspark.sql.types import ( ArrayType, BinaryType, + BooleanType, + ByteType, + DateType, + DayTimeIntervalType, DecimalType, DoubleType, + FloatType, Geography, GeographyType, Geometry, @@ -41,13 +46,17 @@ MapType, NullType, Row, + ShortType, StringType, StructField, StructType, + TimeType, + TimestampNTZType, TimestampType, UserDefinedType, VariantType, VariantVal, + YearMonthIntervalType, ) from pyspark.testing.objects import ExamplePoint, ExamplePointUDT, PythonOnlyPoint, PythonOnlyUDT from pyspark.testing.utils import ( @@ -709,6 +718,138 @@ def test_variant_convert_numpy(self): ) self.assertEqual(len(result), 0) + def test_convert_pyarrow(self): + import pyarrow as pa + import pandas as pd + + from decimal import Decimal + + # Cases where input data equals expected output + cases = [ + ([None, None], pa.null(), NullType()), + ([b"\x01", None], pa.binary(), BinaryType()), + ([True, None, False], pa.bool_(), BooleanType()), + ([1.0, None], pa.float32(), FloatType()), + ([1.0, None], pa.float64(), DoubleType()), + ([1, None, 3], pa.int8(), ByteType()), + ([1, None, 3], pa.int16(), ShortType()), + ([1, None, 3], pa.int32(), IntegerType()), + ([1, None, 3], pa.int64(), LongType()), + ([Decimal("1.23"), None], pa.decimal128(10, 2), DecimalType(10, 2)), + (["a", None, "c"], pa.string(), StringType()), + ([1, None], pa.int32(), YearMonthIntervalType()), + ] + for data, arrow_type, spark_type in cases: + arr = pa.array(data, type=arrow_type) + result = ArrowArrayToPandasConversion.convert_pyarrow(arr, spark_type) + self.assertIsInstance(result.dtype, pd.ArrowDtype, f"Failed for {spark_type}") + for i, val in enumerate(data): + msg = f"Failed for {spark_type} at index {i}: expected {val}, got {result.iloc[i]}" + if val is None: + self.assertTrue(pd.isna(result.iloc[i]), msg) + else: + self.assertEqual(result.iloc[i], val, msg) + + def test_convert_pyarrow_temporal(self): + import pyarrow as pa + import pandas as pd + + cases = [ + ([1, None], pa.date32(), DateType(), [datetime.date(1970, 1, 2), None]), + ([1000000, None], pa.time64("us"), TimeType(), [datetime.time(0, 0, 1), None]), + ( + [1000000, None], + pa.timestamp("us", tz="UTC"), + TimestampType(), + [datetime.datetime(1970, 1, 1, 0, 0, 1), None], + ), + ( + [1000000, None], + pa.timestamp("us"), + TimestampNTZType(), + [datetime.datetime(1970, 1, 1, 0, 0, 1), None], + ), + ( + [1000000, None], + pa.duration("us"), + DayTimeIntervalType(), + [datetime.timedelta(seconds=1), None], + ), + ] + for data, arrow_type, spark_type, expected in cases: + arr = pa.array(data, type=arrow_type) + result = ArrowArrayToPandasConversion.convert_pyarrow(arr, spark_type) + self.assertIsInstance(result.dtype, pd.ArrowDtype, f"Failed for {spark_type}") + for i, exp in enumerate(expected): + msg = f"Failed for {spark_type} at index {i}: expected {exp}, got {result.iloc[i]}" + if exp is None: + self.assertTrue(pd.isna(result.iloc[i]), msg) + else: + self.assertEqual(result.iloc[i], exp, msg) + + def test_convert_pyarrow_ser_name(self): + import pyarrow as pa + import pandas as pd + + # explicit ser_name + arr = pa.array([1, 2, 3], type=pa.int64()) + result = ArrowArrayToPandasConversion.convert_pyarrow(arr, LongType(), ser_name="col") + self.assertEqual(result.name, "col") + self.assertIsInstance(result.dtype, pd.ArrowDtype) + + # default name from arrow array (set via RecordBatch column extraction) + batch = pa.record_batch({"my_col": [1, 2, 3]}) + arr = batch.column("my_col") + result = ArrowArrayToPandasConversion.convert_pyarrow(arr, LongType()) + self.assertEqual(result.name, "my_col") + + def test_convert_arrow_dtype_types(self): + """Test that arrow_dtype_types routes matching types to convert_pyarrow.""" + import pyarrow as pa + import pandas as pd + + arr = pa.array([1, 2, 3], type=pa.int64()) + + # With arrow_dtype_types including LongType: should get ArrowDtype + result = ArrowArrayToPandasConversion.convert( + arr, LongType(), arrow_dtype_types=(LongType,) + ) + self.assertIsInstance(result.dtype, pd.ArrowDtype) + + # With arrow_dtype_types not including LongType: should get numpy dtype + result = ArrowArrayToPandasConversion.convert( + arr, LongType(), arrow_dtype_types=(StringType,) + ) + self.assertNotIsInstance(result.dtype, pd.ArrowDtype) + + # With arrow_dtype_types=None (default): should get numpy dtype + result = ArrowArrayToPandasConversion.convert(arr, LongType()) + self.assertNotIsInstance(result.dtype, pd.ArrowDtype) + + def test_convert_arrow_table_to_pandas_arrow_dtype(self): + """Test _convert_arrow_table_to_pandas with arrow_dtype flag.""" + import pyarrow as pa + import pandas as pd + + from pyspark.sql.pandas.conversion import _convert_arrow_table_to_pandas + + table = pa.table({"a": [1, 2, 3], "b": ["x", "y", "z"]}) + schema = StructType([StructField("a", LongType()), StructField("b", StringType())]) + + # arrow_dtype=False: numpy-backed + pdf_numpy = _convert_arrow_table_to_pandas(table, schema, timezone="UTC", arrow_dtype=False) + self.assertNotIsInstance(pdf_numpy["a"].dtype, pd.ArrowDtype) + self.assertNotIsInstance(pdf_numpy["b"].dtype, pd.ArrowDtype) + + # arrow_dtype=True: ArrowDtype-backed for supported types + pdf_arrow = _convert_arrow_table_to_pandas(table, schema, timezone="UTC", arrow_dtype=True) + self.assertIsInstance(pdf_arrow["a"].dtype, pd.ArrowDtype) + self.assertIsInstance(pdf_arrow["b"].dtype, pd.ArrowDtype) + + # Values should be equal + self.assertEqual(pdf_numpy["a"].tolist(), pdf_arrow["a"].tolist()) + self.assertEqual(pdf_numpy["b"].tolist(), pdf_arrow["b"].tolist()) + def test_geography_convert_numpy(self): import pyarrow as pa diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 05fdb6bca9320..8c47af4ee88d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4183,6 +4183,19 @@ object SQLConf { .booleanConf .createWithDefault(false) + val ARROW_PYSPARK_ARROW_DTYPE_ENABLED = + buildConf("spark.sql.execution.arrow.pyspark.arrowDtype.enabled") + .doc("(Experimental) When true, use ArrowDtype-backed pandas Series in " + + "pyspark.sql.DataFrame.toPandas for supported data types. This keeps data in Arrow " + + "format without converting to numpy, which handles nulls natively via pd.NA and " + + "avoids type coercion issues. " + + "This optimization applies to: pyspark.sql.DataFrame.toPandas " + + "when 'spark.sql.execution.arrow.pyspark.enabled' is set.") + .version("4.2.0") + .withBindingPolicy(ConfigBindingPolicy.SESSION) + .booleanConf + .createWithDefault(false) + val PYSPARK_BINARY_AS_BYTES = buildConf("spark.sql.execution.pyspark.binaryAsBytes") .doc("When true, BinaryType is consistently mapped to bytes in PySpark. " + @@ -7975,6 +7988,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def arrowPySparkSelfDestructEnabled: Boolean = getConf(ARROW_PYSPARK_SELF_DESTRUCT_ENABLED) + def arrowPySparkArrowDtypeEnabled: Boolean = getConf(ARROW_PYSPARK_ARROW_DTYPE_ENABLED) + def pysparkBinaryAsBytes: Boolean = getConf(PYSPARK_BINARY_AS_BYTES) def pysparkToJSONReturnDataFrame: Boolean = getConf(PYSPARK_TOJSON_RETURN_DATAFRAME)