Skip to content

Commit cb676eb

Browse files
committed
Add correct file format to model file.
1 parent 55a7acb commit cb676eb

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tests/data/pytorch_mnist/mnist.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,6 @@ def test(model, test_loader, device):
165165

166166
def model_fn(model_dir):
167167
model = torch.nn.DataParallel(Net())
168-
with open(os.path.join(model_dir, 'model'), 'rb') as f:
168+
with open(os.path.join(model_dir, 'model.pth'), 'rb') as f:
169169
model.load_state_dict(torch.load(f))
170170
return model

0 commit comments

Comments
 (0)