diff --git a/adafruit_json_stream.py b/adafruit_json_stream.py index 0c08ae5..4b9da92 100644 --- a/adafruit_json_stream.py +++ b/adafruit_json_stream.py @@ -40,7 +40,9 @@ def read(self): self.i += 1 return char - def fast_forward(self, closer, *, return_object=False): + def fast_forward( + self, closer, *, return_object=False + ): # pylint: disable=too-many-branches """ Read through the stream until the character is ``closer``, ``]`` (ending a list) or ``}`` (ending an object.) Intermediate lists and @@ -62,6 +64,7 @@ def fast_forward(self, closer, *, return_object=False): # } = 125, { = 123 buffer[0] = closer - 2 + ignore_next = False while close_stack: char = self.read() count += 1 @@ -71,8 +74,14 @@ def fast_forward(self, closer, *, return_object=False): new_buffer[: len(buffer)] = buffer buffer = new_buffer buffer[count] = char - if char == close_stack[-1]: + if ignore_next: + # that character was escaped, skip it + ignore_next = False + elif char == close_stack[-1]: close_stack.pop() + elif char == ord("\\") and close_stack[-1] == ord('"'): + # if backslash, ignore the next character + ignore_next = True elif char == ord('"'): close_stack.append(ord('"')) elif close_stack[-1] == ord('"'): @@ -96,26 +105,41 @@ def next_value(self, endswith=None): if isinstance(endswith, str): endswith = ord(endswith) in_string = False + ignore_next = False while True: try: char = self.read() except EOFError: char = endswith - if not in_string and (char == endswith or char in (ord("]"), ord("}"))): - self.last_char = char - if len(buf) == 0: - return None - value_string = bytes(buf).decode("utf-8") - return json.loads(value_string) - if char == ord("{"): - return TransientObject(self) - if char == ord("["): - return TransientList(self) + in_string = False + ignore_next = False if not in_string: - in_string = char == ord('"') + # end character or object/list end + if char == endswith or char in (ord("]"), ord("}")): + self.last_char = char + if len(buf) == 0: + return None + value_string = bytes(buf).decode("utf-8") + return json.loads(value_string) + # string or sub object + if char == ord("{"): + return TransientObject(self) + if char == ord("["): + return TransientList(self) + # start a string + if char == ord('"'): + in_string = True else: - in_string = char != ord('"') + # skipping any closing or opening character if in a string + # also skipping escaped characters (like quotes in string) + if ignore_next: + ignore_next = False + elif char == ord("\\"): + ignore_next = True + elif char == ord('"'): + in_string = False + buf.append(char) diff --git a/tests/test_json_stream.py b/tests/test_json_stream.py index b8197fe..04f4faa 100644 --- a/tests/test_json_stream.py +++ b/tests/test_json_stream.py @@ -66,6 +66,38 @@ def dict_with_all_types(): """ +@pytest.fixture +def list_with_bad_strings(): + return r""" + [ + "\"}\"", + "{\"a\": 1, \"b\": [2,3]}", + "\"", + "\\\"", + "\\\\\"", + "\\x40\"", + "[[[{{{", + "]]]}}}" + ] + """ + + +@pytest.fixture +def dict_with_bad_strings(): + return r""" + { + "1": "\"}\"", + "2": "{\"a\": 1, \"b\": [2,3]}", + "3": "\"", + "4": "\\\"", + "5": "\\\\\"", + "6": "\\x40\"", + "7": "[[[{{{", + "8": "]]]}}}" + } + """ + + @pytest.fixture def list_with_values(): return """ @@ -308,6 +340,116 @@ def test_complex_dict(complex_dict): assert sub_counter == 12 +def test_bad_strings_in_list(list_with_bad_strings): + """Test loading different strings that can confuse the parser.""" + + bad_strings = [ + '"}"', + '{"a": 1, "b": [2,3]}', + '"', + '\\"', + '\\\\"', + '\\x40"', + "[[[{{{", + "]]]}}}", + ] + + assert json.loads(list_with_bad_strings) + + # get each separately + stream = adafruit_json_stream.load(BytesChunkIO(list_with_bad_strings.encode())) + for i, item in enumerate(stream): + assert item == bad_strings[i] + + +def test_bad_strings_in_list_iter(list_with_bad_strings): + """Test loading different strings that can confuse the parser.""" + + bad_strings = [ + '"}"', + '{"a": 1, "b": [2,3]}', + '"', + '\\"', + '\\\\"', + '\\x40"', + "[[[{{{", + "]]]}}}", + ] + + assert json.loads(list_with_bad_strings) + + # get each separately + stream = adafruit_json_stream.load(BytesChunkIO(list_with_bad_strings.encode())) + for i, item in enumerate(stream): + assert item == bad_strings[i] + + +def test_bad_strings_in_dict_as_object(dict_with_bad_strings): + """Test loading different strings that can confuse the parser.""" + + bad_strings = { + "1": '"}"', + "2": '{"a": 1, "b": [2,3]}', + "3": '"', + "4": '\\"', + "5": '\\\\"', + "6": '\\x40"', + "7": "[[[{{{", + "8": "]]]}}}", + } + + # read all at once + stream = adafruit_json_stream.load(BytesChunkIO(dict_with_bad_strings.encode())) + assert stream.as_object() == bad_strings + + +def test_bad_strings_in_dict_all_keys(dict_with_bad_strings): + """Test loading different strings that can confuse the parser.""" + + bad_strings = { + "1": '"}"', + "2": '{"a": 1, "b": [2,3]}', + "3": '"', + "4": '\\"', + "5": '\\\\"', + "6": '\\x40"', + "7": "[[[{{{", + "8": "]]]}}}", + } + + # read one after the other with keys + stream = adafruit_json_stream.load(BytesChunkIO(dict_with_bad_strings.encode())) + assert stream["1"] == bad_strings["1"] + assert stream["2"] == bad_strings["2"] + assert stream["3"] == bad_strings["3"] + assert stream["4"] == bad_strings["4"] + assert stream["5"] == bad_strings["5"] + assert stream["6"] == bad_strings["6"] + assert stream["7"] == bad_strings["7"] + assert stream["8"] == bad_strings["8"] + + +def test_bad_strings_in_dict_skip_some(dict_with_bad_strings): + """Test loading different strings that can confuse the parser.""" + + bad_strings = { + "1": '"}"', + "2": '{"a": 1, "b": [2,3]}', + "3": '"', + "4": '\\"', + "5": '\\\\"', + "6": '\\x40"', + "7": "[[[{{{", + "8": "]]]}}}", + } + + # read some, skip some + stream = adafruit_json_stream.load(BytesChunkIO(dict_with_bad_strings.encode())) + assert stream["2"] == bad_strings["2"] + assert stream["5"] == bad_strings["5"] + assert stream["8"] == bad_strings["8"] + + def test_complex_dict_grabbing(complex_dict): """Test loading a complex dict and grabbing specific keys."""