Skip to content

Commit a65b346

Browse files
committed
simplified and refactor diffie_hellman.py
Everything is a refactor except the method generate_shared_key(). Its hashing final step has been removed for simplicity
1 parent 13e4d3e commit a65b346

File tree

1 file changed

+14
-55
lines changed

1 file changed

+14
-55
lines changed

ciphers/diffie_hellman.py

+14-55
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
from binascii import hexlify
2-
from hashlib import sha256
3-
from os import urandom
1+
import random
42

53
# RFC 3526 - More Modular Exponential (MODP) Diffie-Hellman groups for
64
# Internet Key Exchange (IKE) https://tools.ietf.org/html/rfc3526
75

8-
primes = {
6+
PRIMES = {
97
# 1536-bit
108
5: {
119
"prime": int(
@@ -187,44 +185,27 @@ class DiffieHellman:
187185
>>> alice = DiffieHellman()
188186
>>> bob = DiffieHellman()
189187
190-
>>> alice_private = alice.get_private_key()
191-
>>> alice_public = alice.generate_public_key()
188+
>>> alice_public = alice.public_key
189+
>>> bob_public = bob.public_key
192190
193-
>>> bob_private = bob.get_private_key()
194-
>>> bob_public = bob.generate_public_key()
195-
196-
>>> # generating shared key using the DH object
191+
Generating shared key using the DH object
197192
>>> alice_shared = alice.generate_shared_key(bob_public)
198193
>>> bob_shared = bob.generate_shared_key(alice_public)
199-
200-
>>> assert alice_shared == bob_shared
201-
202-
>>> # generating shared key using static methods
203-
>>> alice_shared = DiffieHellman.generate_shared_key_static(
204-
... alice_private, bob_public
205-
... )
206-
>>> bob_shared = DiffieHellman.generate_shared_key_static(
207-
... bob_private, alice_public
208-
... )
209-
210194
>>> assert alice_shared == bob_shared
211195
"""
212196

213197
# Current minimum recommendation is 2048 bit (group 14)
214198
def __init__(self, group: int = 14) -> None:
215-
if group not in primes:
199+
if group not in PRIMES:
216200
raise ValueError("Unsupported Group")
217-
self.prime = primes[group]["prime"]
218-
self.generator = primes[group]["generator"]
219-
220-
self.__private_key = int(hexlify(urandom(32)), base=16)
201+
self.prime = PRIMES[group]["prime"]
202+
self.generator = PRIMES[group]["generator"]
221203

222-
def get_private_key(self) -> str:
223-
return hex(self.__private_key)[2:]
204+
self.__private_key = random.getrandbits(256)
224205

225-
def generate_public_key(self) -> str:
226-
public_key = pow(self.generator, self.__private_key, self.prime)
227-
return hex(public_key)[2:]
206+
@property
207+
def public_key(self) -> int:
208+
return pow(self.generator, self.__private_key, self.prime)
228209

229210
def is_valid_public_key(self, key: int) -> bool:
230211
# check if the other public key is valid based on NIST SP800-56
@@ -233,32 +214,10 @@ def is_valid_public_key(self, key: int) -> bool:
233214
and pow(key, (self.prime - 1) // 2, self.prime) == 1
234215
)
235216

236-
def generate_shared_key(self, other_key_str: str) -> str:
237-
other_key = int(other_key_str, base=16)
217+
def generate_shared_key(self, other_key: int) -> int:
238218
if not self.is_valid_public_key(other_key):
239219
raise ValueError("Invalid public key")
240-
shared_key = pow(other_key, self.__private_key, self.prime)
241-
return sha256(str(shared_key).encode()).hexdigest()
242-
243-
@staticmethod
244-
def is_valid_public_key_static(remote_public_key_str: int, prime: int) -> bool:
245-
# check if the other public key is valid based on NIST SP800-56
246-
return (
247-
2 <= remote_public_key_str <= prime - 2
248-
and pow(remote_public_key_str, (prime - 1) // 2, prime) == 1
249-
)
250-
251-
@staticmethod
252-
def generate_shared_key_static(
253-
local_private_key_str: str, remote_public_key_str: str, group: int = 14
254-
) -> str:
255-
local_private_key = int(local_private_key_str, base=16)
256-
remote_public_key = int(remote_public_key_str, base=16)
257-
prime = primes[group]["prime"]
258-
if not DiffieHellman.is_valid_public_key_static(remote_public_key, prime):
259-
raise ValueError("Invalid public key")
260-
shared_key = pow(remote_public_key, local_private_key, prime)
261-
return sha256(str(shared_key).encode()).hexdigest()
220+
return pow(other_key, self.__private_key, self.prime)
262221

263222

264223
if __name__ == "__main__":

0 commit comments

Comments
 (0)