|
30 | 30 | from pyspark.sql.types import ( |
31 | 31 | ArrayType, |
32 | 32 | BinaryType, |
| 33 | + BooleanType, |
| 34 | + ByteType, |
| 35 | + DateType, |
| 36 | + DayTimeIntervalType, |
33 | 37 | DecimalType, |
34 | 38 | DoubleType, |
| 39 | + FloatType, |
35 | 40 | GeographyType, |
36 | 41 | GeometryType, |
37 | 42 | IntegerType, |
38 | 43 | LongType, |
39 | 44 | MapType, |
40 | 45 | NullType, |
41 | 46 | Row, |
| 47 | + ShortType, |
42 | 48 | StringType, |
43 | 49 | StructField, |
44 | 50 | StructType, |
| 51 | + TimeType, |
| 52 | + TimestampNTZType, |
45 | 53 | TimestampType, |
46 | 54 | UserDefinedType, |
47 | 55 | VariantType, |
48 | 56 | VariantVal, |
| 57 | + YearMonthIntervalType, |
49 | 58 | ) |
50 | 59 | from pyspark.testing.objects import ExamplePoint, ExamplePointUDT, PythonOnlyPoint, PythonOnlyUDT |
51 | 60 | from pyspark.testing.utils import ( |
@@ -656,6 +665,91 @@ def test_variant_convert_numpy(self): |
656 | 665 | ) |
657 | 666 | self.assertEqual(len(result), 0) |
658 | 667 |
|
| 668 | + def test_convert_pyarrow(self): |
| 669 | + import pyarrow as pa |
| 670 | + import pandas as pd |
| 671 | + |
| 672 | + from decimal import Decimal |
| 673 | + |
| 674 | + # Cases where input data equals expected output |
| 675 | + cases = [ |
| 676 | + ([None, None], pa.null(), NullType()), |
| 677 | + ([b"\x01", None], pa.binary(), BinaryType()), |
| 678 | + ([True, None, False], pa.bool_(), BooleanType()), |
| 679 | + ([1.0, None], pa.float32(), FloatType()), |
| 680 | + ([1.0, None], pa.float64(), DoubleType()), |
| 681 | + ([1, None, 3], pa.int8(), ByteType()), |
| 682 | + ([1, None, 3], pa.int16(), ShortType()), |
| 683 | + ([1, None, 3], pa.int32(), IntegerType()), |
| 684 | + ([1, None, 3], pa.int64(), LongType()), |
| 685 | + ([Decimal("1.23"), None], pa.decimal128(10, 2), DecimalType(10, 2)), |
| 686 | + (["a", None, "c"], pa.string(), StringType()), |
| 687 | + ([1, None], pa.int32(), YearMonthIntervalType()), |
| 688 | + ] |
| 689 | + for data, arrow_type, spark_type in cases: |
| 690 | + arr = pa.array(data, type=arrow_type) |
| 691 | + result = ArrowArrayToPandasConversion.convert_pyarrow(arr, spark_type) |
| 692 | + self.assertIsInstance(result.dtype, pd.ArrowDtype, f"Failed for {spark_type}") |
| 693 | + for i, val in enumerate(data): |
| 694 | + msg = f"Failed for {spark_type} at index {i}: expected {val}, got {result.iloc[i]}" |
| 695 | + if val is None: |
| 696 | + self.assertTrue(pd.isna(result.iloc[i]), msg) |
| 697 | + else: |
| 698 | + self.assertEqual(result.iloc[i], val, msg) |
| 699 | + |
| 700 | + def test_convert_pyarrow_temporal(self): |
| 701 | + import pyarrow as pa |
| 702 | + import pandas as pd |
| 703 | + |
| 704 | + cases = [ |
| 705 | + ([1, None], pa.date32(), DateType(), [datetime.date(1970, 1, 2), None]), |
| 706 | + ([1000000, None], pa.time64("us"), TimeType(), [datetime.time(0, 0, 1), None]), |
| 707 | + ( |
| 708 | + [1000000, None], |
| 709 | + pa.timestamp("us", tz="UTC"), |
| 710 | + TimestampType(), |
| 711 | + [datetime.datetime(1970, 1, 1, 0, 0, 1), None], |
| 712 | + ), |
| 713 | + ( |
| 714 | + [1000000, None], |
| 715 | + pa.timestamp("us"), |
| 716 | + TimestampNTZType(), |
| 717 | + [datetime.datetime(1970, 1, 1, 0, 0, 1), None], |
| 718 | + ), |
| 719 | + ( |
| 720 | + [1000000, None], |
| 721 | + pa.duration("us"), |
| 722 | + DayTimeIntervalType(), |
| 723 | + [datetime.timedelta(seconds=1), None], |
| 724 | + ), |
| 725 | + ] |
| 726 | + for data, arrow_type, spark_type, expected in cases: |
| 727 | + arr = pa.array(data, type=arrow_type) |
| 728 | + result = ArrowArrayToPandasConversion.convert_pyarrow(arr, spark_type) |
| 729 | + self.assertIsInstance(result.dtype, pd.ArrowDtype, f"Failed for {spark_type}") |
| 730 | + for i, exp in enumerate(expected): |
| 731 | + msg = f"Failed for {spark_type} at index {i}: expected {exp}, got {result.iloc[i]}" |
| 732 | + if exp is None: |
| 733 | + self.assertTrue(pd.isna(result.iloc[i]), msg) |
| 734 | + else: |
| 735 | + self.assertEqual(result.iloc[i], exp, msg) |
| 736 | + |
| 737 | + def test_convert_pyarrow_ser_name(self): |
| 738 | + import pyarrow as pa |
| 739 | + import pandas as pd |
| 740 | + |
| 741 | + # explicit ser_name |
| 742 | + arr = pa.array([1, 2, 3], type=pa.int64()) |
| 743 | + result = ArrowArrayToPandasConversion.convert_pyarrow(arr, LongType(), ser_name="col") |
| 744 | + self.assertEqual(result.name, "col") |
| 745 | + self.assertIsInstance(result.dtype, pd.ArrowDtype) |
| 746 | + |
| 747 | + # default name from arrow array (set via RecordBatch column extraction) |
| 748 | + batch = pa.record_batch({"my_col": [1, 2, 3]}) |
| 749 | + arr = batch.column("my_col") |
| 750 | + result = ArrowArrayToPandasConversion.convert_pyarrow(arr, LongType()) |
| 751 | + self.assertEqual(result.name, "my_col") |
| 752 | + |
659 | 753 |
|
660 | 754 | if __name__ == "__main__": |
661 | 755 | from pyspark.testing import main |
|
0 commit comments