Skip to content

Commit 31710fa

Browse files
author
Yuchen Nie
committed
reformat mnist_ei.py
1 parent d0afe51 commit 31710fa

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

tests/data/mxnet_mnist/mnist_ei.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,19 @@
2323

2424
def model_fn(model_dir):
2525
import eimx
26-
26+
2727
def read_data_shapes(path, preferred_batch_size=1):
28-
with open(path, 'r') as f:
28+
with open(path, "r") as f:
2929
signatures = json.load(f)
3030

3131
data_names = []
3232
data_shapes = []
3333

3434
for s in signatures:
35-
name = s['name']
35+
name = s["name"]
3636
data_names.append(name)
3737

38-
shape = s['shape']
38+
shape = s["shape"]
3939

4040
if preferred_batch_size:
4141
shape[0] = preferred_batch_size
@@ -44,15 +44,15 @@ def read_data_shapes(path, preferred_batch_size=1):
4444

4545
return data_names, data_shapes
4646

47-
shapes_file = os.path.join(model_dir, 'model-shapes.json')
47+
shapes_file = os.path.join(model_dir, "model-shapes.json")
4848
data_names, data_shapes = read_data_shapes(shapes_file)
4949

5050
ctx = mx.cpu()
51-
sym, args, aux = mx.model.load_checkpoint(os.path.join(model_dir, 'model'), 0)
52-
sym = sym.optimize_for('EIA')
51+
sym, args, aux = mx.model.load_checkpoint(os.path.join(model_dir, "model"), 0)
52+
sym = sym.optimize_for("EIA")
5353

5454
mod = mx.mod.Module(symbol=sym, context=ctx, data_names=data_names, label_names=None)
5555
mod.bind(for_training=False, data_shapes=data_shapes)
5656
mod.set_params(args, aux, allow_missing=True)
5757

58-
return mod
58+
return mod

0 commit comments

Comments
 (0)