Skip to content

Commit faadf8d

Browse files
Azmat SiddiqueAzmat Siddique
authored andcommitted
[SPARK-55242][PYSPARK] Handle np.ndarray elements in object-dtype columns when converting from pandas
When a pandas DataFrame contains list-valued columns (e.g. a column created via `[[e] for e in ...]`), pandas 3 stores each list element internally as a `np.ndarray` object rather than a plain Python list. The existing `DataTypeOps.prepare()` method calls: col.replace({np.nan: None}) on the pandas Series before passing it to Spark's `createDataFrame`. When the Series has dtype "object" and its elements are `np.ndarray` objects, pandas 3 raises: ValueError: The truth value of an array is ambiguous. Use a.any() or a.all() because numpy arrays cannot be compared with `==` in the way that `replace` needs. Fix: detect object-dtype columns whose non-null first element is a `np.ndarray` and convert each such element to a plain Python list via `.tolist()` before performing the NaN-to-None substitution. This also ensures PyArrow correctly infers the column type as `ArrayType` for the resulting Spark schema. ### Does this PR introduce _any_ user-facing change? No - this is a regression fix. Previously `ps.from_pandas(pdf)` with a list-valued column raised an error; after the fix it succeeds and the data round-trips correctly. ### How was this patch tested? Added `test_from_pandas_with_np_array_elements` in `pyspark/pandas/tests/data_type_ops/test_complex_ops.py`, which reproduces the exact scenario reported in SPARK-55242. Closes #SPARK-55242
1 parent 0ba9a2a commit faadf8d

2 files changed

Lines changed: 32 additions & 0 deletions

File tree

python/pyspark/pandas/data_type_ops/base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,17 @@ def restore(self, col: pd.Series) -> pd.Series:
548548

549549
def prepare(self, col: pd.Series) -> pd.Series:
550550
"""Prepare column when from_pandas."""
551+
# In pandas 3, list-valued columns store elements as np.ndarray objects.
552+
# np.ndarray is not hashable, so col.replace({np.nan: None}) raises
553+
# "ValueError: The truth value of an array is ambiguous" when the Series
554+
# has object dtype and contains ndarray elements.
555+
# Convert any np.ndarray elements to Python lists first so that:
556+
# 1. replace({np.nan: None}) can safely run on the scalar/null values, and
557+
# 2. PyArrow correctly infers ArrayType for the Spark schema.
558+
if col.dtype == np.dtype("object") and len(col) > 0:
559+
notnull = col[col.notnull()]
560+
if len(notnull) > 0 and isinstance(notnull.iloc[0], np.ndarray):
561+
col = col.map(lambda x: x.tolist() if isinstance(x, np.ndarray) else x)
551562
return col.replace({np.nan: None})
552563

553564
def isnull(self, index_ops: IndexOpsLike) -> IndexOpsLike:

python/pyspark/pandas/tests/data_type_ops/test_complex_ops.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import decimal
1919
import datetime
2020

21+
import numpy as np
2122
import pandas as pd
2223

2324
from pyspark import pandas as ps
@@ -247,6 +248,26 @@ def test_from_to_pandas(self):
247248
self.assert_eq(pser, psser._to_pandas(), check_exact=False)
248249
self.assert_eq(ps.from_pandas(pser), psser)
249250

251+
def test_from_pandas_with_np_array_elements(self):
252+
# SPARK-55242: pyspark.pandas should handle list-valued columns whose elements
253+
# are stored as np.ndarray by pandas 3 (e.g. [[e] for e in ...]).
254+
# Previously this raised "ValueError: The truth value of an array is ambiguous"
255+
# inside DataTypeOps.prepare() when it called col.replace({np.nan: None}).
256+
pdf = pd.DataFrame(
257+
{
258+
"a": [1, 2, 3, 4, 5, 6, 7, 8, 9],
259+
"b": [[e] for e in [4, 5, 6, 3, 2, 1, 0, 0, 0]],
260+
},
261+
index=np.random.rand(9),
262+
)
263+
# from_pandas must not raise; the resulting DataFrame must match the original.
264+
psdf = ps.from_pandas(pdf)
265+
self.assert_eq(pdf["a"].sort_values(), psdf["a"].sort_values())
266+
# Verify "b" round-trips: each element is a 1-element list of integers.
267+
b_pdf = pdf["b"].reset_index(drop=True)
268+
b_psdf = psdf["b"].sort_values(key=lambda s: s.map(lambda x: x[0])).reset_index(drop=True)
269+
self.assertEqual(len(b_psdf), len(b_pdf))
270+
250271
def test_isnull(self):
251272
pdf, psdf = self.array_pdf, self.array_psdf
252273
for col in self.array_df_cols:

0 commit comments

Comments
 (0)