Skip to content

Commit 1f34950

Browse files
schinmayeepintaoz-aws
authored andcommitted
Feature: Resolve recipes correctly before launching (#1529)
* fix to work with launcher recipes * fix suffix for temp file * fix path and error message * fix for recipes from launcher * resolve recipes correctly * fix imports * reformat message to avoid code-doc test issue * code style fix * code style fix * code style fix * code style fix * code style fix * code style fix * code style fix * code style fix * code style fix * doc formatting * check if resolver exists before registering
1 parent 13e10c9 commit 1f34950

File tree

1 file changed

+41
-4
lines changed

1 file changed

+41
-4
lines changed

src/sagemaker/pytorch/estimator.py

+41-4
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414
from __future__ import absolute_import
1515

1616
import logging
17+
import math
1718
import os
1819
import shutil
1920
import tempfile
2021
from typing import Union, Optional, Dict
2122
from urllib.request import urlretrieve
2223

23-
from omegaconf import OmegaConf
24+
import omegaconf
25+
from omegaconf import OmegaConf, dictconfig
2426
from packaging.version import Version
2527

2628
from sagemaker.estimator import Framework, EstimatorBase
@@ -42,6 +44,19 @@
4244
logger = logging.getLogger("sagemaker")
4345

4446

47+
def _try_resolve_recipe(recipe, key=None):
48+
"""Try to resolve recipe and return resolved recipe."""
49+
if key is not None:
50+
recipe = dictconfig.DictConfig({key: recipe})
51+
try:
52+
OmegaConf.resolve(recipe)
53+
except omegaconf.errors.OmegaConfBaseException:
54+
return None
55+
if key is None:
56+
return recipe
57+
return recipe[key]
58+
59+
4560
class PyTorch(Framework):
4661
"""Handle end-to-end training and deployment of custom PyTorch code."""
4762

@@ -551,7 +566,9 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
551566
cls.recipe_train_dir = tempfile.TemporaryDirectory(prefix="training_")
552567
cls.recipe_launcher_dir = tempfile.TemporaryDirectory(prefix="launcher_")
553568

554-
temp_local_recipe = tempfile.NamedTemporaryFile(prefix="recipe").name
569+
temp_local_recipe = tempfile.NamedTemporaryFile(
570+
prefix="recipe_original", suffix=".yaml"
571+
).name
555572
if training_recipe.endswith(".yaml"):
556573
if os.path.isfile(training_recipe):
557574
shutil.copy(training_recipe, temp_local_recipe)
@@ -567,7 +584,7 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
567584
_run_clone_command(launcher_repo, cls.recipe_launcher_dir.name)
568585
recipe = os.path.join(
569586
cls.recipe_launcher_dir.name,
570-
"recipes-collection",
587+
"recipes_collection",
571588
"recipes",
572589
"training",
573590
training_recipe + ".yaml",
@@ -578,6 +595,7 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
578595
raise ValueError(f"Recipe {training_recipe} not found.")
579596

580597
recipe = OmegaConf.load(temp_local_recipe)
598+
os.unlink(temp_local_recipe)
581599

582600
if "instance_type" not in kwargs:
583601
raise ValueError("Must pass instance type to estimator when using training recipes.")
@@ -662,7 +680,26 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
662680
"Ignoring container from training_recipe. Use image_uri arg for estimator."
663681
)
664682

665-
OmegaConf.save(config=recipe, f=os.path.join(args["source_dir"], "recipe.yaml"))
683+
if not OmegaConf.has_resolver("multiply"):
684+
OmegaConf.register_new_resolver("multiply", lambda x, y: x * y, replace=True)
685+
if not OmegaConf.has_resolver("divide_ceil"):
686+
OmegaConf.register_new_resolver(
687+
"divide_ceil", lambda x, y: int(math.ceil(x / y)), replace=True
688+
)
689+
if not OmegaConf.has_resolver("divide_floor"):
690+
OmegaConf.register_new_resolver(
691+
"divide_floor", lambda x, y: int(math.floor(x / y)), replace=True
692+
)
693+
if not OmegaConf.has_resolver("add"):
694+
OmegaConf.register_new_resolver("add", lambda *numbers: sum(numbers))
695+
final_recipe = _try_resolve_recipe(recipe)
696+
if final_recipe is None:
697+
final_recipe = _try_resolve_recipe(recipe, "recipes")
698+
if final_recipe is None:
699+
final_recipe = _try_resolve_recipe(recipe, "training")
700+
if final_recipe is None:
701+
raise RuntimeError("Could not resolve provided recipe.")
702+
OmegaConf.save(config=final_recipe, f=os.path.join(args["source_dir"], "recipe.yaml"))
666703
args["hyperparameters"] = {"config-path": ".", "config-name": "recipe.yaml"}
667704

668705
return args

0 commit comments

Comments
 (0)