diff --git a/dataframely/_base_schema.py b/dataframely/_base_schema.py index b875e6b..f1506fd 100644 --- a/dataframely/_base_schema.py +++ b/dataframely/_base_schema.py @@ -119,7 +119,9 @@ def __new__( result = Metadata() for base in bases: result.update(mcs._get_metadata_recursively(base)) - result.update(mcs._get_metadata(namespace)) + namespace_metadata = mcs._get_metadata(namespace) + mcs._remove_overridden_columns(result, namespace, bases) + result.update(namespace_metadata) namespace[_COLUMN_ATTR] = result.columns cls = super().__new__(mcs, name, bases, namespace, *args, **kwargs) @@ -207,6 +209,34 @@ def __getattribute__(cls, name: str) -> Any: val._name = val.alias or name return val + @staticmethod + def _remove_overridden_columns( + result: Metadata, + namespace: dict[str, Any], + bases: tuple[type[object], ...], + ) -> None: + """Remove inherited columns that the child namespace explicitly overrides. + + Before merging the child namespace, we must drop any parent columns whose + attribute name is redefined in the child. This allows subclasses to redefine + inherited columns while still detecting genuine alias conflicts. + + In multiple-inheritance scenarios, the same attribute name may appear in more + than one base with different aliases, so we walk all parent MROs and collect + every matching column key to remove. + """ + for attr, value in namespace.items(): + if not isinstance(value, Column): + continue + keys_to_remove: set[str] = set() + for base in bases: + for parent_cls in base.__mro__: + parent_col = parent_cls.__dict__.get(attr) + if parent_col is not None and isinstance(parent_col, Column): + keys_to_remove.add(parent_col.alias or attr) + for parent_key in keys_to_remove: + result.columns.pop(parent_key, None) + @staticmethod def _get_metadata_recursively(kls: type[object]) -> Metadata: result = Metadata() diff --git a/tests/schema/test_base.py b/tests/schema/test_base.py index 6eb2084..eb86ca3 100644 --- a/tests/schema/test_base.py +++ b/tests/schema/test_base.py @@ -141,3 +141,22 @@ def test_user_error_polars_datatype_type() -> None: class MySchemaWithPolarsDataTypeType(dy.Schema): a = dy.Int32(nullable=False) b = pl.String # User error: Used pl.String instead of dy.String() + + +def test_override() -> None: + class FirstSchema(dy.Schema): + x = dy.Int64() + + class SecondSchema(FirstSchema): + x = dy.Int64(nullable=True) + + first_columns = FirstSchema.columns() + second_columns = SecondSchema.columns() + + assert set(first_columns) == {"x"} + assert set(second_columns) == {"x"} + + assert first_columns["x"].nullable is False + assert second_columns["x"].nullable is True + + assert type(second_columns["x"]) is type(first_columns["x"])