Skip to content

Commit e731185

Browse files
authored
[mypyc] Optimize str.encode with specializations for common used encodings (#18232)
Tested with: ``` import time start = time.time() for i in range(20000000): "test".encode('utf-8') print(time.time() - start) ``` With PR applied and running mypyc, `python3 -c "import test"` runs in: 0.5383486747741699 0.5224344730377197 0.555696964263916 Without PR applied: 0.7315819263458252 0.7105758190155029 0.7471706867218018 Similar times observed for "ascii"
1 parent 222b104 commit e731185

File tree

4 files changed

+166
-12
lines changed

4 files changed

+166
-12
lines changed

mypyc/irbuild/specialize.py

+57
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,11 @@
8989
dict_values_op,
9090
)
9191
from mypyc.primitives.list_ops import new_list_set_item_op
92+
from mypyc.primitives.str_ops import (
93+
str_encode_ascii_strict,
94+
str_encode_latin1_strict,
95+
str_encode_utf8_strict,
96+
)
9297
from mypyc.primitives.tuple_ops import new_tuple_set_item_op
9398

9499
# Specializers are attempted before compiling the arguments to the
@@ -682,6 +687,58 @@ def translate_fstring(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Va
682687
return None
683688

684689

690+
@specialize_function("encode", str_rprimitive)
691+
def str_encode_fast_path(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None:
692+
"""Specialize common cases of str.encode for most used encodings and strict errors."""
693+
694+
if not isinstance(callee, MemberExpr):
695+
return None
696+
697+
# We can only specialize if we have string literals as args
698+
if len(expr.arg_kinds) > 0 and not isinstance(expr.args[0], StrExpr):
699+
return None
700+
if len(expr.arg_kinds) > 1 and not isinstance(expr.args[1], StrExpr):
701+
return None
702+
703+
encoding = "utf8"
704+
errors = "strict"
705+
if len(expr.arg_kinds) > 0 and isinstance(expr.args[0], StrExpr):
706+
if expr.arg_kinds[0] == ARG_NAMED:
707+
if expr.arg_names[0] == "encoding":
708+
encoding = expr.args[0].value
709+
elif expr.arg_names[0] == "errors":
710+
errors = expr.args[0].value
711+
elif expr.arg_kinds[0] == ARG_POS:
712+
encoding = expr.args[0].value
713+
else:
714+
return None
715+
if len(expr.arg_kinds) > 1 and isinstance(expr.args[1], StrExpr):
716+
if expr.arg_kinds[1] == ARG_NAMED:
717+
if expr.arg_names[1] == "encoding":
718+
encoding = expr.args[1].value
719+
elif expr.arg_names[1] == "errors":
720+
errors = expr.args[1].value
721+
elif expr.arg_kinds[1] == ARG_POS:
722+
errors = expr.args[1].value
723+
else:
724+
return None
725+
726+
if errors != "strict":
727+
# We can only specialize strict errors
728+
return None
729+
730+
encoding = encoding.lower().replace("-", "").replace("_", "") # normalize
731+
# Specialized encodings and their accepted aliases
732+
if encoding in ["u8", "utf", "utf8", "cp65001"]:
733+
return builder.call_c(str_encode_utf8_strict, [builder.accept(callee.expr)], expr.line)
734+
elif encoding in ["646", "ascii", "usascii"]:
735+
return builder.call_c(str_encode_ascii_strict, [builder.accept(callee.expr)], expr.line)
736+
elif encoding in ["iso88591", "8859", "cp819", "latin", "latin1", "l1"]:
737+
return builder.call_c(str_encode_latin1_strict, [builder.accept(callee.expr)], expr.line)
738+
739+
return None
740+
741+
685742
@specialize_function("mypy_extensions.i64")
686743
def translate_i64(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None:
687744
if len(expr.args) != 1 or expr.arg_kinds[0] != ARG_POS:

mypyc/primitives/str_ops.py

+24
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,30 @@
219219
extra_int_constants=[(0, pointer_rprimitive)],
220220
)
221221

222+
# str.encode(encoding) - utf8 strict specialization
223+
str_encode_utf8_strict = custom_op(
224+
arg_types=[str_rprimitive],
225+
return_type=bytes_rprimitive,
226+
c_function_name="PyUnicode_AsUTF8String",
227+
error_kind=ERR_MAGIC,
228+
)
229+
230+
# str.encode(encoding) - ascii strict specialization
231+
str_encode_ascii_strict = custom_op(
232+
arg_types=[str_rprimitive],
233+
return_type=bytes_rprimitive,
234+
c_function_name="PyUnicode_AsASCIIString",
235+
error_kind=ERR_MAGIC,
236+
)
237+
238+
# str.encode(encoding) - latin1 strict specialization
239+
str_encode_latin1_strict = custom_op(
240+
arg_types=[str_rprimitive],
241+
return_type=bytes_rprimitive,
242+
c_function_name="PyUnicode_AsLatin1String",
243+
error_kind=ERR_MAGIC,
244+
)
245+
222246
# str.encode(encoding, errors)
223247
method_op(
224248
name="encode",

mypyc/test-data/fixtures/ir.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def upper(self) -> str: ...
110110
def startswith(self, x: str, start: int=..., end: int=...) -> bool: ...
111111
def endswith(self, x: str, start: int=..., end: int=...) -> bool: ...
112112
def replace(self, old: str, new: str, maxcount: int=...) -> str: ...
113-
def encode(self, x: str=..., y: str=...) -> bytes: ...
113+
def encode(self, encoding: str=..., errors: str=...) -> bytes: ...
114114

115115
class float:
116116
def __init__(self, x: object) -> None: pass

mypyc/test-data/irbuild-str.test

+84-11
Original file line numberDiff line numberDiff line change
@@ -293,20 +293,93 @@ L0:
293293
def f(s: str) -> None:
294294
s.encode()
295295
s.encode('utf-8')
296+
s.encode('utf8', 'strict')
297+
s.encode('latin1', errors='strict')
298+
s.encode(encoding='ascii')
299+
s.encode(errors='strict', encoding='latin-1')
300+
s.encode('utf-8', 'backslashreplace')
296301
s.encode('ascii', 'backslashreplace')
302+
encoding = 'utf8'
303+
s.encode(encoding)
304+
errors = 'strict'
305+
s.encode('utf8', errors)
306+
s.encode('utf8', errors=errors)
307+
s.encode(errors=errors)
308+
s.encode(encoding=encoding, errors=errors)
309+
s.encode('latin2')
310+
297311
[out]
298312
def f(s):
299313
s :: str
300-
r0 :: bytes
301-
r1 :: str
302-
r2 :: bytes
303-
r3, r4 :: str
304-
r5 :: bytes
314+
r0, r1, r2, r3, r4, r5 :: bytes
315+
r6, r7 :: str
316+
r8 :: bytes
317+
r9, r10 :: str
318+
r11 :: bytes
319+
r12, encoding :: str
320+
r13 :: bytes
321+
r14, errors, r15 :: str
322+
r16 :: bytes
323+
r17, r18 :: str
324+
r19 :: object
325+
r20 :: str
326+
r21 :: tuple
327+
r22 :: dict
328+
r23 :: object
329+
r24 :: str
330+
r25 :: object
331+
r26 :: str
332+
r27 :: tuple
333+
r28 :: dict
334+
r29 :: object
335+
r30 :: str
336+
r31 :: object
337+
r32, r33 :: str
338+
r34 :: tuple
339+
r35 :: dict
340+
r36 :: object
341+
r37 :: str
342+
r38 :: bytes
305343
L0:
306-
r0 = CPy_Encode(s, 0, 0)
307-
r1 = 'utf-8'
308-
r2 = CPy_Encode(s, r1, 0)
309-
r3 = 'ascii'
310-
r4 = 'backslashreplace'
311-
r5 = CPy_Encode(s, r3, r4)
344+
r0 = PyUnicode_AsUTF8String(s)
345+
r1 = PyUnicode_AsUTF8String(s)
346+
r2 = PyUnicode_AsUTF8String(s)
347+
r3 = PyUnicode_AsLatin1String(s)
348+
r4 = PyUnicode_AsASCIIString(s)
349+
r5 = PyUnicode_AsLatin1String(s)
350+
r6 = 'utf-8'
351+
r7 = 'backslashreplace'
352+
r8 = CPy_Encode(s, r6, r7)
353+
r9 = 'ascii'
354+
r10 = 'backslashreplace'
355+
r11 = CPy_Encode(s, r9, r10)
356+
r12 = 'utf8'
357+
encoding = r12
358+
r13 = CPy_Encode(s, encoding, 0)
359+
r14 = 'strict'
360+
errors = r14
361+
r15 = 'utf8'
362+
r16 = CPy_Encode(s, r15, errors)
363+
r17 = 'utf8'
364+
r18 = 'encode'
365+
r19 = CPyObject_GetAttr(s, r18)
366+
r20 = 'errors'
367+
r21 = PyTuple_Pack(1, r17)
368+
r22 = CPyDict_Build(1, r20, errors)
369+
r23 = PyObject_Call(r19, r21, r22)
370+
r24 = 'encode'
371+
r25 = CPyObject_GetAttr(s, r24)
372+
r26 = 'errors'
373+
r27 = PyTuple_Pack(0)
374+
r28 = CPyDict_Build(1, r26, errors)
375+
r29 = PyObject_Call(r25, r27, r28)
376+
r30 = 'encode'
377+
r31 = CPyObject_GetAttr(s, r30)
378+
r32 = 'encoding'
379+
r33 = 'errors'
380+
r34 = PyTuple_Pack(0)
381+
r35 = CPyDict_Build(2, r32, encoding, r33, errors)
382+
r36 = PyObject_Call(r31, r34, r35)
383+
r37 = 'latin2'
384+
r38 = CPy_Encode(s, r37, 0)
312385
return 1

0 commit comments

Comments
 (0)