@@ -173,7 +173,7 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749
173
173
headers = None , # a dict to be sent as request headers
174
174
post = None , # A callable to replace requests.post(), for testing.
175
175
# Such as: lambda url, **kwargs:
176
- # Mock(status_code=200, json=Mock(return_value={}) )
176
+ # Mock(status_code=200, text='{}' )
177
177
** kwargs # Relay all extra parameters to underlying requests
178
178
): # Returns the json object came from the OAUTH2 response
179
179
_data = {'client_id' : self .client_id , 'grant_type' : grant_type }
@@ -454,17 +454,20 @@ def __init__(self,
454
454
self .on_removing_rt = on_removing_rt
455
455
self .on_updating_rt = on_updating_rt
456
456
457
- def _obtain_token (self , grant_type , params = None , data = None , * args , ** kwargs ):
458
- RT = "refresh_token"
457
+ def _obtain_token (
458
+ self , grant_type , params = None , data = None ,
459
+ also_save_rt = False ,
460
+ * args , ** kwargs ):
459
461
_data = data .copy () # to prevent side effect
460
- refresh_token = _data .get (RT )
461
462
resp = super (Client , self )._obtain_token (
462
463
grant_type , params , _data , * args , ** kwargs )
463
464
if "error" not in resp :
464
465
_resp = resp .copy ()
465
- if grant_type == RT and RT in _resp and isinstance (refresh_token , dict ):
466
- _resp .pop (RT ) # So we skip it in on_obtaining_tokens(); it will
467
- # be handled in self.obtain_token_by_refresh_token()
466
+ RT = "refresh_token"
467
+ if grant_type == RT and RT in _resp and not also_save_rt :
468
+ # Then we skip it from on_obtaining_tokens();
469
+ # Leave it to self.obtain_token_by_refresh_token()
470
+ _resp .pop (RT , None )
468
471
if "scope" in _resp :
469
472
scope = _resp ["scope" ].split () # It is conceptually a set,
470
473
# but we represent it as a list which can be persisted to JSON
@@ -486,6 +489,7 @@ def _obtain_token(self, grant_type, params=None, data=None, *args, **kwargs):
486
489
def obtain_token_by_refresh_token (self , token_item , scope = None ,
487
490
rt_getter = lambda token_item : token_item ["refresh_token" ],
488
491
on_removing_rt = None ,
492
+ on_updating_rt = None ,
489
493
** kwargs ):
490
494
# type: (Union[str, dict], Union[str, list, set, tuple], Callable) -> dict
491
495
"""This is an overload which will trigger token storage callbacks.
@@ -503,16 +507,28 @@ def obtain_token_by_refresh_token(self, token_item, scope=None,
503
507
according to https://tools.ietf.org/html/rfc6749#section-6
504
508
:param rt_getter: A callable to translate the token_item to a raw RT string
505
509
:param on_removing_rt: If absent, fall back to the one defined in initialization
510
+
511
+ :param on_updating_rt:
512
+ Default to None, it will fall back to the one defined in initialization.
513
+ This is the most common case.
514
+
515
+ As a special case, you can pass in a False,
516
+ then this function will NOT trigger on_updating_rt() for RT UPDATE,
517
+ instead it will allow the RT to be added by on_obtaining_tokens().
518
+ This behavior is useful when you are migrating RTs from elsewhere
519
+ into a token storage managed by this library.
506
520
"""
507
521
resp = super (Client , self ).obtain_token_by_refresh_token (
508
522
rt_getter (token_item )
509
523
if not isinstance (token_item , string_types ) else token_item ,
510
524
scope = scope ,
525
+ also_save_rt = on_updating_rt is False ,
511
526
** kwargs )
512
527
if resp .get ('error' ) == 'invalid_grant' :
513
528
(on_removing_rt or self .on_removing_rt )(token_item ) # Discard old RT
514
- if 'refresh_token' in resp :
515
- self .on_updating_rt (token_item , resp ['refresh_token' ])
529
+ RT = "refresh_token"
530
+ if on_updating_rt is not False and RT in resp :
531
+ (on_updating_rt or self .on_updating_rt )(token_item , resp [RT ])
516
532
return resp
517
533
518
534
def obtain_token_by_assertion (
0 commit comments