1
+ from __future__ import annotations
2
+
1
3
import gzip
2
4
import io
3
5
import struct
4
6
7
+ from typing_extensions import Buffer
8
+
5
9
_XERIAL_V1_HEADER = (- 126 , b"S" , b"N" , b"A" , b"P" , b"P" , b"Y" , 0 , 1 , 1 )
6
10
_XERIAL_V1_FORMAT = "bccccccBii"
7
11
ZSTD_MAX_OUTPUT_SIZE = 1024 * 1024
12
16
cramjam = None
13
17
14
18
15
- def has_gzip ():
19
+ def has_gzip () -> bool :
16
20
return True
17
21
18
22
19
- def has_snappy ():
23
+ def has_snappy () -> bool :
20
24
return cramjam is not None
21
25
22
26
23
- def has_zstd ():
27
+ def has_zstd () -> bool :
24
28
return cramjam is not None
25
29
26
30
27
- def has_lz4 ():
31
+ def has_lz4 () -> bool :
28
32
return cramjam is not None
29
33
30
34
31
- def gzip_encode (payload , compresslevel = None ):
35
+ def gzip_encode (payload : Buffer , compresslevel : int | None = None ) -> bytes :
32
36
if not compresslevel :
33
37
compresslevel = 9
34
38
@@ -45,7 +49,7 @@ def gzip_encode(payload, compresslevel=None):
45
49
return buf .getvalue ()
46
50
47
51
48
- def gzip_decode (payload ) :
52
+ def gzip_decode (payload : Buffer ) -> bytes :
49
53
buf = io .BytesIO (payload )
50
54
51
55
# Gzip context manager introduced in python 2.7
@@ -57,7 +61,9 @@ def gzip_decode(payload):
57
61
gzipper .close ()
58
62
59
63
60
- def snappy_encode (payload , xerial_compatible = True , xerial_blocksize = 32 * 1024 ):
64
+ def snappy_encode (
65
+ payload : Buffer , xerial_compatible : bool = True , xerial_blocksize : int = 32 * 1024
66
+ ) -> bytes :
61
67
"""Encodes the given data with snappy compression.
62
68
63
69
If xerial_compatible is set then the stream is encoded in a fashion
@@ -93,12 +99,9 @@ def snappy_encode(payload, xerial_compatible=True, xerial_blocksize=32 * 1024):
93
99
for fmt , dat in zip (_XERIAL_V1_FORMAT , _XERIAL_V1_HEADER ):
94
100
out .write (struct .pack ("!" + fmt , dat ))
95
101
96
- # Chunk through buffers to avoid creating intermediate slice copies
97
- def chunker (payload , i , size ):
98
- return memoryview (payload )[i : size + i ]
99
-
102
+ payload = memoryview (payload )
100
103
for chunk in (
101
- chunker ( payload , i , xerial_blocksize )
104
+ payload [ i : i + xerial_blocksize ]
102
105
for i in range (0 , len (payload ), xerial_blocksize )
103
106
):
104
107
block = cramjam .snappy .compress_raw (chunk )
@@ -109,7 +112,7 @@ def chunker(payload, i, size):
109
112
return out .getvalue ()
110
113
111
114
112
- def _detect_xerial_stream (payload ) :
115
+ def _detect_xerial_stream (payload : Buffer ) -> bool :
113
116
"""Detects if the data given might have been encoded with the blocking mode
114
117
of the xerial snappy library.
115
118
@@ -131,20 +134,21 @@ def _detect_xerial_stream(payload):
131
134
1.
132
135
"""
133
136
137
+ payload = memoryview (payload )
134
138
if len (payload ) > 16 :
135
- header = struct .unpack ("!" + _XERIAL_V1_FORMAT , memoryview ( payload ) [:16 ])
139
+ header = struct .unpack ("!" + _XERIAL_V1_FORMAT , payload [:16 ])
136
140
return header == _XERIAL_V1_HEADER
137
141
return False
138
142
139
143
140
- def snappy_decode (payload ) :
144
+ def snappy_decode (payload : Buffer ) -> bytes :
141
145
if not has_snappy ():
142
146
raise NotImplementedError ("Snappy codec is not available" )
143
147
144
148
if _detect_xerial_stream (payload ):
145
149
# TODO ? Should become a fileobj ?
146
150
out = io .BytesIO ()
147
- byt = payload [16 :]
151
+ byt = memoryview ( payload ) [16 :]
148
152
length = len (byt )
149
153
cursor = 0
150
154
@@ -162,7 +166,7 @@ def snappy_decode(payload):
162
166
return bytes (cramjam .snappy .decompress_raw (payload ))
163
167
164
168
165
- def lz4_encode (payload , level = 9 ) :
169
+ def lz4_encode (payload : Buffer , level : int = 9 ) -> bytes :
166
170
# level=9 is used by default by broker itself
167
171
# https://cwiki.apache.org/confluence/display/KAFKA/KIP-390%3A+Support+Compression+Level
168
172
if not has_lz4 ():
@@ -177,14 +181,14 @@ def lz4_encode(payload, level=9):
177
181
return bytes (compressor .finish ())
178
182
179
183
180
- def lz4_decode (payload ) :
184
+ def lz4_decode (payload : Buffer ) -> bytes :
181
185
if not has_lz4 ():
182
186
raise NotImplementedError ("LZ4 codec is not available" )
183
187
184
188
return bytes (cramjam .lz4 .decompress (payload ))
185
189
186
190
187
- def zstd_encode (payload , level = None ):
191
+ def zstd_encode (payload : Buffer , level : int | None = None ) -> bytes :
188
192
if not has_zstd ():
189
193
raise NotImplementedError ("Zstd codec is not available" )
190
194
@@ -196,7 +200,7 @@ def zstd_encode(payload, level=None):
196
200
return bytes (cramjam .zstd .compress (payload , level = level ))
197
201
198
202
199
- def zstd_decode (payload ) :
203
+ def zstd_decode (payload : Buffer ) -> bytes :
200
204
if not has_zstd ():
201
205
raise NotImplementedError ("Zstd codec is not available" )
202
206
0 commit comments