Skip to content

Allowing transport layer to be customized #169

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 37 commits into from
Apr 21, 2020
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
cbe98f3
Initial commit
abhidnya13 Feb 25, 2020
e64332f
Iteration 1
abhidnya13 Mar 5, 2020
2a7d5e3
Merge branch 'dev' of https://github.com/AzureAD/microsoft-authentica…
abhidnya13 Mar 5, 2020
aed0e8d
Iteration 2 rectifying tests
abhidnya13 Mar 6, 2020
30f7d99
Iteration 3 modifying some more failing tests
abhidnya13 Mar 6, 2020
34e8e16
Iteration 4
abhidnya13 Mar 7, 2020
f60bbb5
Removing tests whose implementation was removed
abhidnya13 Mar 7, 2020
d54fb8b
Iteration 5
abhidnya13 Mar 7, 2020
be389d5
Replacing generic exception to specific Http error
abhidnya13 Mar 9, 2020
dcec4af
Merge branch 'dev' of https://github.com/AzureAD/microsoft-authentica…
abhidnya13 Mar 16, 2020
e37f0c3
Merge branch 'dev' of https://github.com/AzureAD/microsoft-authentica…
abhidnya13 Mar 16, 2020
b3a4e09
Refactor according to new interface
abhidnya13 Mar 20, 2020
96988f8
Changing one reference to new interface left in the previous one
abhidnya13 Mar 20, 2020
1d05615
Modified one more missed change
abhidnya13 Mar 20, 2020
ccafcf9
Few more changes and refactor
abhidnya13 Mar 20, 2020
2670e25
Adding raw response to response object
abhidnya13 Mar 20, 2020
cb70a37
Cleaning None values
abhidnya13 Mar 23, 2020
53520fd
PR review iteration
abhidnya13 Mar 26, 2020
293b081
Removing default http client
abhidnya13 Mar 30, 2020
fcac05e
cleaning up
abhidnya13 Mar 30, 2020
229ad26
Adding deleted single empty line back
abhidnya13 Mar 30, 2020
340210f
Updating filtering of non values from dictionary
abhidnya13 Mar 30, 2020
ca8a129
Merge branch 'dev' of https://github.com/AzureAD/microsoft-authentica…
abhidnya13 Apr 7, 2020
d993950
Capturing editorial changes
abhidnya13 Apr 13, 2020
e5ebd28
Review changes 1
abhidnya13 Apr 16, 2020
8901269
Fixing broken test
abhidnya13 Apr 16, 2020
7774be8
Cleaning up
abhidnya13 Apr 16, 2020
42ca7a2
PR review changes part 1
abhidnya13 Apr 17, 2020
d38d38d
PR review changes part 2
abhidnya13 Apr 17, 2020
3226bc4
PR review changes part 3
abhidnya13 Apr 17, 2020
90afb94
Minor indent change
abhidnya13 Apr 17, 2020
a1ce0d1
PR review changes part 4
abhidnya13 Apr 20, 2020
ad245a7
Adding back accidentally deleted blank line
abhidnya13 Apr 20, 2020
34bb127
Removing extra indent
abhidnya13 Apr 20, 2020
a745949
Minor line add in authority.py
abhidnya13 Apr 20, 2020
b636baa
Making changes for backward compatibility
abhidnya13 Apr 20, 2020
d24d575
Some more editorial changes
abhidnya13 Apr 21, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 25 additions & 28 deletions msal/application.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import time
try: # Python 2
from urlparse import urljoin
Expand All @@ -8,16 +9,13 @@
import warnings
import uuid

import requests

from .oauth2cli import Client, JwtAssertionCreator
from .oauth2cli import Client, JwtAssertionCreator, DefaultHttpClient
from .authority import Authority
from .mex import send_request as mex_send_request
from .wstrust_request import send_request as wst_send_request
from .wstrust_response import *
from .token_cache import TokenCache


# The __init__.py will import this. Not the other way around.
__version__ = "1.1.0"

Expand Down Expand Up @@ -54,11 +52,11 @@ def decorate_scope(
CLIENT_CURRENT_TELEMETRY = 'x-client-current-telemetry'

def _get_new_correlation_id():
return str(uuid.uuid4())
return str(uuid.uuid4())


def _build_current_telemetry_request_header(public_api_id, force_refresh=False):
return "1|{},{}|".format(public_api_id, "1" if force_refresh else "0")
return "1|{},{}|".format(public_api_id, "1" if force_refresh else "0")


def extract_certs(public_cert_content):
Expand Down Expand Up @@ -91,7 +89,7 @@ class ClientApplication(object):
def __init__(
self, client_id,
client_credential=None, authority=None, validate_authority=True,
token_cache=None,
token_cache=None, http_client=None,
verify=True, proxies=None, timeout=None,
client_claims=None, app_name=None, app_version=None):
"""Create an instance of application.
Expand Down Expand Up @@ -139,6 +137,9 @@ def __init__(
:param TokenCache cache:
Sets the token cache used by this ClientApplication instance.
By default, an in-memory cache will be created and used.
:param http_client: (optional)
Your implementation of abstract class HttpClient <msal.oauth2cli.http.http_client>
Defaults to default http client implementation which uses requests
:param verify: (optional)
It will be passed to the
`verify parameter in the underlying requests library
Expand All @@ -161,14 +162,13 @@ def __init__(
self.client_id = client_id
self.client_credential = client_credential
self.client_claims = client_claims
self.verify = verify
self.proxies = proxies
self.http_client = http_client if http_client else DefaultHttpClient(verify=verify, proxies=proxies)
self.timeout = timeout
self.app_name = app_name
self.app_version = app_version
self.authority = Authority(
authority or "https://login.microsoftonline.com/common/",
validate_authority, verify=verify, proxies=proxies, timeout=timeout)
http_client=self.http_client, validate_authority=validate_authority, timeout=timeout)
# Here the self.authority is not the same type as authority in input
self.token_cache = token_cache or TokenCache()
self.client = self._build_client(client_credential, self.authority)
Expand Down Expand Up @@ -211,14 +211,15 @@ def _build_client(self, client_credential, authority):
return Client(
server_configuration,
self.client_id,
http_client=self.http_client,
default_headers=default_headers,
default_body=default_body,
client_assertion=client_assertion,
client_assertion_type=client_assertion_type,
on_obtaining_tokens=self.token_cache.add,
on_removing_rt=self.token_cache.remove_rt,
on_updating_rt=self.token_cache.update_rt,
verify=self.verify, proxies=self.proxies, timeout=self.timeout)
timeout=self.timeout)

def get_authorization_request_url(
self,
Expand Down Expand Up @@ -264,13 +265,12 @@ def get_authorization_request_url(
# The previous implementation is, it will use self.authority by default.
# Multi-tenant app can use new authority on demand
the_authority = Authority(
authority,
verify=self.verify, proxies=self.proxies, timeout=self.timeout,
authority, http_client=self.http_client, timeout=self.timeout
) if authority else self.authority

client = Client(
{"authorization_endpoint": the_authority.authorization_endpoint},
self.client_id)
self.client_id, self.http_client)
return client.build_auth_request_uri(
response_type=response_type,
redirect_uri=redirect_uri, state=state, login_hint=login_hint,
Expand Down Expand Up @@ -367,13 +367,13 @@ def _find_msal_accounts(self, environment):

def _get_authority_aliases(self, instance):
if not self.authority_groups:
resp = requests.get(
"https://login.microsoftonline.com/common/discovery/instance?api-version=1.1&authorization_endpoint=https://login.microsoftonline.com/common/oauth2/authorize",
resp = self.http_client.get("https://login.microsoftonline.com/common/discovery/instance?api-version=1.1&authorization_endpoint=https://login.microsoftonline.com/common/oauth2/authorize",
headers={'Accept': 'application/json'},
verify=self.verify, proxies=self.proxies, timeout=self.timeout)
timeout=self.timeout)
resp.raise_for_status()
resp = json.loads(resp.text)
self.authority_groups = [
set(group['aliases']) for group in resp.json()['metadata']]
set(group['aliases']) for group in resp['metadata']]
for group in self.authority_groups:
if instance in group:
return [alias for alias in group if alias != instance]
Expand Down Expand Up @@ -491,8 +491,8 @@ def acquire_token_silent_with_error(
if authority:
warnings.warn("We haven't decided how/if this method will accept authority parameter")
# the_authority = Authority(
# authority,
# verify=self.verify, proxies=self.proxies, timeout=self.timeout,
# authority, http_client=self.http_client,
# timeout=self.timeout
# ) if authority else self.authority
result = self._acquire_token_silent_from_cache_and_possibly_refresh_it(
scopes, account, self.authority, force_refresh=force_refresh,
Expand All @@ -504,8 +504,8 @@ def acquire_token_silent_with_error(
for alias in self._get_authority_aliases(self.authority.instance):
the_authority = Authority(
"https://" + alias + "/" + self.authority.tenant,
validate_authority=False,
verify=self.verify, proxies=self.proxies, timeout=self.timeout)
http_client=self.http_client, validate_authority=False,
timeout=self.timeout)
result = self._acquire_token_silent_from_cache_and_possibly_refresh_it(
scopes, account, the_authority, force_refresh=force_refresh,
correlation_id=correlation_id,
Expand Down Expand Up @@ -734,7 +734,7 @@ def acquire_token_by_username_password(
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_BY_USERNAME_PASSWORD_ID),
}
if not self.authority.is_adfs:
if not self.authority.is_adfs and not self.authority.is_b2c:
user_realm_result = self.authority.user_realm_discovery(
username, correlation_id=headers[CLIENT_REQUEST_ID])
if user_realm_result.get("account_type") == "Federated":
Expand All @@ -748,13 +748,10 @@ def acquire_token_by_username_password(

def _acquire_token_by_username_password_federated(
self, user_realm_result, username, password, scopes=None, **kwargs):
verify = kwargs.pop("verify", self.verify)
proxies = kwargs.pop("proxies", self.proxies)
wstrust_endpoint = {}
if user_realm_result.get("federation_metadata_url"):
wstrust_endpoint = mex_send_request(
user_realm_result["federation_metadata_url"],
verify=verify, proxies=proxies)
user_realm_result["federation_metadata_url"], http_client=self.http_client)
if wstrust_endpoint is None:
raise ValueError("Unable to find wstrust endpoint from MEX. "
"This typically happens when attempting MSA accounts. "
Expand All @@ -766,7 +763,7 @@ def _acquire_token_by_username_password_federated(
wstrust_endpoint.get("address",
# Fallback to an AAD supplied endpoint
user_realm_result.get("federation_active_auth_url")),
wstrust_endpoint.get("action"), verify=verify, proxies=proxies)
wstrust_endpoint.get("action"), http_client=self.http_client)
if not ("token" in wstrust_result and "type" in wstrust_result):
raise RuntimeError("Unsuccessful RSTR. %s" % wstrust_result)
GRANT_TYPE_SAML1_1 = 'urn:ietf:params:oauth:grant-type:saml1_1-bearer'
Expand Down
79 changes: 36 additions & 43 deletions msal/authority.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import json

from .oauth2cli.default_http_client import DefaultHttpClient

try:
from urllib.parse import urlparse
except ImportError: # Fall back to Python 2
from urlparse import urlparse
import logging

import requests

from .exceptions import MsalServiceError


Expand All @@ -25,6 +27,7 @@
"b2clogin.de",
]


class Authority(object):
"""This class represents an (already-validated) authority.

Expand All @@ -33,8 +36,8 @@ class Authority(object):
"""
_domains_without_user_realm_discovery = set([])

def __init__(self, authority_url, validate_authority=True,
verify=True, proxies=None, timeout=None,
def __init__(self, authority_url, http_client, validate_authority=True,
timeout=None
):
"""Creates an authority instance, and also validates it.

Expand All @@ -44,19 +47,18 @@ def __init__(self, authority_url, validate_authority=True,
This parameter only controls whether an instance discovery will be
performed.
"""
self.verify = verify
self.proxies = proxies
self.http_client = http_client
self.timeout = timeout
authority, self.instance, tenant = canonicalize(authority_url)
parts = authority.path.split('/')
is_b2c = any(self.instance.endswith("." + d) for d in WELL_KNOWN_B2C_HOSTS) or (
self.is_b2c = any(self.instance.endswith("." + d) for d in WELL_KNOWN_B2C_HOSTS) or (
len(parts) == 3 and parts[2].lower().startswith("b2c_"))
if (tenant != "adfs" and (not is_b2c) and validate_authority
if (tenant != "adfs" and (not self.is_b2c) and validate_authority
and self.instance not in WELL_KNOWN_AUTHORITY_HOSTS):
payload = instance_discovery(
payload = self.instance_discovery(
"https://{}{}/oauth2/v2.0/authorize".format(
self.instance, authority.path),
verify=verify, proxies=proxies, timeout=timeout)
timeout=timeout)
if payload.get("error") == "invalid_instance":
raise ValueError(
"invalid_instance: "
Expand All @@ -73,9 +75,9 @@ def __init__(self, authority_url, validate_authority=True,
authority.path, # In B2C scenario, it is "/tenant/policy"
"" if tenant == "adfs" else "/v2.0" # the AAD v2 endpoint
))
openid_config = tenant_discovery(
openid_config = self.tenant_discovery(
tenant_discovery_endpoint,
verify=verify, proxies=proxies, timeout=timeout)
timeout=timeout)
logger.debug("openid_config = %s", openid_config)
self.authorization_endpoint = openid_config['authorization_endpoint']
self.token_endpoint = openid_config['token_endpoint']
Expand All @@ -86,18 +88,28 @@ def user_realm_discovery(self, username, correlation_id=None, response=None):
# It will typically return a dict containing "ver", "account_type",
# "federation_protocol", "cloud_audience_urn",
# "federation_metadata_url", "federation_active_auth_url", etc.
if self.instance not in self.__class__._domains_without_user_realm_discovery:
resp = response or requests.get(
"https://{netloc}/common/userrealm/{username}?api-version=1.0".format(
netloc=self.instance, username=username),
headers={'Accept':'application/json',
'client-request-id': correlation_id},
verify=self.verify, proxies=self.proxies, timeout=self.timeout)
if resp.status_code != 404:
resp.raise_for_status()
return resp.json()
self.__class__._domains_without_user_realm_discovery.add(self.instance)
return {} # This can guide the caller to fall back normal ROPC flow
resp = response or self.http_client.get("https://{netloc}/common/userrealm/{username}?api-version=1.0".format(
netloc=self.instance, username=username),
headers={'Accept':'application/json', 'client-request-id': correlation_id},
timeout=self.timeout)
return json.loads(resp.text)

def instance_discovery(self, url, **kwargs):
resp = self.http_client.get('https://{}/common/discovery/instance'.format(
WORLD_WIDE # Historically using WORLD_WIDE. Could use self.instance too
# See https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadInstanceDiscovery.cs#L101-L103
# and https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadAuthority.cs#L19-L33
), params={'authorization_endpoint': url, 'api-version': '1.0'},
**kwargs)
return json.loads(resp.text)

def tenant_discovery(self, tenant_discovery_endpoint, **kwargs):
# Returns Openid Configuration
resp = self.http_client.get(tenant_discovery_endpoint, **kwargs)
payload = json.loads(resp.text)
if 'authorization_endpoint' in payload and 'token_endpoint' in payload:
return payload
raise MsalServiceError(status_code=resp.status_code, **payload)


def canonicalize(authority_url):
Expand All @@ -112,22 +124,3 @@ def canonicalize(authority_url):
"or https://<tenant_name>.b2clogin.com/<tenant_name>.onmicrosoft.com/policy"
% authority_url)
return authority, authority.hostname, parts[1]

def instance_discovery(url, **kwargs):
return requests.get( # Note: This URL seemingly returns V1 endpoint only
'https://{}/common/discovery/instance'.format(
WORLD_WIDE # Historically using WORLD_WIDE. Could use self.instance too
# See https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadInstanceDiscovery.cs#L101-L103
# and https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadAuthority.cs#L19-L33
),
params={'authorization_endpoint': url, 'api-version': '1.0'},
**kwargs).json()

def tenant_discovery(tenant_discovery_endpoint, **kwargs):
# Returns Openid Configuration
resp = requests.get(tenant_discovery_endpoint, **kwargs)
payload = resp.json()
if 'authorization_endpoint' in payload and 'token_endpoint' in payload:
return payload
raise MsalServiceError(status_code=resp.status_code, **payload)

11 changes: 5 additions & 6 deletions msal/mex.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,16 @@
except ImportError:
from xml.etree import ElementTree as ET

import requests


def _xpath_of_root(route_to_leaf):
# Construct an xpath suitable to find a root node which has a specified leaf
return '/'.join(route_to_leaf + ['..'] * (len(route_to_leaf)-1))

def send_request(mex_endpoint, **kwargs):
mex_document = requests.get(
mex_endpoint, headers={'Content-Type': 'application/soap+xml'},
**kwargs).text

def send_request(mex_endpoint, http_client, **kwargs):
resp = http_client.get(mex_endpoint, headers={'Content-Type': 'application/soap+xml'},
**kwargs)
mex_document = resp.text
return Mex(mex_document).get_wstrust_username_password_endpoint()


Expand Down
3 changes: 2 additions & 1 deletion msal/oauth2cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
from .oidc import Client
from .assertion import JwtAssertionCreator
from .assertion import JwtSigner # Obsolete. For backward compatibility.

from .http import HttpClient, Response
from .default_http_client import DefaultHttpClient
46 changes: 46 additions & 0 deletions msal/oauth2cli/default_http_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import requests

from .http import HttpClient, Response


class DefaultHttpClient(HttpClient):
"""
Default HTTP Client
"""
def __init__(self, verify=True, proxies=None):
"""
Constructor for the DefaultHttpClient

verify=True, # type: Union[str, True, False, None]
proxies=None, # type: Optional[dict]
"""
self.session = requests.Session()
if verify:
self.session.verify = verify
if proxies:
self.session.proxies = proxies

def post(self, url, params=None, data=None, headers=None, **kwargs):

response = self.session.post(url=url, params=params, headers=headers, data=data, **kwargs)
return Response(response.status_code, response.text, response)

def get(self, url, params=None, headers=None, **kwargs):
response = self.session.get(url=url, params=params, headers=headers, **kwargs)
return Response(response.status_code, response.text, response)


class Response(Response):

def __init__(self, status_code, text, response):
"""Constructor for DefaultResponseObject
status, # type: int
text, # type: str response in string format
response, # type: Raw response from requests
"""
self.status_code = status_code
self.text = text
self.response = response

def raise_for_status(self):
self.response.raise_for_status()
Loading