Skip to content

Commit d6b6487

Browse files
authored
Don't try to invoke str() on Python 2 unicode strings (aws#42)
1 parent 0e84048 commit d6b6487

File tree

3 files changed

+39
-8
lines changed

3 files changed

+39
-8
lines changed

src/sagemaker/tuner.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from sagemaker.estimator import Framework
2121
from sagemaker.job import _Job
2222
from sagemaker.session import Session
23-
from sagemaker.utils import base_name_from_image, name_from_base
23+
from sagemaker.utils import base_name_from_image, name_from_base, to_str
2424

2525
# TODO: probably move these somewhere to Amazon Estimator land after
2626
# the circular dependency issue is resolved
@@ -45,8 +45,8 @@ def __init__(self, min_value, max_value):
4545

4646
def as_tuning_range(self, name):
4747
return {'Name': name,
48-
'MinValue': str(self.min_value),
49-
'MaxValue': str(self.max_value)}
48+
'MinValue': to_str(self.min_value),
49+
'MaxValue': to_str(self.max_value)}
5050

5151

5252
class ContinuousParameter(_ParameterRange):
@@ -61,9 +61,9 @@ class CategoricalParameter(_ParameterRange):
6161

6262
def __init__(self, values):
6363
if isinstance(values, list):
64-
self.values = [str(v) for v in values]
64+
self.values = [to_str(v) for v in values]
6565
else:
66-
self.values = [str(values)]
66+
self.values = [to_str(values)]
6767

6868
def as_tuning_range(self, name):
6969
return {'Name': name,
@@ -108,8 +108,7 @@ def __init__(self, estimator, objective_metric_name, hyperparameter_ranges, metr
108108
self._validate_parameter_ranges()
109109

110110
def prepare_for_training(self):
111-
# TODO: Change this so that it can handle unicode in Python 2
112-
self.static_hyperparameters = {str(k): str(v) for (k, v) in self.estimator.hyperparameters().items()}
111+
self.static_hyperparameters = {to_str(k): to_str(v) for (k, v) in self.estimator.hyperparameters().items()}
113112
for hyperparameter_name in self._hyperparameter_ranges.keys():
114113
self.static_hyperparameters.pop(hyperparameter_name, None)
115114

src/sagemaker/utils.py

+20
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import sys
1516
import time
1617

1718
import re
@@ -100,6 +101,25 @@ def get_config_value(key_path, config):
100101
return current_section
101102

102103

104+
def to_str(value):
105+
"""Convert the input to a string, unless it is a unicode string in Python 2.
106+
107+
Unicode strings are supported as native strings in Python 3, but ``str()`` cannot be
108+
invoked on unicode strings in Python 2, so we need to check for that case when
109+
converting user-specified values to strings.
110+
111+
Args:
112+
value: The value to convert to a string.
113+
114+
Returns:
115+
str or unicode: The string representation of the value or the unicode string itself.
116+
"""
117+
if sys.version_info.major < 3 and isinstance(value, unicode): # noqa: F821
118+
return value
119+
else:
120+
return str(value)
121+
122+
103123
class DeferredError(object):
104124
"""Stores an exception and raises it at a later time anytime this
105125
object is accessed in any way. Useful to allow soft-dependencies on imports,

tests/unit/test_utils.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# -*- coding: utf-8 -*-
2+
13
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
24
#
35
# Licensed under the Apache License, Version 2.0 (the "License"). You
@@ -15,7 +17,7 @@
1517
import pytest
1618
from mock import patch
1719

18-
from sagemaker.utils import get_config_value, name_from_base, DeferredError
20+
from sagemaker.utils import get_config_value, name_from_base, to_str, DeferredError
1921

2022
NAME = 'base_name'
2123

@@ -65,3 +67,13 @@ def test_name_from_base(sagemaker_timestamp):
6567
def test_name_from_base_short(sagemaker_short_timestamp):
6668
name_from_base(NAME, short=True)
6769
assert sagemaker_short_timestamp.called_once
70+
71+
72+
def test_to_str_with_native_string():
73+
value = 'some string'
74+
assert to_str(value) == value
75+
76+
77+
def test_to_str_with_unicode_string():
78+
value = u'åñøthér strîng'
79+
assert to_str(value) == value

0 commit comments

Comments
 (0)