@@ -114,6 +114,120 @@ impl SemanticSyntaxChecker {
114
114
}
115
115
116
116
Self :: debug_shadowing ( stmt, ctx) ;
117
+ Self :: check_annotation ( stmt, ctx) ;
118
+ }
119
+
120
+ fn check_annotation < Ctx : SemanticSyntaxContext > ( stmt : & ast:: Stmt , ctx : & Ctx ) {
121
+ match stmt {
122
+ Stmt :: FunctionDef ( ast:: StmtFunctionDef {
123
+ type_params,
124
+ parameters,
125
+ returns,
126
+ ..
127
+ } ) => {
128
+ // test_ok valid_annotation_function
129
+ // def f() -> (y := 3): ...
130
+ // def g(arg: (x := 1)): ...
131
+
132
+ // test_err invalid_annotation_function
133
+ // def f[T]() -> (y := 3): ...
134
+ // def g[T](arg: (x := 1)): ...
135
+ // def h[T](x: (yield 1)): ...
136
+ // def i(x: (yield 1)): ...
137
+ // def j[T]() -> (yield 1): ...
138
+ // def k() -> (yield 1): ...
139
+ // def l[T](x: (yield from 1)): ...
140
+ // def m(x: (yield from 1)): ...
141
+ // def n[T]() -> (yield from 1): ...
142
+ // def o() -> (yield from 1): ...
143
+ // def p[T: (yield 1)](): ... # yield in TypeVar bound
144
+ // def q[T = (yield 1)](): ... # yield in TypeVar default
145
+ // def r[*Ts = (yield 1)](): ... # yield in TypeVarTuple default
146
+ // def s[**Ts = (yield 1)](): ... # yield in ParamSpec default
147
+ // def t[T: (x := 1)](): ... # named expr in TypeVar bound
148
+ // def u[T = (x := 1)](): ... # named expr in TypeVar default
149
+ // def v[*Ts = (x := 1)](): ... # named expr in TypeVarTuple default
150
+ // def w[**Ts = (x := 1)](): ... # named expr in ParamSpec default
151
+ let is_generic = type_params. is_some ( ) ;
152
+ let mut visitor = InvalidExpressionVisitor {
153
+ allow_named_expr : !is_generic,
154
+ position : InvalidExpressionPosition :: TypeAnnotation ,
155
+ ctx,
156
+ } ;
157
+ if let Some ( type_params) = type_params {
158
+ visitor. visit_type_params ( type_params) ;
159
+ }
160
+ if is_generic {
161
+ visitor. position = InvalidExpressionPosition :: GenericDefinition ;
162
+ } else {
163
+ visitor. position = InvalidExpressionPosition :: TypeAnnotation ;
164
+ }
165
+ for param in parameters
166
+ . iter ( )
167
+ . filter_map ( ast:: AnyParameterRef :: annotation)
168
+ {
169
+ visitor. visit_expr ( param) ;
170
+ }
171
+ if let Some ( returns) = returns {
172
+ visitor. visit_expr ( returns) ;
173
+ }
174
+ }
175
+ Stmt :: ClassDef ( ast:: StmtClassDef {
176
+ type_params,
177
+ arguments,
178
+ ..
179
+ } ) => {
180
+ // test_ok valid_annotation_class
181
+ // class F(y := list): ...
182
+
183
+ // test_err invalid_annotation_class
184
+ // class F[T](y := list): ...
185
+ // class G((yield 1)): ...
186
+ // class H((yield from 1)): ...
187
+ // class I[T]((yield 1)): ...
188
+ // class J[T]((yield from 1)): ...
189
+ // class K[T: (yield 1)]: ... # yield in TypeVar
190
+ // class L[T: (x := 1)]: ... # named expr in TypeVar
191
+ let is_generic = type_params. is_some ( ) ;
192
+ let mut visitor = InvalidExpressionVisitor {
193
+ allow_named_expr : !is_generic,
194
+ position : InvalidExpressionPosition :: TypeAnnotation ,
195
+ ctx,
196
+ } ;
197
+ if let Some ( type_params) = type_params {
198
+ visitor. visit_type_params ( type_params) ;
199
+ }
200
+ if is_generic {
201
+ visitor. position = InvalidExpressionPosition :: GenericDefinition ;
202
+ } else {
203
+ visitor. position = InvalidExpressionPosition :: BaseClass ;
204
+ }
205
+ if let Some ( arguments) = arguments {
206
+ visitor. visit_arguments ( arguments) ;
207
+ }
208
+ }
209
+ Stmt :: TypeAlias ( ast:: StmtTypeAlias {
210
+ type_params, value, ..
211
+ } ) => {
212
+ // test_err invalid_annotation_type_alias
213
+ // type X[T: (yield 1)] = int # TypeVar bound
214
+ // type X[T = (yield 1)] = int # TypeVar default
215
+ // type X[*Ts = (yield 1)] = int # TypeVarTuple default
216
+ // type X[**Ts = (yield 1)] = int # ParamSpec default
217
+ // type Y = (yield 1) # yield in value
218
+ // type Y = (x := 1) # named expr in value
219
+ let mut visitor = InvalidExpressionVisitor {
220
+ allow_named_expr : false ,
221
+ position : InvalidExpressionPosition :: TypeAlias ,
222
+ ctx,
223
+ } ;
224
+ visitor. visit_expr ( value) ;
225
+ if let Some ( type_params) = type_params {
226
+ visitor. visit_type_params ( type_params) ;
227
+ }
228
+ }
229
+ _ => { }
230
+ }
117
231
}
118
232
119
233
/// Emit a [`SemanticSyntaxErrorKind::InvalidStarExpression`] if `expr` is starred.
@@ -511,6 +625,15 @@ impl Display for SemanticSyntaxError {
511
625
write ! ( f, "cannot delete `__debug__` on Python {python_version} (syntax was removed in 3.9)" )
512
626
}
513
627
} ,
628
+ SemanticSyntaxErrorKind :: InvalidExpression (
629
+ kind,
630
+ InvalidExpressionPosition :: BaseClass ,
631
+ ) => {
632
+ write ! ( f, "{kind} cannot be used as a base class" )
633
+ }
634
+ SemanticSyntaxErrorKind :: InvalidExpression ( kind, position) => {
635
+ write ! ( f, "{kind} cannot be used within a {position}" )
636
+ }
514
637
SemanticSyntaxErrorKind :: DuplicateMatchKey ( key) => {
515
638
write ! (
516
639
f,
@@ -641,6 +764,21 @@ pub enum SemanticSyntaxErrorKind {
641
764
/// [BPO 45000]: https://github.com/python/cpython/issues/89163
642
765
WriteToDebug ( WriteToDebugKind ) ,
643
766
767
+ /// Represents the use of an invalid expression kind in one of several locations.
768
+ ///
769
+ /// The kinds include `yield` and `yield from` expressions and named expressions, and locations
770
+ /// include type parameter bounds and defaults, type annotations, type aliases, and base class
771
+ /// lists.
772
+ ///
773
+ /// ## Examples
774
+ ///
775
+ /// ```python
776
+ /// type X[T: (yield 1)] = int
777
+ /// type Y = (yield 1)
778
+ /// def f[T](x: int) -> (y := 3): return x
779
+ /// ```
780
+ InvalidExpression ( InvalidExpressionKind , InvalidExpressionPosition ) ,
781
+
644
782
/// Represents a duplicate key in a `match` mapping pattern.
645
783
///
646
784
/// The [CPython grammar] allows keys in mapping patterns to be literals or attribute accesses:
@@ -713,6 +851,48 @@ pub enum SemanticSyntaxErrorKind {
713
851
InvalidStarExpression ,
714
852
}
715
853
854
+ #[ derive( Debug , Clone , Copy , PartialEq , Eq , Hash ) ]
855
+ pub enum InvalidExpressionPosition {
856
+ TypeVarBound ,
857
+ TypeVarDefault ,
858
+ TypeVarTupleDefault ,
859
+ ParamSpecDefault ,
860
+ TypeAnnotation ,
861
+ BaseClass ,
862
+ GenericDefinition ,
863
+ TypeAlias ,
864
+ }
865
+
866
+ impl Display for InvalidExpressionPosition {
867
+ fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
868
+ f. write_str ( match self {
869
+ InvalidExpressionPosition :: TypeVarBound => "TypeVar bound" ,
870
+ InvalidExpressionPosition :: TypeVarDefault => "TypeVar default" ,
871
+ InvalidExpressionPosition :: TypeVarTupleDefault => "TypeVarTuple default" ,
872
+ InvalidExpressionPosition :: ParamSpecDefault => "ParamSpec default" ,
873
+ InvalidExpressionPosition :: TypeAnnotation => "type annotation" ,
874
+ InvalidExpressionPosition :: GenericDefinition => "generic definition" ,
875
+ InvalidExpressionPosition :: BaseClass => "base class" ,
876
+ InvalidExpressionPosition :: TypeAlias => "type alias" ,
877
+ } )
878
+ }
879
+ }
880
+
881
+ #[ derive( Debug , Clone , PartialEq , Eq , Hash ) ]
882
+ pub enum InvalidExpressionKind {
883
+ Yield ,
884
+ NamedExpr ,
885
+ }
886
+
887
+ impl Display for InvalidExpressionKind {
888
+ fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
889
+ f. write_str ( match self {
890
+ InvalidExpressionKind :: Yield => "yield expression" ,
891
+ InvalidExpressionKind :: NamedExpr => "named expression" ,
892
+ } )
893
+ }
894
+ }
895
+
716
896
#[ derive( Debug , Clone , PartialEq , Eq , Hash ) ]
717
897
pub enum WriteToDebugKind {
718
898
Store ,
@@ -905,6 +1085,83 @@ impl<'a, Ctx: SemanticSyntaxContext> MatchPatternVisitor<'a, Ctx> {
905
1085
}
906
1086
}
907
1087
1088
+ struct InvalidExpressionVisitor < ' a , Ctx > {
1089
+ /// Allow named expressions (`x := ...`) to appear in annotations.
1090
+ ///
1091
+ /// These are allowed in non-generic functions, for example:
1092
+ ///
1093
+ /// ```python
1094
+ /// def foo(arg: (x := int)): ... # ok
1095
+ /// def foo[T](arg: (x := int)): ... # syntax error
1096
+ /// ```
1097
+ allow_named_expr : bool ,
1098
+
1099
+ /// Context used for emitting errors.
1100
+ ctx : & ' a Ctx ,
1101
+
1102
+ position : InvalidExpressionPosition ,
1103
+ }
1104
+
1105
+ impl < Ctx > Visitor < ' _ > for InvalidExpressionVisitor < ' _ , Ctx >
1106
+ where
1107
+ Ctx : SemanticSyntaxContext ,
1108
+ {
1109
+ fn visit_expr ( & mut self , expr : & Expr ) {
1110
+ match expr {
1111
+ Expr :: Named ( ast:: ExprNamed { range, .. } ) if !self . allow_named_expr => {
1112
+ SemanticSyntaxChecker :: add_error (
1113
+ self . ctx ,
1114
+ SemanticSyntaxErrorKind :: InvalidExpression (
1115
+ InvalidExpressionKind :: NamedExpr ,
1116
+ self . position ,
1117
+ ) ,
1118
+ * range,
1119
+ ) ;
1120
+ }
1121
+ Expr :: Yield ( ast:: ExprYield { range, .. } )
1122
+ | Expr :: YieldFrom ( ast:: ExprYieldFrom { range, .. } ) => {
1123
+ SemanticSyntaxChecker :: add_error (
1124
+ self . ctx ,
1125
+ SemanticSyntaxErrorKind :: InvalidExpression (
1126
+ InvalidExpressionKind :: Yield ,
1127
+ self . position ,
1128
+ ) ,
1129
+ * range,
1130
+ ) ;
1131
+ }
1132
+ _ => { }
1133
+ }
1134
+ ast:: visitor:: walk_expr ( self , expr) ;
1135
+ }
1136
+
1137
+ fn visit_type_param ( & mut self , type_param : & ast:: TypeParam ) {
1138
+ match type_param {
1139
+ ast:: TypeParam :: TypeVar ( ast:: TypeParamTypeVar { bound, default, .. } ) => {
1140
+ if let Some ( expr) = bound {
1141
+ self . position = InvalidExpressionPosition :: TypeVarBound ;
1142
+ self . visit_expr ( expr) ;
1143
+ }
1144
+ if let Some ( expr) = default {
1145
+ self . position = InvalidExpressionPosition :: TypeVarDefault ;
1146
+ self . visit_expr ( expr) ;
1147
+ }
1148
+ }
1149
+ ast:: TypeParam :: TypeVarTuple ( ast:: TypeParamTypeVarTuple { default, .. } ) => {
1150
+ if let Some ( expr) = default {
1151
+ self . position = InvalidExpressionPosition :: TypeVarTupleDefault ;
1152
+ self . visit_expr ( expr) ;
1153
+ }
1154
+ }
1155
+ ast:: TypeParam :: ParamSpec ( ast:: TypeParamParamSpec { default, .. } ) => {
1156
+ if let Some ( expr) = default {
1157
+ self . position = InvalidExpressionPosition :: ParamSpecDefault ;
1158
+ self . visit_expr ( expr) ;
1159
+ }
1160
+ }
1161
+ } ;
1162
+ }
1163
+ }
1164
+
908
1165
pub trait SemanticSyntaxContext {
909
1166
/// Returns `true` if a module's docstring boundary has been passed.
910
1167
fn seen_docstring_boundary ( & self ) -> bool ;
0 commit comments