Skip to content

Commit dd7dcb0

Browse files
authored
Type annotations in aiokafka/codec.py (#984)
1 parent bb15ecf commit dd7dcb0

File tree

4 files changed

+34
-29
lines changed

4 files changed

+34
-29
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ SCALA_VERSION?=2.13
55
KAFKA_VERSION?=2.8.1
66
DOCKER_IMAGE=aiolibs/kafka:$(SCALA_VERSION)_$(KAFKA_VERSION)
77
DIFF_BRANCH=origin/master
8-
FORMATTED_AREAS=aiokafka/util.py aiokafka/structs.py
8+
FORMATTED_AREAS=aiokafka/util.py aiokafka/structs.py aiokafka/codec.py tests/test_codec.py
99

1010
.PHONY: setup
1111
setup:

aiokafka/codec.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
from __future__ import annotations
2+
13
import gzip
24
import io
35
import struct
46

7+
from typing_extensions import Buffer
8+
59
_XERIAL_V1_HEADER = (-126, b"S", b"N", b"A", b"P", b"P", b"Y", 0, 1, 1)
610
_XERIAL_V1_FORMAT = "bccccccBii"
711
ZSTD_MAX_OUTPUT_SIZE = 1024 * 1024
@@ -12,23 +16,23 @@
1216
cramjam = None
1317

1418

15-
def has_gzip():
19+
def has_gzip() -> bool:
1620
return True
1721

1822

19-
def has_snappy():
23+
def has_snappy() -> bool:
2024
return cramjam is not None
2125

2226

23-
def has_zstd():
27+
def has_zstd() -> bool:
2428
return cramjam is not None
2529

2630

27-
def has_lz4():
31+
def has_lz4() -> bool:
2832
return cramjam is not None
2933

3034

31-
def gzip_encode(payload, compresslevel=None):
35+
def gzip_encode(payload: Buffer, compresslevel: int | None = None) -> bytes:
3236
if not compresslevel:
3337
compresslevel = 9
3438

@@ -45,7 +49,7 @@ def gzip_encode(payload, compresslevel=None):
4549
return buf.getvalue()
4650

4751

48-
def gzip_decode(payload):
52+
def gzip_decode(payload: Buffer) -> bytes:
4953
buf = io.BytesIO(payload)
5054

5155
# Gzip context manager introduced in python 2.7
@@ -57,7 +61,9 @@ def gzip_decode(payload):
5761
gzipper.close()
5862

5963

60-
def snappy_encode(payload, xerial_compatible=True, xerial_blocksize=32 * 1024):
64+
def snappy_encode(
65+
payload: Buffer, xerial_compatible: bool = True, xerial_blocksize: int = 32 * 1024
66+
) -> bytes:
6167
"""Encodes the given data with snappy compression.
6268
6369
If xerial_compatible is set then the stream is encoded in a fashion
@@ -93,12 +99,9 @@ def snappy_encode(payload, xerial_compatible=True, xerial_blocksize=32 * 1024):
9399
for fmt, dat in zip(_XERIAL_V1_FORMAT, _XERIAL_V1_HEADER):
94100
out.write(struct.pack("!" + fmt, dat))
95101

96-
# Chunk through buffers to avoid creating intermediate slice copies
97-
def chunker(payload, i, size):
98-
return memoryview(payload)[i : size + i]
99-
102+
payload = memoryview(payload)
100103
for chunk in (
101-
chunker(payload, i, xerial_blocksize)
104+
payload[i : i + xerial_blocksize]
102105
for i in range(0, len(payload), xerial_blocksize)
103106
):
104107
block = cramjam.snappy.compress_raw(chunk)
@@ -109,7 +112,7 @@ def chunker(payload, i, size):
109112
return out.getvalue()
110113

111114

112-
def _detect_xerial_stream(payload):
115+
def _detect_xerial_stream(payload: Buffer) -> bool:
113116
"""Detects if the data given might have been encoded with the blocking mode
114117
of the xerial snappy library.
115118
@@ -131,20 +134,21 @@ def _detect_xerial_stream(payload):
131134
1.
132135
"""
133136

137+
payload = memoryview(payload)
134138
if len(payload) > 16:
135-
header = struct.unpack("!" + _XERIAL_V1_FORMAT, memoryview(payload)[:16])
139+
header = struct.unpack("!" + _XERIAL_V1_FORMAT, payload[:16])
136140
return header == _XERIAL_V1_HEADER
137141
return False
138142

139143

140-
def snappy_decode(payload):
144+
def snappy_decode(payload: Buffer) -> bytes:
141145
if not has_snappy():
142146
raise NotImplementedError("Snappy codec is not available")
143147

144148
if _detect_xerial_stream(payload):
145149
# TODO ? Should become a fileobj ?
146150
out = io.BytesIO()
147-
byt = payload[16:]
151+
byt = memoryview(payload)[16:]
148152
length = len(byt)
149153
cursor = 0
150154

@@ -162,7 +166,7 @@ def snappy_decode(payload):
162166
return bytes(cramjam.snappy.decompress_raw(payload))
163167

164168

165-
def lz4_encode(payload, level=9):
169+
def lz4_encode(payload: Buffer, level: int = 9) -> bytes:
166170
# level=9 is used by default by broker itself
167171
# https://cwiki.apache.org/confluence/display/KAFKA/KIP-390%3A+Support+Compression+Level
168172
if not has_lz4():
@@ -177,14 +181,14 @@ def lz4_encode(payload, level=9):
177181
return bytes(compressor.finish())
178182

179183

180-
def lz4_decode(payload):
184+
def lz4_decode(payload: Buffer) -> bytes:
181185
if not has_lz4():
182186
raise NotImplementedError("LZ4 codec is not available")
183187

184188
return bytes(cramjam.lz4.decompress(payload))
185189

186190

187-
def zstd_encode(payload, level=None):
191+
def zstd_encode(payload: Buffer, level: int | None = None) -> bytes:
188192
if not has_zstd():
189193
raise NotImplementedError("Zstd codec is not available")
190194

@@ -196,7 +200,7 @@ def zstd_encode(payload, level=None):
196200
return bytes(cramjam.zstd.compress(payload, level=level))
197201

198202

199-
def zstd_decode(payload):
203+
def zstd_decode(payload: Buffer) -> bytes:
200204
if not has_zstd():
201205
raise NotImplementedError("Zstd codec is not available")
202206

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ dynamic = ["version"]
3131
dependencies = [
3232
"async-timeout",
3333
"packaging",
34+
"typing_extensions >=4.6.0",
3435
]
3536

3637
[project.optional-dependencies]

tests/test_codec.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,23 @@
2020
from ._testutil import random_string
2121

2222

23-
def test_gzip():
23+
def test_gzip() -> None:
2424
for i in range(1000):
2525
b1 = random_string(100)
2626
b2 = gzip_decode(gzip_encode(b1))
2727
assert b1 == b2
2828

2929

3030
@pytest.mark.skipif(not has_snappy(), reason="Snappy not available")
31-
def test_snappy():
31+
def test_snappy() -> None:
3232
for i in range(1000):
3333
b1 = random_string(100)
3434
b2 = snappy_decode(snappy_encode(b1))
3535
assert b1 == b2
3636

3737

3838
@pytest.mark.skipif(not has_snappy(), reason="Snappy not available")
39-
def test_snappy_detect_xerial():
39+
def test_snappy_detect_xerial() -> None:
4040
_detect_xerial_stream = codecs._detect_xerial_stream
4141

4242
header = b"\x82SNAPPY\x00\x00\x00\x00\x01\x00\x00\x00\x01Some extra bytes"
@@ -55,7 +55,7 @@ def test_snappy_detect_xerial():
5555

5656

5757
@pytest.mark.skipif(not has_snappy(), reason="Snappy not available")
58-
def test_snappy_decode_xerial():
58+
def test_snappy_decode_xerial() -> None:
5959
header = b"\x82SNAPPY\x00\x00\x00\x00\x01\x00\x00\x00\x01"
6060
random_snappy = snappy_encode(b"SNAPPY" * 50, xerial_compatible=False)
6161
block_len = len(random_snappy)
@@ -73,7 +73,7 @@ def test_snappy_decode_xerial():
7373

7474

7575
@pytest.mark.skipif(not has_snappy(), reason="Snappy not available")
76-
def test_snappy_encode_xerial():
76+
def test_snappy_encode_xerial() -> None:
7777
to_ensure = (
7878
b"\x82SNAPPY\x00\x00\x00\x00\x01\x00\x00\x00\x01"
7979
b"\x00\x00\x00\x18\xac\x02\x14SNAPPY\xfe\x06\x00\xfe\x06\x00\xfe\x06\x00"
@@ -88,7 +88,7 @@ def test_snappy_encode_xerial():
8888

8989

9090
@pytest.mark.skipif(not has_lz4(), reason="LZ4 not available")
91-
def test_lz4():
91+
def test_lz4() -> None:
9292
for i in range(1000):
9393
b1 = random_string(100)
9494
b2 = lz4_decode(lz4_encode(b1))
@@ -97,7 +97,7 @@ def test_lz4():
9797

9898

9999
@pytest.mark.skipif(not has_lz4(), reason="LZ4 not available")
100-
def test_lz4_incremental():
100+
def test_lz4_incremental() -> None:
101101
for i in range(1000):
102102
# lz4 max single block size is 4MB
103103
# make sure we test with multiple-blocks
@@ -108,7 +108,7 @@ def test_lz4_incremental():
108108

109109

110110
@pytest.mark.skipif(not has_zstd(), reason="Zstd not available")
111-
def test_zstd():
111+
def test_zstd() -> None:
112112
for _ in range(1000):
113113
b1 = random_string(100)
114114
b2 = zstd_decode(zstd_encode(b1))

0 commit comments

Comments
 (0)