23
23
24
24
def model_fn (model_dir ):
25
25
import eimx
26
-
26
+
27
27
def read_data_shapes (path , preferred_batch_size = 1 ):
28
- with open (path , 'r' ) as f :
28
+ with open (path , "r" ) as f :
29
29
signatures = json .load (f )
30
30
31
31
data_names = []
32
32
data_shapes = []
33
33
34
34
for s in signatures :
35
- name = s [' name' ]
35
+ name = s [" name" ]
36
36
data_names .append (name )
37
37
38
- shape = s [' shape' ]
38
+ shape = s [" shape" ]
39
39
40
40
if preferred_batch_size :
41
41
shape [0 ] = preferred_batch_size
@@ -44,15 +44,15 @@ def read_data_shapes(path, preferred_batch_size=1):
44
44
45
45
return data_names , data_shapes
46
46
47
- shapes_file = os .path .join (model_dir , ' model-shapes.json' )
47
+ shapes_file = os .path .join (model_dir , " model-shapes.json" )
48
48
data_names , data_shapes = read_data_shapes (shapes_file )
49
49
50
50
ctx = mx .cpu ()
51
- sym , args , aux = mx .model .load_checkpoint (os .path .join (model_dir , ' model' ), 0 )
52
- sym = sym .optimize_for (' EIA' )
51
+ sym , args , aux = mx .model .load_checkpoint (os .path .join (model_dir , " model" ), 0 )
52
+ sym = sym .optimize_for (" EIA" )
53
53
54
54
mod = mx .mod .Module (symbol = sym , context = ctx , data_names = data_names , label_names = None )
55
55
mod .bind (for_training = False , data_shapes = data_shapes )
56
56
mod .set_params (args , aux , allow_missing = True )
57
57
58
- return mod
58
+ return mod
0 commit comments