Skip to content

Commit 9302e61

Browse files
committed
Address review feedback for Annotated relationships
- Detect annotated relationships without a default value by iterating over the union of class_dict and original_annotations. - Use elif and unwrap inner Mapped[T] when handling Annotated[Mapped[T], ...] to avoid double-wrapping in Mapped. - Split the relationship/pydantic partition into three single-purpose loops with comments for clarity. - Add tests covering Annotated relationships with default value, without default value, and with Annotated[Mapped[T], ...].
1 parent e91fb22 commit 9302e61

2 files changed

Lines changed: 106 additions & 10 deletions

File tree

sqlmodel/main.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -561,20 +561,29 @@ def __new__(
561561
original_annotations = get_annotations(class_dict)
562562
pydantic_annotations = {}
563563
relationship_annotations = {}
564-
for k, v in class_dict.items():
565-
a = original_annotations.get(k, None)
566-
r = get_annotated_relationshipinfo(a)
564+
565+
# find relationship info in both annotations and class dict
566+
for k in {**original_annotations, **class_dict}:
567+
v = class_dict.get(k)
568+
if isinstance(v, RelationshipInfo):
569+
relationships[k] = v
570+
continue
571+
r = get_annotated_relationshipinfo(original_annotations.get(k))
567572
if r is not None:
568573
relationships[k] = r
569-
elif isinstance(v, RelationshipInfo):
570-
relationships[k] = v
571-
else:
574+
575+
# populate dict passed to pydantic
576+
for k, v in class_dict.items():
577+
if k not in relationships:
572578
dict_for_pydantic[k] = v
573-
for k, v in original_annotations.items():
579+
580+
# split out pydantic annotations
581+
for k, a in original_annotations.items():
574582
if k in relationships:
575-
relationship_annotations[k] = v
583+
relationship_annotations[k] = a
576584
else:
577-
pydantic_annotations[k] = v
585+
pydantic_annotations[k] = a
586+
578587
dict_used = {
579588
**dict_for_pydantic,
580589
"__weakref__": None,
@@ -659,8 +668,10 @@ def __init__(
659668
origin: Any = get_origin(raw_ann)
660669
if origin is Mapped:
661670
ann = raw_ann.__args__[0]
662-
if origin is Annotated:
671+
elif origin is Annotated:
663672
ann = get_args(raw_ann)[0]
673+
if get_origin(ann) is Mapped:
674+
ann = ann.__args__[0]
664675
cls.__annotations__[rel_name] = Mapped[ann] # type: ignore[valid-type]
665676
else:
666677
ann = raw_ann
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from typing import Annotated
2+
3+
from sqlalchemy.orm import Mapped
4+
from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select
5+
6+
7+
def test_annotated_relationship_with_default() -> None:
8+
class Team(SQLModel, table=True):
9+
id: Annotated[int | None, Field(primary_key=True)] = None
10+
name: Annotated[str, Field(index=True)]
11+
12+
heroes: Annotated[list["Hero"], Relationship(back_populates="team")] = [] # noqa: RUF012
13+
14+
class Hero(SQLModel, table=True):
15+
id: Annotated[int | None, Field(primary_key=True)] = None
16+
name: Annotated[str, Field(index=True)]
17+
team_id: Annotated[int | None, Field(foreign_key="team.id")] = None
18+
team: Annotated[Team | None, Relationship(back_populates="heroes")] = None
19+
20+
engine = create_engine("sqlite://")
21+
SQLModel.metadata.create_all(engine)
22+
with Session(engine) as session:
23+
team = Team(name="Preventers")
24+
hero = Hero(name="Deadpond", team=team)
25+
session.add(hero)
26+
session.commit()
27+
session.refresh(hero)
28+
assert hero.team is not None
29+
assert hero.team.name == "Preventers"
30+
team_db = session.exec(select(Team)).one()
31+
assert [h.name for h in team_db.heroes] == ["Deadpond"]
32+
33+
34+
def test_annotated_relationship_without_default() -> None:
35+
class Team(SQLModel, table=True):
36+
id: Annotated[int | None, Field(primary_key=True)] = None
37+
name: Annotated[str, Field(index=True)]
38+
39+
heroes: Annotated[list["Hero"], Relationship(back_populates="team")]
40+
41+
class Hero(SQLModel, table=True):
42+
id: Annotated[int | None, Field(primary_key=True)] = None
43+
name: Annotated[str, Field(index=True)]
44+
team_id: Annotated[int | None, Field(foreign_key="team.id")] = None
45+
team: Annotated[Team | None, Relationship(back_populates="heroes")]
46+
47+
engine = create_engine("sqlite://")
48+
SQLModel.metadata.create_all(engine)
49+
with Session(engine) as session:
50+
team = Team(name="Z-Force")
51+
hero = Hero(name="Spider-Boy", team=team)
52+
session.add(hero)
53+
session.commit()
54+
session.refresh(hero)
55+
assert hero.team is not None
56+
assert hero.team.name == "Z-Force"
57+
58+
59+
def test_annotated_mapped_relationship() -> None:
60+
class Team(SQLModel, table=True):
61+
id: Annotated[int | None, Field(primary_key=True)] = None
62+
name: Annotated[str, Field(index=True)]
63+
64+
heroes: Annotated[
65+
Mapped[list["Hero"]], Relationship(back_populates="team")
66+
] = [] # noqa: RUF012
67+
68+
class Hero(SQLModel, table=True):
69+
id: Annotated[int | None, Field(primary_key=True)] = None
70+
name: Annotated[str, Field(index=True)]
71+
team_id: Annotated[int | None, Field(foreign_key="team.id")] = None
72+
team: Annotated[Mapped[Team | None], Relationship(back_populates="heroes")] = (
73+
None
74+
)
75+
76+
engine = create_engine("sqlite://")
77+
SQLModel.metadata.create_all(engine)
78+
with Session(engine) as session:
79+
team = Team(name="Avengers")
80+
hero = Hero(name="Iron Man", team=team)
81+
session.add(hero)
82+
session.commit()
83+
session.refresh(hero)
84+
assert hero.team is not None
85+
assert hero.team.name == "Avengers"

0 commit comments

Comments
 (0)