Skip to content

Commit 427de2d

Browse files
authored
Merge pull request #74 from Yelp/u/dpopes/msk-auth-support
Support MSK IAM Authentication
2 parents 5e508ed + c062282 commit 427de2d

File tree

5 files changed

+391
-3
lines changed

5 files changed

+391
-3
lines changed

kafka/client.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class SimpleClient(object):
7676
# socket timeout.
7777
def __init__(self, hosts, client_id=CLIENT_ID,
7878
timeout=DEFAULT_SOCKET_TIMEOUT_SECONDS,
79-
correlation_id=0, metrics=None):
79+
correlation_id=0, metrics=None, **kwargs):
8080
# We need one connection to bootstrap
8181
self.client_id = client_id
8282
self.timeout = timeout
@@ -90,6 +90,10 @@ def __init__(self, hosts, client_id=CLIENT_ID,
9090
self.topics_to_brokers = {} # TopicPartition -> BrokerMetadata
9191
self.topic_partitions = {} # topic -> partition -> leader
9292

93+
# Support arbitrary kwargs to be provided as config to BrokerConnection
94+
# This will allow advanced features like Authentication to work
95+
self.config = kwargs
96+
9397
self.load_metadata_for_topics() # bootstrap with all metadata
9498

9599
##################
@@ -108,6 +112,7 @@ def _get_conn(self, host, port, afi, node_id='bootstrap'):
108112
metrics=self._metrics_registry,
109113
metric_group_prefix='simple-client',
110114
node_id=node_id,
115+
**self.config,
111116
)
112117

113118
conn = self._conns[host_key]

kafka/conn.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import kafka.errors as Errors
2626
from kafka.future import Future
2727
from kafka.metrics.stats import Avg, Count, Max, Rate
28+
from kafka.msk import AwsMskIamClient
2829
from kafka.oauth.abstract import AbstractTokenProvider
2930
from kafka.protocol.admin import SaslHandShakeRequest
3031
from kafka.protocol.commit import OffsetFetchRequest
@@ -81,6 +82,12 @@ class SSLWantWriteError(Exception):
8182
gssapi = None
8283
GSSError = None
8384

85+
# needed for AWS_MSK_IAM authentication:
86+
try:
87+
from botocore.session import Session as BotoSession
88+
except ImportError:
89+
# no botocore available, will disable AWS_MSK_IAM mechanism
90+
BotoSession = None
8491

8592
AFI_NAMES = {
8693
socket.AF_UNSPEC: "unspecified",
@@ -224,7 +231,7 @@ class BrokerConnection(object):
224231
'sasl_oauth_token_provider': None
225232
}
226233
SECURITY_PROTOCOLS = ('PLAINTEXT', 'SSL', 'SASL_PLAINTEXT', 'SASL_SSL')
227-
SASL_MECHANISMS = ('PLAIN', 'GSSAPI', 'OAUTHBEARER')
234+
SASL_MECHANISMS = ('PLAIN', 'GSSAPI', 'OAUTHBEARER', 'AWS_MSK_IAM')
228235

229236
def __init__(self, host, port, afi, **configs):
230237
self.host = host
@@ -269,6 +276,11 @@ def __init__(self, host, port, afi, **configs):
269276
token_provider = self.config['sasl_oauth_token_provider']
270277
assert token_provider is not None, 'sasl_oauth_token_provider required for OAUTHBEARER sasl'
271278
assert callable(getattr(token_provider, "token", None)), 'sasl_oauth_token_provider must implement method #token()'
279+
280+
if self.config['sasl_mechanism'] == 'AWS_MSK_IAM':
281+
assert BotoSession is not None, 'AWS_MSK_IAM requires the "botocore" package'
282+
assert self.config['security_protocol'] == 'SASL_SSL', 'AWS_MSK_IAM requires SASL_SSL'
283+
272284
# This is not a general lock / this class is not generally thread-safe yet
273285
# However, to avoid pushing responsibility for maintaining
274286
# per-connection locks to the upstream client, we will use this lock to
@@ -552,6 +564,8 @@ def _handle_sasl_handshake_response(self, future, response):
552564
return self._try_authenticate_gssapi(future)
553565
elif self.config['sasl_mechanism'] == 'OAUTHBEARER':
554566
return self._try_authenticate_oauth(future)
567+
elif self.config['sasl_mechanism'] == 'AWS_MSK_IAM':
568+
return self._try_authenticate_aws_msk_iam(future)
555569
else:
556570
return future.failure(
557571
Errors.UnsupportedSaslMechanismError(
@@ -652,6 +666,40 @@ def _try_authenticate_plain(self, future):
652666
log.info('%s: Authenticated as %s via PLAIN', self, self.config['sasl_plain_username'])
653667
return future.success(True)
654668

669+
def _try_authenticate_aws_msk_iam(self, future):
670+
session = BotoSession()
671+
client = AwsMskIamClient(
672+
host=self.host,
673+
boto_session=session,
674+
)
675+
676+
msg = client.first_message()
677+
size = Int32.encode(len(msg))
678+
679+
err = None
680+
close = False
681+
with self._lock:
682+
if not self._can_send_recv():
683+
err = Errors.NodeNotReadyError(str(self))
684+
close = False
685+
else:
686+
try:
687+
self._send_bytes_blocking(size + msg)
688+
data = self._recv_bytes_blocking(4)
689+
data = self._recv_bytes_blocking(struct.unpack('4B', data)[-1])
690+
except (ConnectionError, TimeoutError) as e:
691+
log.exception("%s: Error receiving reply from server", self)
692+
err = Errors.KafkaConnectionError("%s: %s" % (self, e))
693+
close = True
694+
695+
if err is not None:
696+
if close:
697+
self.close(error=err)
698+
return future.failure(err)
699+
700+
log.info('%s: Authenticated via AWS_MSK_IAM %s', self, data.decode('utf-8'))
701+
return future.success(True)
702+
655703
def _try_authenticate_gssapi(self, future):
656704
kerberos_damin_name = self.config['sasl_kerberos_domain_name'] or self.host
657705
auth_id = self.config['sasl_kerberos_service_name'] + '@' + kerberos_damin_name

kafka/msk.py

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
import datetime
2+
import hashlib
3+
import hmac
4+
import json
5+
import string
6+
7+
from kafka.errors import IllegalArgumentError
8+
from kafka.vendor.six.moves import urllib
9+
10+
11+
class AwsMskIamClient:
12+
UNRESERVED_CHARS = string.ascii_letters + string.digits + '-._~'
13+
14+
def __init__(self, host, boto_session):
15+
"""
16+
Arguments:
17+
host (str): The hostname of the broker.
18+
boto_session (botocore.BotoSession) the boto session
19+
"""
20+
self.algorithm = 'AWS4-HMAC-SHA256'
21+
self.expires = '900'
22+
self.hashfunc = hashlib.sha256
23+
self.headers = [
24+
('host', host)
25+
]
26+
self.version = '2020_10_22'
27+
28+
self.service = 'kafka-cluster'
29+
self.action = '{}:Connect'.format(self.service)
30+
31+
now = datetime.datetime.utcnow()
32+
self.datestamp = now.strftime('%Y%m%d')
33+
self.timestamp = now.strftime('%Y%m%dT%H%M%SZ')
34+
35+
self.host = host
36+
self.boto_session = boto_session
37+
38+
# This will raise if the region can't be determined
39+
# Do this during init instead of waiting for failures downstream
40+
if self.region:
41+
pass
42+
43+
@property
44+
def access_key(self):
45+
return self.boto_session.get_credentials().access_key
46+
47+
@property
48+
def secret_key(self):
49+
return self.boto_session.get_credentials().secret_key
50+
51+
@property
52+
def token(self):
53+
return self.boto_session.get_credentials().token
54+
55+
@property
56+
def region(self):
57+
# Try to get the region information from the broker hostname
58+
for host in self.host.split(','):
59+
if 'amazonaws.com' in host:
60+
return host.split('.')[-3]
61+
62+
# If the region can't be determined from hostname, try the boto session
63+
# This will only have a value if:
64+
# - `AWS_DEFAULT_REGION` environment variable is set
65+
# - `~/.aws/config` region variable is set
66+
region = self.boto_session.get_config_variable('region')
67+
if region:
68+
return region
69+
70+
# Otherwise give up
71+
raise IllegalArgumentError('Could not determine region from broker host(s) or aws configuration')
72+
73+
@property
74+
def _credential(self):
75+
return '{0.access_key}/{0._scope}'.format(self)
76+
77+
@property
78+
def _scope(self):
79+
return '{0.datestamp}/{0.region}/{0.service}/aws4_request'.format(self)
80+
81+
@property
82+
def _signed_headers(self):
83+
"""
84+
Returns (str):
85+
An alphabetically sorted, semicolon-delimited list of lowercase
86+
request header names.
87+
"""
88+
return ';'.join(sorted(k.lower() for k, _ in self.headers))
89+
90+
@property
91+
def _canonical_headers(self):
92+
"""
93+
Returns (str):
94+
A newline-delited list of header names and values.
95+
Header names are lowercased.
96+
"""
97+
return '\n'.join(map(':'.join, self.headers)) + '\n'
98+
99+
@property
100+
def _canonical_request(self):
101+
"""
102+
Returns (str):
103+
An AWS Signature Version 4 canonical request in the format:
104+
<Method>\n
105+
<Path>\n
106+
<CanonicalQueryString>\n
107+
<CanonicalHeaders>\n
108+
<SignedHeaders>\n
109+
<HashedPayload>
110+
"""
111+
# The hashed_payload is always an empty string for MSK.
112+
hashed_payload = self.hashfunc(b'').hexdigest()
113+
return '\n'.join((
114+
'GET',
115+
'/',
116+
self._canonical_querystring,
117+
self._canonical_headers,
118+
self._signed_headers,
119+
hashed_payload,
120+
))
121+
122+
@property
123+
def _canonical_querystring(self):
124+
"""
125+
Returns (str):
126+
A '&'-separated list of URI-encoded key/value pairs.
127+
"""
128+
params = []
129+
params.append(('Action', self.action))
130+
params.append(('X-Amz-Algorithm', self.algorithm))
131+
params.append(('X-Amz-Credential', self._credential))
132+
params.append(('X-Amz-Date', self.timestamp))
133+
params.append(('X-Amz-Expires', self.expires))
134+
if self.token:
135+
params.append(('X-Amz-Security-Token', self.token))
136+
params.append(('X-Amz-SignedHeaders', self._signed_headers))
137+
138+
return '&'.join(self._uriencode(k) + '=' + self._uriencode(v) for k, v in params)
139+
140+
@property
141+
def _signing_key(self):
142+
"""
143+
Returns (bytes):
144+
An AWS Signature V4 signing key generated from the secret_key, date,
145+
region, service, and request type.
146+
"""
147+
key = self._hmac(('AWS4' + self.secret_key).encode('utf-8'), self.datestamp)
148+
key = self._hmac(key, self.region)
149+
key = self._hmac(key, self.service)
150+
key = self._hmac(key, 'aws4_request')
151+
return key
152+
153+
@property
154+
def _signing_str(self):
155+
"""
156+
Returns (str):
157+
A string used to sign the AWS Signature V4 payload in the format:
158+
<Algorithm>\n
159+
<Timestamp>\n
160+
<Scope>\n
161+
<CanonicalRequestHash>
162+
"""
163+
canonical_request_hash = self.hashfunc(self._canonical_request.encode('utf-8')).hexdigest()
164+
return '\n'.join((self.algorithm, self.timestamp, self._scope, canonical_request_hash))
165+
166+
def _uriencode(self, msg):
167+
"""
168+
Arguments:
169+
msg (str): A string to URI-encode.
170+
171+
Returns (str):
172+
The URI-encoded version of the provided msg, following the encoding
173+
rules specified: https://github.com/aws/aws-msk-iam-auth#uriencode
174+
"""
175+
return urllib.parse.quote(msg, safe=self.UNRESERVED_CHARS)
176+
177+
def _hmac(self, key, msg):
178+
"""
179+
Arguments:
180+
key (bytes): A key to use for the HMAC digest.
181+
msg (str): A value to include in the HMAC digest.
182+
Returns (bytes):
183+
An HMAC digest of the given key and msg.
184+
"""
185+
return hmac.new(key, msg.encode('utf-8'), digestmod=self.hashfunc).digest()
186+
187+
def first_message(self):
188+
"""
189+
Returns (bytes):
190+
An encoded JSON authentication payload that can be sent to the
191+
broker.
192+
"""
193+
signature = hmac.new(
194+
self._signing_key,
195+
self._signing_str.encode('utf-8'),
196+
digestmod=self.hashfunc,
197+
).hexdigest()
198+
msg = {
199+
'version': self.version,
200+
'host': self.host,
201+
'user-agent': 'kafka-python',
202+
'action': self.action,
203+
'x-amz-algorithm': self.algorithm,
204+
'x-amz-credential': self._credential,
205+
'x-amz-date': self.timestamp,
206+
'x-amz-signedheaders': self._signed_headers,
207+
'x-amz-expires': self.expires,
208+
'x-amz-signature': signature,
209+
}
210+
if self.token:
211+
msg['x-amz-security-token'] = self.token
212+
213+
return json.dumps(msg, separators=(',', ':')).encode('utf-8')

kafka/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.4.7.post4'
1+
__version__ = '1.4.7.post5'

0 commit comments

Comments
 (0)