Skip to content

Commit d9fa036

Browse files
fix: add all available model fields to ensure that multiple models don't trigger additional fields check for csv
restructured the reader tests as well
1 parent 46a4608 commit d9fa036

19 files changed

Lines changed: 436 additions & 445 deletions

File tree

src/dve/core_engine/backends/base/reader.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def read_to_py_iterator(
9090
resource: URI,
9191
entity_name: EntityName,
9292
schema: type[BaseModel],
93+
all_model_fields: Optional[set[str]] = None,
9394
) -> Iterator[dict[str, Any]]:
9495
"""Iterate through the contents of the resource, yielding dicts
9596
representing each record.
@@ -107,6 +108,7 @@ def read_to_entity_type(
107108
resource: URI,
108109
entity_name: EntityName,
109110
schema: type[BaseModel],
111+
all_model_fields: Optional[set[str]] = None,
110112
) -> EntityType:
111113
"""Read to the specified entity type, if supported.
112114
@@ -116,7 +118,12 @@ def read_to_entity_type(
116118
117119
"""
118120
if entity_name == Iterator[dict[str, Any]]:
119-
return self.read_to_py_iterator(resource, entity_name, schema) # type: ignore
121+
return self.read_to_py_iterator(
122+
resource,
123+
entity_name,
124+
schema, # type: ignore
125+
all_model_fields
126+
)
120127

121128
self.raise_if_not_sensible_file(resource, entity_name)
122129

@@ -125,7 +132,7 @@ def read_to_entity_type(
125132
except KeyError as err:
126133
raise ReaderLacksEntityTypeSupport(entity_type=entity_type) from err
127134

128-
return reader_func(self, resource, entity_name, schema)
135+
return reader_func(self, resource, entity_name, schema, all_model_fields=all_model_fields)
129136

130137
def add_record_index(self, entity: EntityType, **kwargs) -> EntityType:
131138
"""Add a record index to the entity"""

src/dve/core_engine/backends/implementations/duckdb/readers/csv.py

Lines changed: 46 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,15 @@
99
from duckdb import DuckDBPyConnection, DuckDBPyRelation, StarExpression, read_csv
1010
from pydantic import BaseModel
1111

12-
from dve.core_engine.backends.base.reader import BaseFileReader, read_function
12+
from dve.core_engine.backends.base.reader import read_function
1313
from dve.core_engine.backends.exceptions import EmptyFileError, MessageBearingError
1414
from dve.core_engine.backends.implementations.duckdb.duckdb_helpers import (
1515
duckdb_record_index,
1616
duckdb_write_parquet,
1717
get_duckdb_type_from_annotation,
1818
)
1919
from dve.core_engine.backends.implementations.duckdb.types import SQLType
20-
from dve.core_engine.backends.readers.utilities import (
21-
raise_message_bearing_error_on_header_differences,
22-
)
20+
from dve.core_engine.backends.readers.csv import CSVFileReader
2321
from dve.core_engine.backends.utilities import get_polars_type_from_annotation, polars_record_index
2422
from dve.core_engine.constants import RECORD_INDEX_COLUMN_NAME
2523
from dve.core_engine.message import FeedbackMessage
@@ -29,7 +27,7 @@
2927

3028
@duckdb_record_index
3129
@duckdb_write_parquet
32-
class DuckDBCSVReader(BaseFileReader):
30+
class DuckDBCSVReader(CSVFileReader):
3331
"""A reader for CSV files including the ability to compare the passed model
3432
to the file header, if it exists.
3533
@@ -54,55 +52,52 @@ def __init__(
5452
null_empty_strings: bool = False,
5553
**_,
5654
):
57-
self.header = header
58-
self.delim = delim
59-
self.quotechar = quotechar
6055
self._connection = connection if connection else ddb.connect(":memory:")
61-
self.field_check = field_check
62-
self.field_check_error_code = field_check_error_code
63-
self.field_check_error_message = field_check_error_message
6456
self.null_empty_strings = null_empty_strings
6557

66-
super().__init__()
67-
68-
def perform_field_check(
69-
self, resource: URI, entity_name: str, expected_schema: type[BaseModel]
70-
):
71-
"""Check that the header of the CSV aligns with the provided model"""
72-
if not self.header:
73-
raise ValueError("Cannot perform field check without a CSV header")
74-
75-
raise_message_bearing_error_on_header_differences(
76-
resource,
77-
entity_name,
78-
expected_schema,
79-
self.field_check_error_code,
80-
self.field_check_error_message,
81-
self.delim,
82-
self.quotechar,
58+
super().__init__(
59+
header=header,
60+
delimiter=delim,
61+
quote_char=quotechar,
62+
field_check=field_check,
63+
field_check_error_code=field_check_error_code,
64+
field_check_error_message=field_check_error_message
8365
)
8466

8567
def read_to_py_iterator(
86-
self, resource: URI, entity_name: EntityName, schema: type[BaseModel]
68+
self,
69+
resource: URI,
70+
entity_name: EntityName,
71+
schema: type[BaseModel],
72+
all_model_fields: Optional[set[str]] = None,
8773
) -> Iterator[dict[str, Any]]:
8874
"""Creates an iterable object of rows as dictionaries"""
89-
yield from self.read_to_relation(resource, entity_name, schema).pl().iter_rows(named=True)
75+
yield from self.read_to_relation(
76+
resource,
77+
entity_name,
78+
schema,
79+
all_model_fields,
80+
).pl().iter_rows(named=True)
9081

9182
@read_function(DuckDBPyRelation)
9283
def read_to_relation( # pylint: disable=unused-argument
93-
self, resource: URI, entity_name: EntityName, schema: type[BaseModel]
84+
self,
85+
resource: URI,
86+
entity_name: EntityName,
87+
schema: type[BaseModel],
88+
all_model_fields: Optional[set[str]] = None,
9489
) -> DuckDBPyRelation:
9590
"""Returns a relation object from the source csv"""
9691
if get_content_length(resource) == 0:
9792
raise EmptyFileError(f"File at {resource} is empty.")
9893

9994
if self.field_check:
100-
self.perform_field_check(resource, entity_name, schema)
95+
self.perform_field_check(resource, entity_name, schema, all_model_fields)
10196

10297
reader_options: dict[str, Any] = {
10398
"header": self.header,
104-
"delimiter": self.delim,
105-
"quotechar": self.quotechar,
99+
"delimiter": self.delimiter,
100+
"quotechar": self.quote_char,
106101
}
107102

108103
ddb_schema: dict[str, SQLType] = {
@@ -134,19 +129,23 @@ class PolarsToDuckDBCSVReader(DuckDBCSVReader):
134129

135130
@read_function(DuckDBPyRelation)
136131
def read_to_relation( # pylint: disable=unused-argument
137-
self, resource: URI, entity_name: EntityName, schema: type[BaseModel]
132+
self,
133+
resource: URI,
134+
entity_name: EntityName,
135+
schema: type[BaseModel],
136+
all_model_fields: Optional[set[str]] = None,
138137
) -> DuckDBPyRelation:
139138
"""Returns a relation object from the source csv"""
140139
if get_content_length(resource) == 0:
141140
raise EmptyFileError(f"File at {resource} is empty.")
142141

143142
if self.field_check:
144-
self.perform_field_check(resource, entity_name, schema)
143+
self.perform_field_check(resource, entity_name, schema, all_model_fields)
145144

146145
reader_options: dict[str, Any] = {
147146
"has_header": self.header,
148-
"separator": self.delim,
149-
"quote_char": self.quotechar,
147+
"separator": self.delimiter,
148+
"quote_char": self.quote_char,
150149
}
151150

152151
polars_types = {
@@ -212,10 +211,17 @@ def __init__(
212211

213212
@read_function(DuckDBPyRelation)
214213
def read_to_relation( # pylint: disable=unused-argument
215-
self, resource: URI, entity_name: EntityName, schema: type[BaseModel]
214+
self,
215+
resource: URI,
216+
entity_name: EntityName,
217+
schema: type[BaseModel],
218+
all_model_fields: Optional[set[str]] = None,
216219
) -> DuckDBPyRelation:
217220
entity: DuckDBPyRelation = super().read_to_relation(
218-
resource=resource, entity_name=entity_name, schema=schema
221+
resource=resource,
222+
entity_name=entity_name,
223+
schema=schema,
224+
all_model_fields=all_model_fields
219225
)
220226
entity = entity.select(StarExpression(exclude=[RECORD_INDEX_COLUMN_NAME])).distinct()
221227
no_records = entity.shape[0]

src/dve/core_engine/backends/implementations/duckdb/readers/json.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,22 @@ def __init__(
3636
super().__init__()
3737

3838
def read_to_py_iterator(
39-
self, resource: URI, entity_name: EntityName, schema: type[BaseModel]
39+
self,
40+
resource: URI,
41+
entity_name: EntityName,
42+
schema: type[BaseModel],
43+
all_model_fields: Optional[set[str]] = None,
4044
) -> Iterator[dict[str, Any]]:
4145
"""Creates an iterable object of rows as dictionaries"""
4246
return self.read_to_relation(resource, entity_name, schema).pl().iter_rows(named=True)
4347

4448
@read_function(DuckDBPyRelation)
4549
def read_to_relation( # pylint: disable=unused-argument
46-
self, resource: URI, entity_name: EntityName, schema: type[BaseModel]
50+
self,
51+
resource: URI,
52+
entity_name: EntityName,
53+
schema: type[BaseModel],
54+
**_,
4755
) -> DuckDBPyRelation:
4856
"""Returns a relation object from the source json"""
4957

src/dve/core_engine/backends/implementations/duckdb/readers/xml.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,13 @@ def __init__(self, *, connection: Optional[DuckDBPyConnection] = None, **kwargs)
3030
super().__init__(**kwargs)
3131

3232
@read_function(DuckDBPyRelation)
33-
def read_to_relation(self, resource: URI, entity_name: str, schema: type[BaseModel]):
33+
def read_to_relation(
34+
self,
35+
resource: URI,
36+
entity_name: str,
37+
schema: type[BaseModel],
38+
**_,
39+
):
3440
"""Returns a relation object from the source xml"""
3541
if self.xsd_location:
3642
msg = self._run_xmllint(file_uri=resource)

src/dve/core_engine/backends/implementations/spark/readers/csv.py

Lines changed: 20 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,21 @@
88
from pyspark.sql import DataFrame, SparkSession
99
from pyspark.sql.types import StructType
1010

11-
from dve.core_engine.backends.base.reader import BaseFileReader, read_function
11+
from dve.core_engine.backends.base.reader import read_function
12+
from dve.core_engine.backends.readers.csv import CSVFileReader
1213
from dve.core_engine.backends.exceptions import EmptyFileError
1314
from dve.core_engine.backends.implementations.spark.spark_helpers import (
1415
get_type_from_annotation,
1516
spark_record_index,
1617
spark_write_parquet,
1718
)
18-
from dve.core_engine.backends.readers.utilities import (
19-
raise_message_bearing_error_on_header_differences,
20-
)
2119
from dve.core_engine.type_hints import URI, EntityName
2220
from dve.parser.file_handling import get_content_length
2321

2422

2523
@spark_record_index
2624
@spark_write_parquet
27-
class SparkCSVReader(BaseFileReader):
25+
class SparkCSVReader(CSVFileReader):
2826
"""A Spark reader for CSV files."""
2927

3028
# pylint: disable=R0902
@@ -45,41 +43,29 @@ def __init__(
4543
**_,
4644
) -> None:
4745

48-
self.delimiter = delimiter
49-
self.escape_char = escape_char
50-
self.encoding = encoding
51-
self.quote_char = quote_char
52-
self.header = header
5346
self.multi_line = multi_line
5447
self.null_empty_strings = null_empty_strings
5548
self.spark_session = spark_session if spark_session else SparkSession.builder.getOrCreate() # type: ignore # pylint: disable=C0301
56-
self.field_check = field_check
57-
self.field_check_error_code = field_check_error_code
58-
self.field_check_error_message = field_check_error_message
59-
60-
super().__init__()
61-
62-
def perform_field_check(
63-
self, resource: URI, entity_name: str, expected_schema: type[BaseModel]
64-
):
65-
"""Check that the header of the CSV aligns with the provided model"""
66-
if not self.header:
67-
raise ValueError("Cannot perform field check without a CSV header")
6849

69-
raise_message_bearing_error_on_header_differences(
70-
resource,
71-
entity_name,
72-
expected_schema,
73-
self.field_check_error_code,
74-
self.field_check_error_message,
75-
self.delimiter,
76-
self.quote_char,
50+
super().__init__(
51+
delimiter=delimiter,
52+
escape_char=escape_char,
53+
encoding=encoding,
54+
quote_char=quote_char,
55+
header=header,
56+
field_check=field_check,
57+
field_check_error_code=field_check_error_code,
58+
field_check_error_message=field_check_error_message,
7759
)
7860

7961
def read_to_py_iterator(
80-
self, resource: URI, entity_name: EntityName, schema: type[BaseModel]
62+
self,
63+
resource: URI,
64+
entity_name: EntityName,
65+
schema: type[BaseModel],
66+
all_model_fields: Optional[set[str]] = None,
8167
) -> Iterator[dict[URI, Any]]:
82-
df = self.read_to_dataframe(resource, entity_name, schema)
68+
df = self.read_to_dataframe(resource, entity_name, schema, all_model_fields)
8369
yield from (record.asDict(True) for record in df.toLocalIterator())
8470

8571
@read_function(DataFrame)
@@ -88,13 +74,14 @@ def read_to_dataframe(
8874
resource: URI,
8975
entity_name: EntityName, # pylint: disable=unused-argument
9076
schema: type[BaseModel],
77+
all_model_fields: Optional[set[str]] = None,
9178
) -> DataFrame:
9279
"""Read a CSV file directly to a Spark DataFrame."""
9380
if get_content_length(resource) == 0:
9481
raise EmptyFileError(f"File at {resource} is empty.")
9582

9683
if self.field_check:
97-
self.perform_field_check(resource, entity_name, schema)
84+
self.perform_field_check(resource, entity_name, schema, all_model_fields)
9885

9986
spark_schema: StructType = get_type_from_annotation(schema)
10087
kwargs = {

src/dve/core_engine/backends/implementations/spark/readers/json.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,11 @@ def __init__(
3939
super().__init__()
4040

4141
def read_to_py_iterator(
42-
self, resource: URI, entity_name: EntityName, schema: type[BaseModel]
42+
self,
43+
resource: URI,
44+
entity_name: EntityName,
45+
schema: type[BaseModel],
46+
all_model_fields: Optional[set[str]] = None,
4347
) -> Iterator[dict[URI, Any]]:
4448
df = self.read_to_dataframe(resource, entity_name, schema)
4549
yield from (record.asDict(True) for record in df.toLocalIterator())
@@ -50,6 +54,7 @@ def read_to_dataframe(
5054
resource: URI,
5155
entity_name: EntityName, # pylint: disable=unused-argument
5256
schema: type[BaseModel],
57+
**_,
5358
) -> DataFrame:
5459
"""Read a JSON file directly to a Spark DataFrame."""
5560
if get_content_length(resource) == 0:

src/dve/core_engine/backends/implementations/spark/readers/xml.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,11 @@ def __init__(
104104
self.namespace = namespace
105105

106106
def read_to_py_iterator(
107-
self, resource: URI, entity_name: EntityName, schema: type[BaseModel]
107+
self,
108+
resource: URI,
109+
entity_name: EntityName,
110+
schema: type[BaseModel],
111+
all_model_fields: Optional[set[str]] = None,
108112
) -> Iterator[dict[URI, Any]]:
109113
df = self.read_to_dataframe(resource, entity_name, schema)
110114
yield from (record.asDict(True) for record in df.toLocalIterator())
@@ -115,6 +119,7 @@ def read_to_dataframe(
115119
resource: URI,
116120
entity_name: EntityName, # pylint: disable=unused-argument
117121
schema: type[BaseModel],
122+
**_,
118123
) -> DataFrame:
119124
"""Read an XML file directly to a Spark DataFrame using the Databricks
120125
XML reader package.

0 commit comments

Comments
 (0)