Skip to content

Commit d165e4b

Browse files
♻️ Refactor generate select template to isolate templated code to the minimum (fastapi#967)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent d5cba6e commit d165e4b

7 files changed

+544
-730
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ exclude_lines = [
9090
strict = true
9191

9292
[[tool.mypy.overrides]]
93-
module = "sqlmodel.sql.expression"
93+
module = "sqlmodel.sql._expression_select_gen"
9494
warn_unused_ignores = false
9595

9696
[[tool.mypy.overrides]]

scripts/generate_select.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
from jinja2 import Template
88
from pydantic import BaseModel
99

10-
template_path = Path(__file__).parent.parent / "sqlmodel/sql/expression.py.jinja2"
11-
destiny_path = Path(__file__).parent.parent / "sqlmodel/sql/expression.py"
10+
template_path = (
11+
Path(__file__).parent.parent / "sqlmodel/sql/_expression_select_gen.py.jinja2"
12+
)
13+
destiny_path = Path(__file__).parent.parent / "sqlmodel/sql/_expression_select_gen.py"
1214

1315

1416
number_of_types = 4
@@ -48,7 +50,7 @@ class Arg(BaseModel):
4850

4951
result = (
5052
"# WARNING: do not modify this code, it is generated by "
51-
"expression.py.jinja2\n\n" + result
53+
"_expression_select_gen.py.jinja2\n\n" + result
5254
)
5355

5456
result = black.format_str(result, mode=black.Mode())
+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from typing import (
2+
Tuple,
3+
TypeVar,
4+
Union,
5+
)
6+
7+
from sqlalchemy.sql._typing import (
8+
_ColumnExpressionArgument,
9+
)
10+
from sqlalchemy.sql.expression import Select as _Select
11+
from typing_extensions import Self
12+
13+
_T = TypeVar("_T")
14+
15+
16+
# Separate this class in SelectBase, Select, and SelectOfScalar so that they can share
17+
# where and having without having type overlap incompatibility in session.exec().
18+
class SelectBase(_Select[Tuple[_T]]):
19+
inherit_cache = True
20+
21+
def where(self, *whereclause: Union[_ColumnExpressionArgument[bool], bool]) -> Self:
22+
"""Return a new `Select` construct with the given expression added to
23+
its `WHERE` clause, joined to the existing clause via `AND`, if any.
24+
"""
25+
return super().where(*whereclause) # type: ignore[arg-type]
26+
27+
def having(self, *having: Union[_ColumnExpressionArgument[bool], bool]) -> Self:
28+
"""Return a new `Select` construct with the given expression added to
29+
its `HAVING` clause, joined to the existing clause via `AND`, if any.
30+
"""
31+
return super().having(*having) # type: ignore[arg-type]
32+
33+
34+
class Select(SelectBase[_T]):
35+
inherit_cache = True
36+
37+
38+
# This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different
39+
# purpose. This is the same as a normal SQLAlchemy Select class where there's only one
40+
# entity, so the result will be converted to a scalar by default. This way writing
41+
# for loops on the results will feel natural.
42+
class SelectOfScalar(SelectBase[_T]):
43+
inherit_cache = True

0 commit comments

Comments
 (0)