Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ repos:
language: unsupported
types: [python]

- id: local-mypy
name: mypy check
entry: uv run mypy sqlmodel tests/test_select_typing.py
- id: local-ty
name: ty check
entry: uv run ty check sqlmodel tests/test_select_typing.py
require_serial: true
language: unsupported
pass_filenames: false
Expand Down
15 changes: 4 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ tests = [
"fastapi >=0.128.0",
"httpx >=0.28.1",
"jinja2 >=3.1.6",
"mypy >=1.19.1",
"pytest >=7.0.1",
"ruff >=0.15.6",
"ty>=0.0.25",
"typing-extensions >=4.15.0",
]

Expand Down Expand Up @@ -125,16 +125,6 @@ exclude_lines = [
[tool.coverage.html]
show_contexts = true

[tool.mypy]
strict = true
exclude = "sqlmodel.sql._expression_select_gen"

[[tool.mypy.overrides]]
module = "docs_src.*"
disallow_incomplete_defs = false
disallow_untyped_defs = false
disallow_untyped_calls = false

[tool.ruff.lint]
select = [
"E", # pycodestyle errors
Expand All @@ -161,3 +151,6 @@ known-third-party = ["sqlmodel", "sqlalchemy", "pydantic", "fastapi"]
[tool.ruff.lint.pyupgrade]
# Preserve types, even if a file imports `from __future__ import annotations`.
keep-runtime-typing = true

[tool.ty.terminal]
error-on-warning = true
2 changes: 1 addition & 1 deletion scripts/generate_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Arg(BaseModel):
else:
t_type = f"_T{i}"
t_var = f"_TCCA[{t_type}]"
arg = Arg(name=f"__ent{i}", annotation=t_var)
arg = Arg(name=f"ent{i}", annotation=t_var)
ret_type = t_type
args.append(arg)
return_types.append(ret_type)
Expand Down
4 changes: 2 additions & 2 deletions scripts/lint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
set -e
set -x

mypy sqlmodel
mypy tests/test_select_typing.py
ty check sqlmodel
ty check tests/test_select_typing.py
ruff check sqlmodel tests docs_src scripts
ruff format sqlmodel tests docs_src scripts --check
21 changes: 9 additions & 12 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import builtins
import ipaddress
import uuid
import weakref
from collections.abc import Callable, Mapping, Sequence, Set
from dataclasses import dataclass
from datetime import date, datetime, time, timedelta
Expand Down Expand Up @@ -52,7 +51,7 @@
from sqlalchemy.sql.sqltypes import LargeBinary, Time, Uuid
from typing_extensions import deprecated

from ._compat import ( # type: ignore[attr-defined]
from ._compat import (
PYDANTIC_MINOR_VERSION,
BaseConfig,
ModelMetaclass,
Expand Down Expand Up @@ -177,7 +176,7 @@ def __init__(
cascade_delete: bool | None = False,
passive_deletes: bool | Literal["all"] | None = False,
link_model: Any | None = None,
sa_relationship: RelationshipProperty | None = None, # type: ignore
sa_relationship: RelationshipProperty | None = None,
sa_relationship_args: Sequence[Any] | None = None,
sa_relationship_kwargs: Mapping[str, Any] | None = None,
) -> None:
Expand Down Expand Up @@ -398,7 +397,7 @@ def Field(
nullable: bool | UndefinedType = Undefined,
index: bool | UndefinedType = Undefined,
sa_type: type[Any] | UndefinedType = Undefined,
sa_column: Column | UndefinedType = Undefined, # type: ignore
sa_column: Column | UndefinedType = Undefined,
sa_column_args: Sequence[Any] | UndefinedType = Undefined,
sa_column_kwargs: Mapping[str, Any] | UndefinedType = Undefined,
schema_extra: dict[str, Any] | None = None,
Expand Down Expand Up @@ -525,13 +524,13 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
model_fields: ClassVar[dict[str, FieldInfo]]

# Replicate SQLAlchemy
def __setattr__(cls, name: str, value: Any) -> None:
def __setattr__(cls, name: str, value: Any) -> None: # ty: ignore[invalid-method-override]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For context: #1806 (comment)

if is_table_model_class(cls):
DeclarativeMeta.__setattr__(cls, name, value)
else:
super().__setattr__(name, value)

def __delattr__(cls, name: str) -> None:
def __delattr__(cls, name: str) -> None: # ty: ignore[invalid-method-override]
if is_table_model_class(cls):
DeclarativeMeta.__delattr__(cls, name)
else:
Expand Down Expand Up @@ -649,7 +648,7 @@ def __init__(
# Plain forward references, for models not yet defined, are not
# handled well by SQLAlchemy without Mapped, so, wrap the
# annotations in Mapped here
cls.__annotations__[rel_name] = Mapped[ann] # type: ignore[valid-type]
cls.__annotations__[rel_name] = Mapped[ann]
relationship_to = get_relationship_to(
name=rel_name, rel_info=rel_info, annotation=ann
)
Expand Down Expand Up @@ -738,7 +737,7 @@ def get_sqlalchemy_type(field: Any) -> Any:
raise ValueError(f"{type_} has no matching SQLAlchemy type")


def get_column_from_field(field: Any) -> Column: # type: ignore
def get_column_from_field(field: Any) -> Column:
field_info = field
sa_column = _get_sqlmodel_field_value(field_info, "sa_column", Undefined)
if isinstance(sa_column, Column):
Expand Down Expand Up @@ -773,7 +772,7 @@ def get_column_from_field(field: Any) -> Column: # type: ignore
assert isinstance(foreign_key, str)
assert isinstance(ondelete_value, (str, type(None))) # for typing
args.append(ForeignKey(foreign_key, ondelete=ondelete_value))
kwargs = {
kwargs: dict[str, Any] = {
"primary_key": primary_key,
"nullable": nullable,
"index": index,
Expand All @@ -797,8 +796,6 @@ def get_column_from_field(field: Any) -> Column: # type: ignore
return Column(sa_type, *args, **kwargs)


class_registry = weakref.WeakValueDictionary() # type: ignore

default_registry = registry()

_TSQLModel = TypeVar("_TSQLModel", bound="SQLModel")
Expand Down Expand Up @@ -850,7 +847,7 @@ def __setattr__(self, name: str, value: Any) -> None:
return
else:
# Set in SQLAlchemy, before Pydantic to trigger events and updates
if is_table_model_class(self.__class__) and is_instrumented(self, name): # type: ignore[no-untyped-call]
if is_table_model_class(self.__class__) and is_instrumented(self, name):
set_attribute(self, name, value)
# Set in Pydantic model to trigger possible validation changes, only for
# non relationship values
Expand Down
4 changes: 2 additions & 2 deletions sqlmodel/sql/_expression_select_cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ def where(self, *whereclause: _ColumnExpressionArgument[bool] | bool) -> Self:
"""Return a new `Select` construct with the given expression added to
its `WHERE` clause, joined to the existing clause via `AND`, if any.
"""
return super().where(*whereclause) # type: ignore[arg-type]
return super().where(*whereclause)

def having(self, *having: _ColumnExpressionArgument[bool] | bool) -> Self:
"""Return a new `Select` construct with the given expression added to
its `HAVING` clause, joined to the existing clause via `AND`, if any.
"""
return super().having(*having) # type: ignore[arg-type]
return super().having(*having)


class Select(SelectBase[_T]):
Expand Down
Loading
Loading