Skip to content

Commit 04f51e6

Browse files
navinsoniAo Guo
authored and
Ao Guo
committed
fix: Update localmode code to decode urllib response as UTF8 (aws#3284)
1 parent b7996b9 commit 04f51e6

File tree

3 files changed

+23
-6
lines changed

3 files changed

+23
-6
lines changed

src/sagemaker/local/entities.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ def start(self, input_data, output_data, transform_resources, **kwargs):
314314
endpoint_url = "http://%s:%d/execution-parameters" % (get_docker_host(), serving_port)
315315
response, code = _perform_request(endpoint_url)
316316
if code == 200:
317-
execution_parameters = json.loads(response.read())
317+
execution_parameters = json.loads(response.data.decode("utf-8"))
318318
# MaxConcurrentTransforms is ignored because we currently only support 1
319319
for setting in ("BatchStrategy", "MaxPayloadInMB"):
320320
if setting not in kwargs and setting in execution_parameters:

src/sagemaker/workflow/conditions.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,15 @@
2020
import abc
2121

2222
from enum import Enum
23-
from typing import List, Union
23+
from typing import Dict, List, Union
2424

2525
import attr
2626

27+
from sagemaker.workflow import is_pipeline_variable
2728
from sagemaker.workflow.entities import (
2829
DefaultEnumMeta,
2930
Entity,
31+
Expression,
3032
PrimitiveType,
3133
RequestType,
3234
)
@@ -289,3 +291,18 @@ def _referenced_steps(self) -> List[str]:
289291
for condition in self.conditions:
290292
steps.extend(condition._referenced_steps)
291293
return steps
294+
295+
296+
def primitive_or_expr(
297+
value: Union[ExecutionVariable, Expression, PrimitiveType, Parameter, Properties]
298+
) -> Union[Dict[str, str], PrimitiveType]:
299+
"""Provide the expression of the value or return value if it is a primitive.
300+
301+
Args:
302+
value (Union[ConditionValueType, PrimitiveType]): The value to evaluate.
303+
Returns:
304+
Either the expression of the value or the primitive value.
305+
"""
306+
if is_pipeline_variable(value):
307+
return value.expr
308+
return value

tests/unit/test_local_entities.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def test_start_local_transform_job(_perform_batch_inference, _perform_request, l
106106

107107
response = Mock()
108108
_perform_request.return_value = (response, 200)
109-
response.read.return_value = '{"BatchStrategy": "SingleRecord"}'
109+
response.data = '{"BatchStrategy": "SingleRecord"}'.encode("UTF-8")
110110
local_transform_job.primary_container["ModelDataUrl"] = "file:///some/model"
111111
local_transform_job.start(input_data, output_data, transform_resources, Environment={})
112112

@@ -176,9 +176,9 @@ def test_start_local_transform_job_from_remote_docker_host(
176176
output_data = {}
177177
transform_resources = {"InstanceType": "local"}
178178
m_get_docker_host.return_value = "some_host"
179-
perform_request_mock = Mock()
180-
m_perform_request.return_value = (perform_request_mock, 200)
181-
perform_request_mock.read.return_value = '{"BatchStrategy": "SingleRecord"}'
179+
response = Mock()
180+
m_perform_request.return_value = (response, 200)
181+
response.data = '{"BatchStrategy": "SingleRecord"}'.encode("UTF-8")
182182
local_transform_job.primary_container["ModelDataUrl"] = "file:///some/model"
183183
local_transform_job.start(input_data, output_data, transform_resources, Environment={})
184184
endpoints = [

0 commit comments

Comments
 (0)