-
Notifications
You must be signed in to change notification settings - Fork 69
/
Copy pathmock_server.py
93 lines (73 loc) · 2.47 KB
/
mock_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""
Mock a GCP Metadata Server. Returns a valid access_token.
"""
import base64
import http.server
import json
import os
import textwrap
import time
import jwt
import requests
def b64_to_b64url(b64):
return b64.replace("+", "-").replace("/", "_").replace("=", "")
def dict_to_b64url(arg):
as_json = json.dumps(arg).encode("utf8")
as_b64 = base64.b64encode(as_json).decode("utf8")
return b64_to_b64url(as_b64)
def get_access_token():
"""
Create a signed JSON Web Token (JWT) and obtain a GCP access token.
"""
global private_key
header = {"alg": "RS256", "typ": "JWT"}
if "GOOGLE_APPLICATION_CREDENTIALS" not in os.environ:
raise Exception(
"please set GOOGLE_APPLICATION_CREDENTIALS environment variable to a JSON Service account key"
)
creds = json.load(open(os.environ["GOOGLE_APPLICATION_CREDENTIALS"]))
private_key = creds["private_key"].encode("utf8")
client_email = creds["client_email"]
claims = {
"iss": client_email,
"aud": "https://oauth2.googleapis.com/token",
"scope": "https://www.googleapis.com/auth/cloudkms",
# Expiration can be at most one hour in the future. Let's say 30 minutes.
"exp": int(time.time()) + 30 * 60,
"iat": int(time.time()),
}
assertion = jwt.encode(claims, private_key, algorithm="RS256", headers=header)
resp = requests.post(
url="https://oauth2.googleapis.com/token",
data={
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
"assertion": assertion,
},
)
if resp.status_code != 200:
msg = textwrap.dedent(f"""
Unexpected non-200 status.
Got status {resp.status_code} on HTTP response:
Headers:{resp.headers}
Body:{resp.text}
Sent request:
Headers:{resp.request.headers}
Body:{resp.request.text}
""")
raise Exception(msg)
return resp.json()
class Handler(http.server.BaseHTTPRequestHandler):
def do_GET(self):
self.send_response(200)
self.send_header("Content-Type", "application/json")
self.end_headers()
access_token = get_access_token()
self.wfile.write(json.dumps(access_token).encode("utf8"))
def main():
global private_key
port = 5000
server = http.server.HTTPServer(("localhost", port), Handler)
print(f"Listening on port {port}")
server.serve_forever()
if __name__ == "__main__":
main()