Skip to content

Commit b0ed50e

Browse files
AvasamAlexWaygoodJelleZijlstra
authored
Fix all fixable stubtest_allowlist entries in SQLAlchemy (#9596)
Co-authored-by: Alex Waygood <[email protected]> Co-authored-by: Jelle Zijlstra <[email protected]>
1 parent 08e6e4c commit b0ed50e

23 files changed

+379
-1268
lines changed

stubs/SQLAlchemy/@tests/stubtest_allowlist.txt

+22-1,151
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from __future__ import annotations
2+
3+
from typing_extensions import assert_type
4+
5+
from sqlalchemy.orm.strategy_options import (
6+
Load,
7+
contains_eager,
8+
defaultload,
9+
defer,
10+
immediateload,
11+
joinedload,
12+
lazyload,
13+
load_only,
14+
loader_option,
15+
noload,
16+
raiseload,
17+
selectin_polymorphic,
18+
selectinload,
19+
subqueryload,
20+
undefer,
21+
undefer_group,
22+
with_expression,
23+
)
24+
25+
26+
def fn(loadopt: Load, *args: object) -> loader_option:
27+
return loader_option()
28+
29+
30+
# Testing that the function and return type of function are actually all instances of "loader_option"
31+
assert_type(contains_eager, loader_option)
32+
assert_type(contains_eager(fn), loader_option)
33+
assert_type(load_only, loader_option)
34+
assert_type(load_only(fn), loader_option)
35+
assert_type(joinedload, loader_option)
36+
assert_type(joinedload(fn), loader_option)
37+
assert_type(subqueryload, loader_option)
38+
assert_type(subqueryload(fn), loader_option)
39+
assert_type(selectinload, loader_option)
40+
assert_type(selectinload(fn), loader_option)
41+
assert_type(lazyload, loader_option)
42+
assert_type(lazyload(fn), loader_option)
43+
assert_type(immediateload, loader_option)
44+
assert_type(immediateload(fn), loader_option)
45+
assert_type(noload, loader_option)
46+
assert_type(noload(fn), loader_option)
47+
assert_type(raiseload, loader_option)
48+
assert_type(raiseload(fn), loader_option)
49+
assert_type(defaultload, loader_option)
50+
assert_type(defaultload(fn), loader_option)
51+
assert_type(defer, loader_option)
52+
assert_type(defer(fn), loader_option)
53+
assert_type(undefer, loader_option)
54+
assert_type(undefer(fn), loader_option)
55+
assert_type(undefer_group, loader_option)
56+
assert_type(undefer_group(fn), loader_option)
57+
assert_type(with_expression, loader_option)
58+
assert_type(with_expression(fn), loader_option)
59+
assert_type(selectin_polymorphic, loader_option)
60+
assert_type(selectin_polymorphic(fn), loader_option)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from __future__ import annotations
2+
3+
from _typeshed.dbapi import DBAPIConnection
4+
from typing import cast
5+
6+
from sqlalchemy.engine.base import Engine
7+
from sqlalchemy.engine.default import DefaultDialect
8+
from sqlalchemy.engine.url import URL
9+
from sqlalchemy.pool.base import Pool
10+
from sqlalchemy.testing import config as ConfigModule
11+
from sqlalchemy.testing.provision import (
12+
configure_follower,
13+
create_db,
14+
drop_all_schema_objects_post_tables,
15+
drop_all_schema_objects_pre_tables,
16+
drop_db,
17+
follower_url_from_main,
18+
generate_driver_url,
19+
get_temp_table_name,
20+
post_configure_engine,
21+
prepare_for_drop_tables,
22+
register,
23+
run_reap_dbs,
24+
set_default_schema_on_connection,
25+
stop_test_class_outside_fixtures,
26+
temp_table_keyword_args,
27+
update_db_opts,
28+
)
29+
from sqlalchemy.util import immutabledict
30+
31+
url = URL("", "", "", "", 0, "", immutabledict())
32+
engine = Engine(Pool(lambda: cast(DBAPIConnection, object())), DefaultDialect(), "")
33+
config = cast(ConfigModule.Config, object())
34+
unused = None
35+
36+
37+
class Foo:
38+
pass
39+
40+
41+
# Test that the decorator changes the first parameter to "cfg: str | URL | _ConfigProtocol"
42+
@register.init
43+
def no_args(__foo: Foo) -> None:
44+
pass
45+
46+
47+
no_args(cfg="")
48+
no_args(cfg=url)
49+
no_args(cfg=config)
50+
51+
# Test pre-decorated functions
52+
generate_driver_url(url, "", "")
53+
drop_all_schema_objects_pre_tables(url, unused)
54+
drop_all_schema_objects_post_tables(url, unused)
55+
create_db(url, engine, unused)
56+
drop_db(url, engine, unused)
57+
update_db_opts(url, unused)
58+
post_configure_engine(url, unused, unused)
59+
follower_url_from_main(url, "")
60+
configure_follower(url, unused)
61+
run_reap_dbs(url, unused)
62+
temp_table_keyword_args(url, engine)
63+
prepare_for_drop_tables(url, unused)
64+
stop_test_class_outside_fixtures(url, unused, type)
65+
get_temp_table_name(url, unused, "")
66+
set_default_schema_on_connection(ConfigModule, unused, unused)
67+
set_default_schema_on_connection(config, unused, unused)

stubs/SQLAlchemy/sqlalchemy/dialects/mssql/base.pyi

-1
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,6 @@ class MSExecutionContext(default.DefaultExecutionContext):
197197
@property
198198
def rowcount(self): ...
199199
def handle_dbapi_exception(self, e) -> None: ...
200-
def get_result_cursor_strategy(self, result): ...
201200
def fire_sequence(self, seq, type_): ...
202201
def get_insert_default(self, column): ...
203202

stubs/SQLAlchemy/sqlalchemy/dialects/postgresql/base.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ class PGCompiler(compiler.SQLCompiler):
163163
class PGDDLCompiler(compiler.DDLCompiler):
164164
def get_column_specification(self, column, **kwargs): ...
165165
def visit_check_constraint(self, constraint): ...
166+
def visit_foreign_key_constraint(self, constraint) -> str: ... # type: ignore[override] # Different params
166167
def visit_drop_table_comment(self, drop): ...
167168
def visit_create_enum_type(self, create): ...
168169
def visit_drop_enum_type(self, drop): ...

stubs/SQLAlchemy/sqlalchemy/engine/interfaces.pyi

-1
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ class ExecutionContext:
116116
def pre_exec(self) -> None: ...
117117
def get_out_parameter_values(self, out_param_names) -> None: ...
118118
def post_exec(self) -> None: ...
119-
def get_result_cursor_strategy(self, result) -> None: ...
120119
def handle_dbapi_exception(self, e) -> None: ...
121120
def should_autocommit_text(self, statement) -> None: ...
122121
def lastrow_has_defaults(self) -> None: ...

stubs/SQLAlchemy/sqlalchemy/engine/url.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class _URLTuple(NamedTuple):
2222
_Query: TypeAlias = Mapping[str, str | Sequence[str]] | Sequence[tuple[str, str | Sequence[str]]]
2323

2424
class URL(_URLTuple):
25+
def __new__(self, *arg, **kw) -> Self | URL: ...
2526
@classmethod
2627
def create(
2728
cls,

stubs/SQLAlchemy/sqlalchemy/event/base.pyi

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class _Dispatch:
1212
class _EventMeta(type):
1313
def __init__(cls, classname, bases, dict_) -> None: ...
1414

15-
class Events:
15+
class Events(metaclass=_EventMeta):
1616
dispatch: Any
1717

1818
class _JoinedDispatcher:

stubs/SQLAlchemy/sqlalchemy/orm/collections.pyi

+46-8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
1-
from _typeshed import Incomplete
2-
from typing import Any
1+
from _typeshed import Incomplete, SupportsKeysAndGetItem
2+
from collections.abc import Iterable
3+
from typing import Any, TypeVar, overload
4+
from typing_extensions import Literal, SupportsIndex
5+
6+
from ..orm.attributes import Event
7+
from ..util.langhelpers import _symbol, symbol
8+
9+
_T = TypeVar("_T")
10+
_KT = TypeVar("_KT")
11+
_VT = TypeVar("_VT")
312

413
class _PlainColumnGetter:
514
cols: Any
@@ -81,12 +90,41 @@ class CollectionAdapter:
8190
def fire_remove_event(self, item, initiator: Incomplete | None = None) -> None: ...
8291
def fire_pre_remove_event(self, initiator: Incomplete | None = None) -> None: ...
8392

84-
class InstrumentedList(list[Any]): ...
85-
class InstrumentedSet(set[Any]): ...
86-
class InstrumentedDict(dict[Any, Any]): ...
93+
class InstrumentedList(list[_T]):
94+
def append(self, item, _sa_initiator: Event | Literal[False] | None = None) -> None: ...
95+
def clear(self, index: SupportsIndex = -1) -> None: ...
96+
def extend(self, iterable: Iterable[_T]) -> None: ...
97+
def insert(self, index: SupportsIndex, value: _T) -> None: ...
98+
def pop(self, index: SupportsIndex = -1) -> _T: ...
99+
def remove(self, value: _T, _sa_initiator: Event | Literal[False] | None = None) -> None: ...
100+
101+
class InstrumentedSet(set[_T]):
102+
def add(self, value: _T, _sa_initiator: Event | Literal[False] | None = None) -> None: ...
103+
def difference_update(self, value: Iterable[_T]) -> None: ... # type: ignore[override]
104+
def discard(self, value: _T, _sa_initiator: Event | Literal[False] | None = None) -> None: ...
105+
def intersection_update(self, other: Iterable[_T]) -> None: ... # type: ignore[override]
106+
def remove(self, value: _T, _sa_initiator: Event | Literal[False] | None = None) -> None: ...
107+
def symmetric_difference_update(self, other: Iterable[_T]) -> None: ...
108+
def update(self, value: Iterable[_T]) -> None: ... # type: ignore[override]
109+
110+
class InstrumentedDict(dict[_KT, _VT]): ...
87111

88-
class MappedCollection(dict[Any, Any]):
112+
class MappedCollection(dict[_KT, _VT]):
89113
keyfunc: Any
90114
def __init__(self, keyfunc) -> None: ...
91-
def set(self, value, _sa_initiator: Incomplete | None = None) -> None: ...
92-
def remove(self, value, _sa_initiator: Incomplete | None = None) -> None: ...
115+
def set(self, value: _VT, _sa_initiator: Event | Literal[False] | None = None) -> None: ...
116+
def remove(self, value: _VT, _sa_initiator: Event | Literal[False] | None = None) -> None: ...
117+
def __delitem__(self, key: _KT, _sa_initiatorEvent: Event | Literal[False] | None = None) -> None: ...
118+
def __setitem__(self, key: _KT, value: _VT, _sa_initiator: Event | Literal[False] | None = None) -> None: ...
119+
@overload
120+
def pop(self, key: _KT) -> _VT: ...
121+
@overload
122+
def pop(self, key: _KT, default: _VT | _T | _symbol | symbol = ...) -> _VT | _T: ...
123+
@overload # type: ignore[override]
124+
def setdefault(self, key: _KT, default: _T) -> _VT | _T: ...
125+
@overload
126+
def setdefault(self, key: _KT, default: None = None) -> _VT | None: ...
127+
@overload
128+
def update(self, __other: SupportsKeysAndGetItem[_KT, _VT] = ..., **kwargs: _VT) -> None: ...
129+
@overload
130+
def update(self, __other: Iterable[tuple[_KT, _VT]] = ..., **kwargs: _VT) -> None: ...

stubs/SQLAlchemy/sqlalchemy/orm/decl_api.pyi

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ _DeclT = TypeVar("_DeclT", bound=type[_DeclarativeBase])
1313

1414
# Dynamic class as created by registry.generate_base() via DeclarativeMeta
1515
# or another metaclass. This class does not exist at runtime.
16-
class _DeclarativeBase(Any): # super classes are dynamic
16+
class _DeclarativeBase(Any): # type: ignore[misc] # super classes are dynamic
1717
registry: ClassVar[registry]
1818
metadata: ClassVar[MetaData]
1919
__abstract__: ClassVar[bool]
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from _typeshed import Incomplete
2+
from collections.abc import Callable
23
from typing import Any
4+
from typing_extensions import Self
35

46
from ..sql.base import Generative
57
from .interfaces import LoaderOption
@@ -17,50 +19,67 @@ class Load(Generative, LoaderOption):
1719
propagate_to_loaders: bool
1820
def process_compile_state_replaced_entities(self, compile_state, mapper_entities) -> None: ...
1921
def process_compile_state(self, compile_state) -> None: ...
20-
def options(self, *opts) -> None: ...
21-
def set_relationship_strategy(self, attr, strategy, propagate_to_loaders: bool = True) -> None: ...
22-
def set_column_strategy(self, attrs, strategy, opts: Incomplete | None = None, opts_only: bool = False) -> None: ...
23-
def set_generic_strategy(self, attrs, strategy) -> None: ...
24-
def set_class_strategy(self, strategy, opts) -> None: ...
25-
# added dynamically at runtime
26-
def contains_eager(self, attr, alias: Incomplete | None = None): ...
27-
def load_only(self, *attrs): ...
28-
def joinedload(self, attr, innerjoin: Incomplete | None = None): ...
29-
def subqueryload(self, attr): ...
30-
def selectinload(self, attr): ...
31-
def lazyload(self, attr): ...
32-
def immediateload(self, attr): ...
33-
def noload(self, attr): ...
34-
def raiseload(self, attr, sql_only: bool = False): ...
35-
def defaultload(self, attr): ...
36-
def defer(self, key, raiseload: bool = False): ...
37-
def undefer(self, key): ...
38-
def undefer_group(self, name): ...
39-
def with_expression(self, key, expression): ...
40-
def selectin_polymorphic(self, classes): ...
22+
def options(self, *opts) -> Self: ...
23+
def set_relationship_strategy(self, attr, strategy, propagate_to_loaders: bool = True) -> Self: ...
24+
def set_column_strategy(self, attrs, strategy, opts: Incomplete | None = None, opts_only: bool = False) -> Self: ...
25+
def set_generic_strategy(self, attrs, strategy) -> Self: ...
26+
def set_class_strategy(self, strategy, opts) -> Self: ...
27+
# Added dynamically at runtime
28+
def contains_eager(loadopt, attr, alias: Incomplete | None = None) -> Self: ...
29+
def load_only(loadopt, *attrs) -> Self: ...
30+
def joinedload(loadopt, attr, innerjoin: Incomplete | None = None) -> Self: ...
31+
def subqueryload(loadopt, attr) -> Self: ...
32+
def selectinload(loadopt, attr) -> Self: ...
33+
def lazyload(loadopt, attr) -> Self: ...
34+
def immediateload(loadopt, attr) -> Self: ...
35+
def noload(loadopt, attr) -> Self: ...
36+
def raiseload(loadopt, attr, sql_only: bool = False) -> Self: ...
37+
def defaultload(loadopt, attr) -> Self: ...
38+
def defer(loadopt, key, raiseload: bool = False) -> Self: ...
39+
def undefer(loadopt, key) -> Self: ...
40+
def undefer_group(loadopt, name) -> Self: ...
41+
def with_expression(loadopt, key, expression) -> Self: ...
42+
def selectin_polymorphic(loadopt, classes) -> Self: ...
4143

4244
class _UnboundLoad(Load):
4345
path: Any
4446
local_opts: Any
4547
def __init__(self) -> None: ...
4648

4749
class loader_option:
48-
name: Any
49-
fn: Any
50-
def __call__(self, fn): ...
50+
name: str
51+
# The first parameter of this Callable should always be `loadopt: Load`
52+
fn: Callable[..., loader_option]
53+
def __call__(self, fn: Callable[..., loader_option]) -> Self: ...
5154

52-
def contains_eager(loadopt, attr, alias: Incomplete | None = ...): ...
53-
def load_only(loadopt, *attrs): ...
54-
def joinedload(loadopt, attr, innerjoin: Incomplete | None = ...): ...
55-
def subqueryload(loadopt, attr): ...
56-
def selectinload(loadopt, attr): ...
57-
def lazyload(loadopt, attr): ...
58-
def immediateload(loadopt, attr): ...
59-
def noload(loadopt, attr): ...
60-
def raiseload(loadopt, attr, sql_only: bool = ...): ...
61-
def defaultload(loadopt, attr): ...
62-
def defer(loadopt, key, raiseload: bool = ...): ...
63-
def undefer(loadopt, key): ...
64-
def undefer_group(loadopt, name): ...
65-
def with_expression(loadopt, key, expression): ...
66-
def selectin_polymorphic(loadopt, classes): ...
55+
# loader_option instances that can be used to dynamically add methods to Load at runtime
56+
@loader_option()
57+
def contains_eager(loadopt: Load, attr, alias: Incomplete | None = ...) -> loader_option: ...
58+
@loader_option()
59+
def load_only(loadopt: Load, *attrs) -> loader_option: ...
60+
@loader_option()
61+
def joinedload(loadopt, attr, innerjoin=None): ...
62+
@loader_option()
63+
def subqueryload(loadopt: Load, attr) -> loader_option: ...
64+
@loader_option()
65+
def selectinload(loadopt: Load, attr) -> loader_option: ...
66+
@loader_option()
67+
def lazyload(loadopt: Load, attr) -> loader_option: ...
68+
@loader_option()
69+
def immediateload(loadopt: Load, attr) -> loader_option: ...
70+
@loader_option()
71+
def noload(loadopt: Load, attr) -> loader_option: ...
72+
@loader_option()
73+
def raiseload(loadopt: Load, attr, sql_only: bool = ...) -> loader_option: ...
74+
@loader_option()
75+
def defaultload(loadopt: Load, attr) -> loader_option: ...
76+
@loader_option()
77+
def defer(loadopt: Load, key, raiseload: bool = ...) -> loader_option: ...
78+
@loader_option()
79+
def undefer(loadopt: Load, key) -> loader_option: ...
80+
@loader_option()
81+
def undefer_group(loadopt: Load, name) -> loader_option: ...
82+
@loader_option()
83+
def with_expression(loadopt: Load, key) -> loader_option: ...
84+
@loader_option()
85+
def selectin_polymorphic(loadopt: Load, classes) -> loader_option: ...

stubs/SQLAlchemy/sqlalchemy/pool/base.pyi

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from _typeshed import Incomplete
2+
from _typeshed.dbapi import DBAPIConnection
3+
from collections.abc import Callable
24
from typing import Any
35

46
from .. import log
@@ -24,7 +26,7 @@ class Pool(log.Identified):
2426
echo: Any
2527
def __init__(
2628
self,
27-
creator,
29+
creator: Callable[[], DBAPIConnection],
2830
recycle: int = -1,
2931
echo: Incomplete | None = None,
3032
logging_name: Incomplete | None = None,

stubs/SQLAlchemy/sqlalchemy/sql/base.pyi

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class _MetaOptions(type):
7373
def __init__(cls, classname, bases, dict_) -> None: ...
7474
def __add__(self, other): ...
7575

76-
class Options:
76+
class Options(metaclass=_MetaOptions):
7777
def __init__(self, **kw) -> None: ...
7878
def __add__(self, other): ...
7979
def __eq__(self, other): ...

0 commit comments

Comments
 (0)