@@ -238,6 +238,8 @@ pub enum Type<'db> {
238
238
None ,
239
239
/// a specific function object
240
240
Function ( FunctionType < ' db > ) ,
241
+ /// The `typing.reveal_type` function, which has special `__call__` behavior.
242
+ RevealTypeFunction ( FunctionType < ' db > ) ,
241
243
/// a specific module object
242
244
Module ( File ) ,
243
245
/// a specific class object
@@ -324,14 +326,16 @@ impl<'db> Type<'db> {
324
326
325
327
pub const fn into_function_type ( self ) -> Option < FunctionType < ' db > > {
326
328
match self {
327
- Type :: Function ( function_type) => Some ( function_type) ,
329
+ Type :: Function ( function_type) | Type :: RevealTypeFunction ( function_type) => {
330
+ Some ( function_type)
331
+ }
328
332
_ => None ,
329
333
}
330
334
}
331
335
332
336
pub fn expect_function ( self ) -> FunctionType < ' db > {
333
337
self . into_function_type ( )
334
- . expect ( "Expected a Type::Function variant" )
338
+ . expect ( "Expected a variant wrapping a FunctionType " )
335
339
}
336
340
337
341
pub const fn into_int_literal_type ( self ) -> Option < i64 > {
@@ -367,6 +371,16 @@ impl<'db> Type<'db> {
367
371
}
368
372
}
369
373
374
+ pub fn is_stdlib_symbol ( & self , db : & ' db dyn Db , module_name : & str , name : & str ) -> bool {
375
+ match self {
376
+ Type :: Class ( class) => class. is_stdlib_symbol ( db, module_name, name) ,
377
+ Type :: Function ( function) | Type :: RevealTypeFunction ( function) => {
378
+ function. is_stdlib_symbol ( db, module_name, name)
379
+ }
380
+ _ => false ,
381
+ }
382
+ }
383
+
370
384
/// Return true if this type is [assignable to] type `target`.
371
385
///
372
386
/// [assignable to]: https://typing.readthedocs.io/en/latest/spec/concepts.html#the-assignable-to-or-consistent-subtyping-relation
@@ -436,7 +450,7 @@ impl<'db> Type<'db> {
436
450
// TODO: attribute lookup on None type
437
451
Type :: Unknown
438
452
}
439
- Type :: Function ( _) => {
453
+ Type :: Function ( _) | Type :: RevealTypeFunction ( _ ) => {
440
454
// TODO: attribute lookup on function type
441
455
Type :: Unknown
442
456
}
@@ -482,26 +496,39 @@ impl<'db> Type<'db> {
482
496
///
483
497
/// Returns `None` if `self` is not a callable type.
484
498
#[ must_use]
485
- pub fn call ( & self , db : & ' db dyn Db ) -> Option < Type < ' db > > {
499
+ fn call ( self , db : & ' db dyn Db , arg_types : & [ Type < ' db > ] ) -> CallOutcome < ' db > {
486
500
match self {
487
- Type :: Function ( function_type) => Some ( function_type. return_type ( db) ) ,
501
+ // TODO validate typed call arguments vs callable signature
502
+ Type :: Function ( function_type) => CallOutcome :: callable ( function_type. return_type ( db) ) ,
503
+ Type :: RevealTypeFunction ( function_type) => CallOutcome :: revealed (
504
+ function_type. return_type ( db) ,
505
+ * arg_types. first ( ) . unwrap_or ( & Type :: Unknown ) ,
506
+ ) ,
488
507
489
508
// TODO annotated return type on `__new__` or metaclass `__call__`
490
- Type :: Class ( class) => Some ( Type :: Instance ( * class) ) ,
509
+ Type :: Class ( class) => CallOutcome :: callable ( Type :: Instance ( class) ) ,
491
510
492
- // TODO: handle classes which implement the Callable protocol
493
- Type :: Instance ( _instance_ty) => Some ( Type :: Unknown ) ,
511
+ // TODO: handle classes which implement the `__call__` protocol
512
+ Type :: Instance ( _instance_ty) => CallOutcome :: callable ( Type :: Unknown ) ,
494
513
495
514
// `Any` is callable, and its return type is also `Any`.
496
- Type :: Any => Some ( Type :: Any ) ,
515
+ Type :: Any => CallOutcome :: callable ( Type :: Any ) ,
497
516
498
- Type :: Unknown => Some ( Type :: Unknown ) ,
517
+ Type :: Unknown => CallOutcome :: callable ( Type :: Unknown ) ,
499
518
500
- // TODO: union and intersection types, if they reduce to `Callable`
501
- Type :: Union ( _) => Some ( Type :: Unknown ) ,
502
- Type :: Intersection ( _) => Some ( Type :: Unknown ) ,
519
+ Type :: Union ( union) => CallOutcome :: union (
520
+ self ,
521
+ union
522
+ . elements ( db)
523
+ . iter ( )
524
+ . map ( |elem| elem. call ( db, arg_types) )
525
+ . collect :: < Box < [ CallOutcome < ' db > ] > > ( ) ,
526
+ ) ,
503
527
504
- _ => None ,
528
+ // TODO: intersection types
529
+ Type :: Intersection ( _) => CallOutcome :: callable ( Type :: Unknown ) ,
530
+
531
+ _ => CallOutcome :: not_callable ( self ) ,
505
532
}
506
533
}
507
534
@@ -513,7 +540,7 @@ impl<'db> Type<'db> {
513
540
/// for y in x:
514
541
/// pass
515
542
/// ```
516
- fn iterate ( & self , db : & ' db dyn Db ) -> IterationOutcome < ' db > {
543
+ fn iterate ( self , db : & ' db dyn Db ) -> IterationOutcome < ' db > {
517
544
if let Type :: Tuple ( tuple_type) = self {
518
545
return IterationOutcome :: Iterable {
519
546
element_ty : UnionType :: from_elements ( db, & * * tuple_type. elements ( db) ) ,
@@ -526,18 +553,22 @@ impl<'db> Type<'db> {
526
553
527
554
let dunder_iter_method = iterable_meta_type. member ( db, "__iter__" ) ;
528
555
if !dunder_iter_method. is_unbound ( ) {
529
- let Some ( iterator_ty) = dunder_iter_method. call ( db) else {
556
+ let CallOutcome :: Callable {
557
+ return_ty : iterator_ty,
558
+ } = dunder_iter_method. call ( db, & [ ] )
559
+ else {
530
560
return IterationOutcome :: NotIterable {
531
- not_iterable_ty : * self ,
561
+ not_iterable_ty : self ,
532
562
} ;
533
563
} ;
534
564
535
565
let dunder_next_method = iterator_ty. to_meta_type ( db) . member ( db, "__next__" ) ;
536
566
return dunder_next_method
537
- . call ( db)
567
+ . call ( db, & [ ] )
568
+ . return_ty ( db)
538
569
. map ( |element_ty| IterationOutcome :: Iterable { element_ty } )
539
570
. unwrap_or ( IterationOutcome :: NotIterable {
540
- not_iterable_ty : * self ,
571
+ not_iterable_ty : self ,
541
572
} ) ;
542
573
}
543
574
@@ -550,10 +581,11 @@ impl<'db> Type<'db> {
550
581
let dunder_get_item_method = iterable_meta_type. member ( db, "__getitem__" ) ;
551
582
552
583
dunder_get_item_method
553
- . call ( db)
584
+ . call ( db, & [ ] )
585
+ . return_ty ( db)
554
586
. map ( |element_ty| IterationOutcome :: Iterable { element_ty } )
555
587
. unwrap_or ( IterationOutcome :: NotIterable {
556
- not_iterable_ty : * self ,
588
+ not_iterable_ty : self ,
557
589
} )
558
590
}
559
591
@@ -573,6 +605,7 @@ impl<'db> Type<'db> {
573
605
Type :: BooleanLiteral ( _)
574
606
| Type :: BytesLiteral ( _)
575
607
| Type :: Function ( _)
608
+ | Type :: RevealTypeFunction ( _)
576
609
| Type :: Instance ( _)
577
610
| Type :: Module ( _)
578
611
| Type :: IntLiteral ( _)
@@ -595,7 +628,7 @@ impl<'db> Type<'db> {
595
628
Type :: BooleanLiteral ( _) => builtins_symbol_ty ( db, "bool" ) ,
596
629
Type :: BytesLiteral ( _) => builtins_symbol_ty ( db, "bytes" ) ,
597
630
Type :: IntLiteral ( _) => builtins_symbol_ty ( db, "int" ) ,
598
- Type :: Function ( _) => types_symbol_ty ( db, "FunctionType" ) ,
631
+ Type :: Function ( _) | Type :: RevealTypeFunction ( _ ) => types_symbol_ty ( db, "FunctionType" ) ,
599
632
Type :: Module ( _) => types_symbol_ty ( db, "ModuleType" ) ,
600
633
Type :: None => typeshed_symbol_ty ( db, "NoneType" ) ,
601
634
// TODO not accurate if there's a custom metaclass...
@@ -619,6 +652,152 @@ impl<'db> From<&Type<'db>> for Type<'db> {
619
652
}
620
653
}
621
654
655
+ #[ derive( Debug , Clone , PartialEq , Eq ) ]
656
+ enum CallOutcome < ' db > {
657
+ Callable {
658
+ return_ty : Type < ' db > ,
659
+ } ,
660
+ RevealType {
661
+ return_ty : Type < ' db > ,
662
+ revealed_ty : Type < ' db > ,
663
+ } ,
664
+ NotCallable {
665
+ not_callable_ty : Type < ' db > ,
666
+ } ,
667
+ Union {
668
+ called_ty : Type < ' db > ,
669
+ outcomes : Box < [ CallOutcome < ' db > ] > ,
670
+ } ,
671
+ }
672
+
673
+ impl < ' db > CallOutcome < ' db > {
674
+ /// Create a new `CallOutcome::Callable` with given return type.
675
+ fn callable ( return_ty : Type < ' db > ) -> CallOutcome {
676
+ CallOutcome :: Callable { return_ty }
677
+ }
678
+
679
+ /// Create a new `CallOutcome::NotCallable` with given not-callable type.
680
+ fn not_callable ( not_callable_ty : Type < ' db > ) -> CallOutcome {
681
+ CallOutcome :: NotCallable { not_callable_ty }
682
+ }
683
+
684
+ /// Create a new `CallOutcome::RevealType` with given revealed and return types.
685
+ fn revealed ( return_ty : Type < ' db > , revealed_ty : Type < ' db > ) -> CallOutcome < ' db > {
686
+ CallOutcome :: RevealType {
687
+ return_ty,
688
+ revealed_ty,
689
+ }
690
+ }
691
+
692
+ /// Create a new `CallOutcome::Union` with given wrapped outcomes.
693
+ fn union ( called_ty : Type < ' db > , outcomes : impl Into < Box < [ CallOutcome < ' db > ] > > ) -> CallOutcome {
694
+ CallOutcome :: Union {
695
+ called_ty,
696
+ outcomes : outcomes. into ( ) ,
697
+ }
698
+ }
699
+
700
+ /// Get the return type of the call, or `None` if not callable.
701
+ fn return_ty ( & self , db : & ' db dyn Db ) -> Option < Type < ' db > > {
702
+ match self {
703
+ Self :: Callable { return_ty } => Some ( * return_ty) ,
704
+ Self :: RevealType {
705
+ return_ty,
706
+ revealed_ty : _,
707
+ } => Some ( * return_ty) ,
708
+ Self :: NotCallable { not_callable_ty : _ } => None ,
709
+ Self :: Union {
710
+ outcomes,
711
+ called_ty : _,
712
+ } => outcomes
713
+ . iter ( )
714
+ // If all outcomes are NotCallable, we return None; if some outcomes are callable
715
+ // and some are not, we return a union including Unknown.
716
+ . fold ( None , |acc, outcome| {
717
+ let ty = outcome. return_ty ( db) ;
718
+ match ( acc, ty) {
719
+ ( None , None ) => None ,
720
+ ( None , Some ( ty) ) => Some ( UnionBuilder :: new ( db) . add ( ty) ) ,
721
+ ( Some ( builder) , ty) => Some ( builder. add ( ty. unwrap_or ( Type :: Unknown ) ) ) ,
722
+ }
723
+ } )
724
+ . map ( UnionBuilder :: build) ,
725
+ }
726
+ }
727
+
728
+ /// Get the return type of the call, emitting diagnostics if needed.
729
+ fn unwrap_with_diagnostic < ' a > (
730
+ & self ,
731
+ db : & ' db dyn Db ,
732
+ node : ast:: AnyNodeRef ,
733
+ builder : & ' a mut TypeInferenceBuilder < ' db > ,
734
+ ) -> Type < ' db > {
735
+ match self {
736
+ Self :: Callable { return_ty } => * return_ty,
737
+ Self :: RevealType {
738
+ return_ty,
739
+ revealed_ty,
740
+ } => {
741
+ builder. add_diagnostic (
742
+ node,
743
+ "revealed-type" ,
744
+ format_args ! ( "Revealed type is '{}'." , revealed_ty. display( db) ) ,
745
+ ) ;
746
+ * return_ty
747
+ }
748
+ Self :: NotCallable { not_callable_ty } => {
749
+ builder. add_diagnostic (
750
+ node,
751
+ "call-non-callable" ,
752
+ format_args ! (
753
+ "Object of type '{}' is not callable." ,
754
+ not_callable_ty. display( db)
755
+ ) ,
756
+ ) ;
757
+ Type :: Unknown
758
+ }
759
+ Self :: Union {
760
+ outcomes,
761
+ called_ty,
762
+ } => {
763
+ let mut not_callable = vec ! [ ] ;
764
+ let mut union_builder = UnionBuilder :: new ( db) ;
765
+ for outcome in & * * outcomes {
766
+ let return_ty = if let Self :: NotCallable { not_callable_ty } = outcome {
767
+ not_callable. push ( * not_callable_ty) ;
768
+ Type :: Unknown
769
+ } else {
770
+ outcome. unwrap_with_diagnostic ( db, node, builder)
771
+ } ;
772
+ union_builder = union_builder. add ( return_ty) ;
773
+ }
774
+ match not_callable[ ..] {
775
+ [ ] => { }
776
+ [ elem] => builder. add_diagnostic (
777
+ node,
778
+ "call-non-callable" ,
779
+ format_args ! (
780
+ "Union element '{}' of type '{}' is not callable." ,
781
+ elem. display( db) ,
782
+ called_ty. display( db)
783
+ ) ,
784
+ ) ,
785
+ _ => builder. add_diagnostic (
786
+ node,
787
+ "call-non-callable" ,
788
+ format_args ! (
789
+ "Union elements {} of type '{}' are not callable." ,
790
+ not_callable. display( db) ,
791
+ called_ty. display( db)
792
+ ) ,
793
+ ) ,
794
+ }
795
+ union_builder. build ( )
796
+ }
797
+ }
798
+ }
799
+ }
800
+
622
801
#[ derive( Debug , Clone , Copy , PartialEq , Eq ) ]
623
802
enum IterationOutcome < ' db > {
624
803
Iterable { element_ty : Type < ' db > } ,
@@ -654,6 +833,14 @@ pub struct FunctionType<'db> {
654
833
}
655
834
656
835
impl < ' db > FunctionType < ' db > {
836
+ /// Return true if this is a standard library function with given module name and name.
837
+ pub ( crate ) fn is_stdlib_symbol ( self , db : & ' db dyn Db , module_name : & str , name : & str ) -> bool {
838
+ name == self . name ( db)
839
+ && file_to_module ( db, self . definition ( db) . file ( db) ) . is_some_and ( |module| {
840
+ module. search_path ( ) . is_standard_library ( ) && module. name ( ) == module_name
841
+ } )
842
+ }
843
+
657
844
pub fn has_decorator ( self , db : & dyn Db , decorator : Type < ' _ > ) -> bool {
658
845
self . decorators ( db) . contains ( & decorator)
659
846
}
0 commit comments