Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion dataframely/_base_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ def __new__(
result = Metadata()
for base in bases:
result.update(mcs._get_metadata_recursively(base))
result.update(mcs._get_metadata(namespace))
namespace_metadata = mcs._get_metadata(namespace)
mcs._remove_overridden_columns(result, namespace, bases)
result.update(namespace_metadata)
namespace[_COLUMN_ATTR] = result.columns
cls = super().__new__(mcs, name, bases, namespace, *args, **kwargs)

Expand Down Expand Up @@ -207,6 +209,34 @@ def __getattribute__(cls, name: str) -> Any:
val._name = val.alias or name
return val

@staticmethod
def _remove_overridden_columns(
result: Metadata,
namespace: dict[str, Any],
bases: tuple[type[object], ...],
) -> None:
"""Remove inherited columns that the child namespace explicitly overrides.

Before merging the child namespace, we must drop any parent columns whose
attribute name is redefined in the child. This allows subclasses to redefine
inherited columns while still detecting genuine alias conflicts.

In multiple-inheritance scenarios, the same attribute name may appear in more
than one base with different aliases, so we walk all parent MROs and collect
every matching column key to remove.
"""
for attr, value in namespace.items():
if not isinstance(value, Column):
continue
keys_to_remove: set[str] = set()
for base in bases:
for parent_cls in base.__mro__:
parent_col = parent_cls.__dict__.get(attr)
if parent_col is not None and isinstance(parent_col, Column):
keys_to_remove.add(parent_col.alias or attr)
for parent_key in keys_to_remove:
result.columns.pop(parent_key, None)

@staticmethod
def _get_metadata_recursively(kls: type[object]) -> Metadata:
result = Metadata()
Expand Down
19 changes: 19 additions & 0 deletions tests/schema/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,22 @@ def test_user_error_polars_datatype_type() -> None:
class MySchemaWithPolarsDataTypeType(dy.Schema):
a = dy.Int32(nullable=False)
b = pl.String # User error: Used pl.String instead of dy.String()


def test_override() -> None:
class FirstSchema(dy.Schema):
x = dy.Int64()

class SecondSchema(FirstSchema):
x = dy.Int64(nullable=True)

first_columns = FirstSchema.columns()
second_columns = SecondSchema.columns()

assert set(first_columns) == {"x"}
assert set(second_columns) == {"x"}

assert first_columns["x"].nullable is False
assert second_columns["x"].nullable is True

assert type(second_columns["x"]) is type(first_columns["x"])
Loading