@@ -36,8 +36,17 @@ def train(current_host, hosts, num_cpus, num_gpus, channel_input_dirs, model_dir
36
36
# load training and validation data
37
37
# we use the gluon.data.vision.CIFAR10 class because of its built in pre-processing logic,
38
38
# but point it at the location where SageMaker placed the data files, so it doesn't download them again.
39
+
40
+ part_index = 0
41
+ for i , host in enumerate (hosts ):
42
+ if host == current_host :
43
+ part_index = i
44
+ break
45
+
46
+
39
47
data_dir = channel_input_dirs ['training' ]
40
- train_data = get_train_data (num_cpus , data_dir , batch_size , (3 , 32 , 32 ))
48
+ train_data = get_train_data (num_cpus , data_dir , batch_size , (3 , 32 , 32 ),
49
+ num_parts = len (hosts ), part_index = part_index )
41
50
test_data = get_test_data (num_cpus , data_dir , batch_size , (3 , 32 , 32 ))
42
51
43
52
# Collect all parameters from net and its children, then initialize them.
@@ -104,23 +113,26 @@ def save(net, model_dir):
104
113
os .rename (os .path .join (model_dir , best ), os .path .join (model_dir , 'model.params' ))
105
114
106
115
107
- def get_data (path , augment , num_cpus , batch_size , data_shape , resize = - 1 ):
116
+ def get_data (path , augment , num_cpus , batch_size , data_shape , resize = - 1 , num_parts = 1 , part_index = 0 ):
108
117
return mx .io .ImageRecordIter (
109
118
path_imgrec = path ,
110
119
resize = resize ,
111
120
data_shape = data_shape ,
112
121
batch_size = batch_size ,
113
122
rand_crop = augment ,
114
123
rand_mirror = augment ,
115
- preprocess_threads = num_cpus )
124
+ preprocess_threads = num_cpus ,
125
+ num_parts = num_parts ,
126
+ part_index = part_index )
116
127
117
128
118
129
def get_test_data (num_cpus , data_dir , batch_size , data_shape , resize = - 1 ):
119
- return get_data (os .path .join (data_dir , "test.rec" ), False , num_cpus , batch_size , data_shape , resize )
130
+ return get_data (os .path .join (data_dir , "test.rec" ), False , num_cpus , batch_size , data_shape , resize , 1 , 0 )
120
131
121
132
122
- def get_train_data (num_cpus , data_dir , batch_size , data_shape , resize = - 1 ):
123
- return get_data (os .path .join (data_dir , "train.rec" ), True , num_cpus , batch_size , data_shape , resize )
133
+ def get_train_data (num_cpus , data_dir , batch_size , data_shape , resize = - 1 , num_parts = 1 , part_index = 0 ):
134
+ return get_data (os .path .join (data_dir , "train.rec" ), True , num_cpus , batch_size , data_shape , resize , num_parts ,
135
+ part_index )
124
136
125
137
126
138
def test (ctx , net , test_data ):
0 commit comments