Skip to content

Commit e3e9740

Browse files
committed
Merge branch 'oauth2cli' into dev
2 parents b088ca4 + 08fec9a commit e3e9740

File tree

2 files changed

+72
-10
lines changed

2 files changed

+72
-10
lines changed

msal/oauth2cli/oauth2.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749
173173
headers=None, # a dict to be sent as request headers
174174
post=None, # A callable to replace requests.post(), for testing.
175175
# Such as: lambda url, **kwargs:
176-
# Mock(status_code=200, json=Mock(return_value={}))
176+
# Mock(status_code=200, text='{}')
177177
**kwargs # Relay all extra parameters to underlying requests
178178
): # Returns the json object came from the OAUTH2 response
179179
_data = {'client_id': self.client_id, 'grant_type': grant_type}
@@ -454,17 +454,20 @@ def __init__(self,
454454
self.on_removing_rt = on_removing_rt
455455
self.on_updating_rt = on_updating_rt
456456

457-
def _obtain_token(self, grant_type, params=None, data=None, *args, **kwargs):
458-
RT = "refresh_token"
457+
def _obtain_token(
458+
self, grant_type, params=None, data=None,
459+
also_save_rt=False,
460+
*args, **kwargs):
459461
_data = data.copy() # to prevent side effect
460-
refresh_token = _data.get(RT)
461462
resp = super(Client, self)._obtain_token(
462463
grant_type, params, _data, *args, **kwargs)
463464
if "error" not in resp:
464465
_resp = resp.copy()
465-
if grant_type == RT and RT in _resp and isinstance(refresh_token, dict):
466-
_resp.pop(RT) # So we skip it in on_obtaining_tokens(); it will
467-
# be handled in self.obtain_token_by_refresh_token()
466+
RT = "refresh_token"
467+
if grant_type == RT and RT in _resp and not also_save_rt:
468+
# Then we skip it from on_obtaining_tokens();
469+
# Leave it to self.obtain_token_by_refresh_token()
470+
_resp.pop(RT, None)
468471
if "scope" in _resp:
469472
scope = _resp["scope"].split() # It is conceptually a set,
470473
# but we represent it as a list which can be persisted to JSON
@@ -486,6 +489,7 @@ def _obtain_token(self, grant_type, params=None, data=None, *args, **kwargs):
486489
def obtain_token_by_refresh_token(self, token_item, scope=None,
487490
rt_getter=lambda token_item: token_item["refresh_token"],
488491
on_removing_rt=None,
492+
on_updating_rt=None,
489493
**kwargs):
490494
# type: (Union[str, dict], Union[str, list, set, tuple], Callable) -> dict
491495
"""This is an overload which will trigger token storage callbacks.
@@ -503,16 +507,28 @@ def obtain_token_by_refresh_token(self, token_item, scope=None,
503507
according to https://tools.ietf.org/html/rfc6749#section-6
504508
:param rt_getter: A callable to translate the token_item to a raw RT string
505509
:param on_removing_rt: If absent, fall back to the one defined in initialization
510+
511+
:param on_updating_rt:
512+
Default to None, it will fall back to the one defined in initialization.
513+
This is the most common case.
514+
515+
As a special case, you can pass in a False,
516+
then this function will NOT trigger on_updating_rt() for RT UPDATE,
517+
instead it will allow the RT to be added by on_obtaining_tokens().
518+
This behavior is useful when you are migrating RTs from elsewhere
519+
into a token storage managed by this library.
506520
"""
507521
resp = super(Client, self).obtain_token_by_refresh_token(
508522
rt_getter(token_item)
509523
if not isinstance(token_item, string_types) else token_item,
510524
scope=scope,
525+
also_save_rt=on_updating_rt is False,
511526
**kwargs)
512527
if resp.get('error') == 'invalid_grant':
513528
(on_removing_rt or self.on_removing_rt)(token_item) # Discard old RT
514-
if 'refresh_token' in resp:
515-
self.on_updating_rt(token_item, resp['refresh_token'])
529+
RT = "refresh_token"
530+
if on_updating_rt is not False and RT in resp:
531+
(on_updating_rt or self.on_updating_rt)(token_item, resp[RT])
516532
return resp
517533

518534
def obtain_token_by_assertion(

tests/test_client.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from msal.oauth2cli import Client, JwtSigner
1313
from msal.oauth2cli.authcode import obtain_auth_code
1414
from tests import unittest, Oauth2TestCase
15-
from tests.http_client import MinimalHttpClient
15+
from tests.http_client import MinimalHttpClient, MinimalResponse
1616

1717

1818
logging.basicConfig(level=logging.DEBUG)
@@ -175,6 +175,52 @@ def test_device_flow(self):
175175
skippable_errors=self.client.DEVICE_FLOW_RETRIABLE_ERRORS)
176176

177177

178+
class TestRefreshTokenCallbacks(unittest.TestCase):
179+
180+
def _dummy(self, url, **kwargs):
181+
return MinimalResponse(status_code=200, text='{"refresh_token": "new"}')
182+
183+
def test_rt_being_added(self):
184+
client = Client(
185+
{"token_endpoint": "http://example.com/token"},
186+
"client_id",
187+
http_client=MinimalHttpClient(),
188+
on_obtaining_tokens=lambda event:
189+
self.assertEqual("new", event["response"].get("refresh_token")),
190+
on_updating_rt=lambda rt_item, new_rt:
191+
self.fail("This should not be called here"),
192+
)
193+
client.obtain_token_by_authorization_code("code", post=self._dummy)
194+
195+
def test_rt_being_updated(self):
196+
old_rt = {"refresh_token": "old"}
197+
client = Client(
198+
{"token_endpoint": "http://example.com/token"},
199+
"client_id",
200+
http_client=MinimalHttpClient(),
201+
on_obtaining_tokens=lambda event:
202+
self.assertNotIn("refresh_token", event["response"]),
203+
on_updating_rt=lambda old, new: # TODO: ensure it being called
204+
(self.assertEqual(old_rt, old), self.assertEqual("new", new)),
205+
)
206+
client.obtain_token_by_refresh_token(
207+
{"refresh_token": "old"}, post=self._dummy)
208+
209+
def test_rt_being_migrated(self):
210+
old_rt = {"refresh_token": "old"}
211+
client = Client(
212+
{"token_endpoint": "http://example.com/token"},
213+
"client_id",
214+
http_client=MinimalHttpClient(),
215+
on_obtaining_tokens=lambda event:
216+
self.assertEqual("new", event["response"].get("refresh_token")),
217+
on_updating_rt=lambda rt_item, new_rt:
218+
self.fail("This should not be called here"),
219+
)
220+
client.obtain_token_by_refresh_token(
221+
{"refresh_token": "old"}, on_updating_rt=False, post=self._dummy)
222+
223+
178224
class TestSessionAccessibility(unittest.TestCase):
179225
def test_accessing_session_property_for_backward_compatibility(self):
180226
client = Client({}, "client_id")

0 commit comments

Comments
 (0)