diff --git a/scim2_models/base.py b/scim2_models/base.py index d427673..396bd47 100644 --- a/scim2_models/base.py +++ b/scim2_models/base.py @@ -1,6 +1,7 @@ import warnings from inspect import isclass from typing import Any +from typing import ClassVar from typing import Optional from typing import get_args from typing import get_origin @@ -112,6 +113,14 @@ class BaseModel(PydanticBaseModel): extra="forbid", ) + _allow_bulk_id: ClassVar[bool] = False + """Allow bulkId field for the model""" + + @classmethod + def __pydantic_init_subclass__(cls, **kwargs: Any) -> None: + # Validate field names + cls._check_bulk_id() + @classmethod def get_field_annotation(cls, field_name: str, annotation_type: type) -> Any: """Return the annotation of type 'annotation_type' of the field 'field_name'. @@ -538,6 +547,19 @@ def _set_complex_attribute_urns(self) -> None: else: attr_value._attribute_urn = schema + @classmethod + def _check_bulk_id(cls) -> None: + """Enforce bulkId as reserved field per RFC 7643 ยง3.1. + + Check if a bulkdId field is defined and + raise error if `_allow_bulk_id` is set to `False` + """ + if cls._allow_bulk_id: + return + for info in cls.model_fields.values(): + if info.serialization_alias == "bulkId": + raise TypeError(f"{cls.__name__}: bulkId is reserved for BulkOperation") + @field_serializer("*", mode="wrap") def scim_serializer( self, diff --git a/scim2_models/messages/bulk.py b/scim2_models/messages/bulk.py index 7c2f19f..1468ba4 100644 --- a/scim2_models/messages/bulk.py +++ b/scim2_models/messages/bulk.py @@ -12,6 +12,8 @@ class BulkOperation(ComplexAttribute): + _allow_bulk_id = True + class Method(str, Enum): post = "POST" put = "PUT" diff --git a/tests/test_model_attributes.py b/tests/test_model_attributes.py index 418b88c..428bfa9 100644 --- a/tests/test_model_attributes.py +++ b/tests/test_model_attributes.py @@ -1,6 +1,8 @@ import uuid from typing import Annotated +import pytest + from scim2_models import URN from scim2_models.annotations import Returned from scim2_models.attributes import ComplexAttribute @@ -377,3 +379,14 @@ def test_short_attr_path_with_plain_name(): assert _short_attr_path("userName") == "userName" assert _short_attr_path("name.familyName") == "name.familyName" + + +def test_forbid_bulk_id(): + """Forbid bulkId from class definition.""" + with pytest.raises(TypeError) as exc_info: + + class CustomModel(Resource): + __schema__ = URN("urn:example:schemas:CustomModel") + bulk_id: str | None = None + + assert str(exc_info.value) == "CustomModel: bulkId is reserved for BulkOperation"