Skip to content

Commit c32195b

Browse files
author
Chuyang Deng
committed
fix: update pytorch mnist script
1 parent 4722900 commit c32195b

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

src/sagemaker/pytorch/defaults.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15-
PYTORCH_VERSION = "0.4"
15+
PYTORCH_VERSION = '1.1.0'
1616
"""Default PyTorch version for when the framework version is not specified.
1717
The latest PyTorch version is 1.1.0, but the default version is no longer updated so as to not break existing workflows.
1818
"""

tests/data/pytorch_mnist/mnist.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
logger.addHandler(logging.StreamHandler(sys.stdout))
1818

1919

20+
2021
class Net(nn.Module):
2122
# Based on https://github.com/pytorch/examples/blob/master/mnist/main.py
2223
def __init__(self):
@@ -47,6 +48,7 @@ def _get_train_data_loader(training_dir, is_distributed, batch_size, **kwargs):
4748
transform=transforms.Compose(
4849
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
4950
),
51+
download=True
5052
)
5153
train_sampler = (
5254
torch.utils.data.distributed.DistributedSampler(dataset) if is_distributed else None
@@ -70,6 +72,7 @@ def _get_test_data_loader(training_dir, **kwargs):
7072
transform=transforms.Compose(
7173
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
7274
),
75+
download=True
7376
),
7477
batch_size=1000,
7578
shuffle=True,

0 commit comments

Comments
 (0)