28
28
//! definitions once the rest of the types in the scope have been inferred.
29
29
use std:: num:: NonZeroU32 ;
30
30
31
- use rustc_hash:: FxHashMap ;
32
- use salsa;
33
- use salsa:: plumbing:: AsId ;
34
-
35
31
use ruff_db:: files:: File ;
36
32
use ruff_db:: parsed:: parsed_module;
37
33
use ruff_python_ast:: { self as ast, AnyNodeRef , ExprContext , UnaryOp } ;
38
34
use ruff_text_size:: Ranged ;
35
+ use rustc_hash:: FxHashMap ;
36
+ use salsa;
37
+ use salsa:: plumbing:: AsId ;
39
38
40
39
use crate :: module_name:: ModuleName ;
41
40
use crate :: module_resolver:: { file_to_module, resolve_module} ;
@@ -52,7 +51,7 @@ use crate::types::diagnostic::{TypeCheckDiagnostic, TypeCheckDiagnostics};
52
51
use crate :: types:: {
53
52
bindings_ty, builtins_symbol_ty, declarations_ty, global_symbol_ty, symbol_ty,
54
53
typing_extensions_symbol_ty, BytesLiteralType , ClassType , FunctionType , StringLiteralType ,
55
- TupleType , Type , TypeArrayDisplay , UnionType ,
54
+ Truthiness , TupleType , Type , TypeArrayDisplay , UnionType ,
56
55
} ;
57
56
use crate :: Db ;
58
57
@@ -2318,16 +2317,35 @@ impl<'db> TypeInferenceBuilder<'db> {
2318
2317
fn infer_boolean_expression ( & mut self , bool_op : & ast:: ExprBoolOp ) -> Type < ' db > {
2319
2318
let ast:: ExprBoolOp {
2320
2319
range : _,
2321
- op : _ ,
2320
+ op,
2322
2321
values,
2323
2322
} = bool_op;
2324
-
2325
- for value in values {
2326
- self . infer_expression ( value) ;
2327
- }
2328
-
2329
- // TODO resolve bool op
2330
- Type :: Unknown
2323
+ let mut done = false ;
2324
+ UnionType :: from_elements (
2325
+ self . db ,
2326
+ values. iter ( ) . enumerate ( ) . map ( |( i, value) | {
2327
+ // We need to infer the type of every expression (that's an invariant maintained by
2328
+ // type inference), even if we can short-circuit boolean evaluation of some of
2329
+ // those types.
2330
+ let value_ty = self . infer_expression ( value) ;
2331
+ if done {
2332
+ Type :: Never
2333
+ } else {
2334
+ let is_last = i == values. len ( ) - 1 ;
2335
+ match ( value_ty. bool ( self . db ) , is_last, op) {
2336
+ ( Truthiness :: Ambiguous , _, _) => value_ty,
2337
+ ( Truthiness :: AlwaysTrue , false , ast:: BoolOp :: And ) => Type :: Never ,
2338
+ ( Truthiness :: AlwaysFalse , false , ast:: BoolOp :: Or ) => Type :: Never ,
2339
+ ( Truthiness :: AlwaysFalse , _, ast:: BoolOp :: And )
2340
+ | ( Truthiness :: AlwaysTrue , _, ast:: BoolOp :: Or ) => {
2341
+ done = true ;
2342
+ value_ty
2343
+ }
2344
+ ( _, true , _) => value_ty,
2345
+ }
2346
+ }
2347
+ } ) ,
2348
+ )
2331
2349
}
2332
2350
2333
2351
fn infer_compare_expression ( & mut self , compare : & ast:: ExprCompare ) -> Type < ' db > {
@@ -6048,4 +6066,96 @@ mod tests {
6048
6066
) ;
6049
6067
Ok ( ( ) )
6050
6068
}
6069
+
6070
+ #[ test]
6071
+ fn boolean_or_expression ( ) -> anyhow:: Result < ( ) > {
6072
+ let mut db = setup_db ( ) ;
6073
+
6074
+ db. write_dedented (
6075
+ "/src/a.py" ,
6076
+ "
6077
+ def foo() -> str:
6078
+ pass
6079
+
6080
+ a = True or False
6081
+ b = 'x' or 'y' or 'z'
6082
+ c = '' or 'y' or 'z'
6083
+ d = False or 'z'
6084
+ e = False or True
6085
+ f = False or False
6086
+ g = foo() or False
6087
+ h = foo() or True
6088
+ " ,
6089
+ ) ?;
6090
+
6091
+ assert_public_ty ( & db, "/src/a.py" , "a" , "Literal[True]" ) ;
6092
+ assert_public_ty ( & db, "/src/a.py" , "b" , r#"Literal["x"]"# ) ;
6093
+ assert_public_ty ( & db, "/src/a.py" , "c" , r#"Literal["y"]"# ) ;
6094
+ assert_public_ty ( & db, "/src/a.py" , "d" , r#"Literal["z"]"# ) ;
6095
+ assert_public_ty ( & db, "/src/a.py" , "e" , "Literal[True]" ) ;
6096
+ assert_public_ty ( & db, "/src/a.py" , "f" , "Literal[False]" ) ;
6097
+ assert_public_ty ( & db, "/src/a.py" , "g" , "str | Literal[False]" ) ;
6098
+ assert_public_ty ( & db, "/src/a.py" , "h" , "str | Literal[True]" ) ;
6099
+
6100
+ Ok ( ( ) )
6101
+ }
6102
+
6103
+ #[ test]
6104
+ fn boolean_and_expression ( ) -> anyhow:: Result < ( ) > {
6105
+ let mut db = setup_db ( ) ;
6106
+
6107
+ db. write_dedented (
6108
+ "/src/a.py" ,
6109
+ "
6110
+ def foo() -> str:
6111
+ pass
6112
+
6113
+ a = True and False
6114
+ b = False and True
6115
+ c = foo() and False
6116
+ d = foo() and True
6117
+ e = 'x' and 'y' and 'z'
6118
+ f = 'x' and 'y' and ''
6119
+ g = '' and 'y'
6120
+ " ,
6121
+ ) ?;
6122
+
6123
+ assert_public_ty ( & db, "/src/a.py" , "a" , "Literal[False]" ) ;
6124
+ assert_public_ty ( & db, "/src/a.py" , "b" , "Literal[False]" ) ;
6125
+ assert_public_ty ( & db, "/src/a.py" , "c" , "str | Literal[False]" ) ;
6126
+ assert_public_ty ( & db, "/src/a.py" , "d" , "str | Literal[True]" ) ;
6127
+ assert_public_ty ( & db, "/src/a.py" , "e" , r#"Literal["z"]"# ) ;
6128
+ assert_public_ty ( & db, "/src/a.py" , "f" , r#"Literal[""]"# ) ;
6129
+ assert_public_ty ( & db, "/src/a.py" , "g" , r#"Literal[""]"# ) ;
6130
+ Ok ( ( ) )
6131
+ }
6132
+
6133
+ #[ test]
6134
+ fn boolean_complex_expression ( ) -> anyhow:: Result < ( ) > {
6135
+ let mut db = setup_db ( ) ;
6136
+
6137
+ db. write_dedented (
6138
+ "/src/a.py" ,
6139
+ r#"
6140
+ def foo() -> str:
6141
+ pass
6142
+
6143
+ a = "x" and "y" or "z"
6144
+ b = "x" or "y" and "z"
6145
+ c = "" and "y" or "z"
6146
+ d = "" or "y" and "z"
6147
+ e = "x" and "y" or ""
6148
+ f = "x" or "y" and ""
6149
+
6150
+ "# ,
6151
+ ) ?;
6152
+
6153
+ assert_public_ty ( & db, "/src/a.py" , "a" , r#"Literal["y"]"# ) ;
6154
+ assert_public_ty ( & db, "/src/a.py" , "b" , r#"Literal["x"]"# ) ;
6155
+ assert_public_ty ( & db, "/src/a.py" , "c" , r#"Literal["z"]"# ) ;
6156
+ assert_public_ty ( & db, "/src/a.py" , "d" , r#"Literal["z"]"# ) ;
6157
+ assert_public_ty ( & db, "/src/a.py" , "e" , r#"Literal["y"]"# ) ;
6158
+ assert_public_ty ( & db, "/src/a.py" , "f" , r#"Literal["x"]"# ) ;
6159
+ Ok ( ( ) )
6160
+ }
6051
6161
}
0 commit comments