14
14
from __future__ import absolute_import
15
15
16
16
import logging
17
+ import math
17
18
import os
18
19
import shutil
19
20
import tempfile
20
21
from typing import Union , Optional , Dict
21
22
from urllib .request import urlretrieve
22
23
23
- from omegaconf import OmegaConf
24
+ import omegaconf
25
+ from omegaconf import OmegaConf , dictconfig
24
26
from packaging .version import Version
25
27
26
28
from sagemaker .estimator import Framework , EstimatorBase
42
44
logger = logging .getLogger ("sagemaker" )
43
45
44
46
47
+ def _try_resolve_recipe (recipe , key = None ):
48
+ """Try to resolve recipe and return resolved recipe."""
49
+ if key is not None :
50
+ recipe = dictconfig .DictConfig ({key : recipe })
51
+ try :
52
+ OmegaConf .resolve (recipe )
53
+ except omegaconf .errors .OmegaConfBaseException :
54
+ return None
55
+ if key is None :
56
+ return recipe
57
+ return recipe [key ]
58
+
59
+
45
60
class PyTorch (Framework ):
46
61
"""Handle end-to-end training and deployment of custom PyTorch code."""
47
62
@@ -551,7 +566,9 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
551
566
cls .recipe_train_dir = tempfile .TemporaryDirectory (prefix = "training_" )
552
567
cls .recipe_launcher_dir = tempfile .TemporaryDirectory (prefix = "launcher_" )
553
568
554
- temp_local_recipe = tempfile .NamedTemporaryFile (prefix = "recipe" ).name
569
+ temp_local_recipe = tempfile .NamedTemporaryFile (
570
+ prefix = "recipe_original" , suffix = ".yaml"
571
+ ).name
555
572
if training_recipe .endswith (".yaml" ):
556
573
if os .path .isfile (training_recipe ):
557
574
shutil .copy (training_recipe , temp_local_recipe )
@@ -567,7 +584,7 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
567
584
_run_clone_command (launcher_repo , cls .recipe_launcher_dir .name )
568
585
recipe = os .path .join (
569
586
cls .recipe_launcher_dir .name ,
570
- "recipes-collection " ,
587
+ "recipes_collection " ,
571
588
"recipes" ,
572
589
"training" ,
573
590
training_recipe + ".yaml" ,
@@ -578,6 +595,7 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
578
595
raise ValueError (f"Recipe { training_recipe } not found." )
579
596
580
597
recipe = OmegaConf .load (temp_local_recipe )
598
+ os .unlink (temp_local_recipe )
581
599
582
600
if "instance_type" not in kwargs :
583
601
raise ValueError ("Must pass instance type to estimator when using training recipes." )
@@ -662,7 +680,26 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
662
680
"Ignoring container from training_recipe. Use image_uri arg for estimator."
663
681
)
664
682
665
- OmegaConf .save (config = recipe , f = os .path .join (args ["source_dir" ], "recipe.yaml" ))
683
+ if not OmegaConf .has_resolver ("multiply" ):
684
+ OmegaConf .register_new_resolver ("multiply" , lambda x , y : x * y , replace = True )
685
+ if not OmegaConf .has_resolver ("divide_ceil" ):
686
+ OmegaConf .register_new_resolver (
687
+ "divide_ceil" , lambda x , y : int (math .ceil (x / y )), replace = True
688
+ )
689
+ if not OmegaConf .has_resolver ("divide_floor" ):
690
+ OmegaConf .register_new_resolver (
691
+ "divide_floor" , lambda x , y : int (math .floor (x / y )), replace = True
692
+ )
693
+ if not OmegaConf .has_resolver ("add" ):
694
+ OmegaConf .register_new_resolver ("add" , lambda * numbers : sum (numbers ))
695
+ final_recipe = _try_resolve_recipe (recipe )
696
+ if final_recipe is None :
697
+ final_recipe = _try_resolve_recipe (recipe , "recipes" )
698
+ if final_recipe is None :
699
+ final_recipe = _try_resolve_recipe (recipe , "training" )
700
+ if final_recipe is None :
701
+ raise RuntimeError ("Could not resolve provided recipe." )
702
+ OmegaConf .save (config = final_recipe , f = os .path .join (args ["source_dir" ], "recipe.yaml" ))
666
703
args ["hyperparameters" ] = {"config-path" : "." , "config-name" : "recipe.yaml" }
667
704
668
705
return args
0 commit comments