-
Notifications
You must be signed in to change notification settings - Fork 1.2k
local mode: support output_path #449
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,6 +33,9 @@ | |
import yaml | ||
|
||
import sagemaker | ||
import sagemaker.local.data | ||
import sagemaker.local.utils | ||
import sagemaker.utils | ||
|
||
CONTAINER_PREFIX = 'algo' | ||
DOCKER_COMPOSE_FILENAME = 'docker-compose.yaml' | ||
|
@@ -78,7 +81,7 @@ def __init__(self, instance_type, instance_count, image, sagemaker_session=None) | |
self.container_root = None | ||
self.container = None | ||
|
||
def train(self, input_data_config, hyperparameters, job_name): | ||
def train(self, input_data_config, output_data_config, hyperparameters, job_name): | ||
"""Run a training job locally using docker-compose. | ||
Args: | ||
input_data_config (dict): The Input Data Configuration, this contains data such as the | ||
|
@@ -126,23 +129,17 @@ def train(self, input_data_config, hyperparameters, job_name): | |
msg = "Failed to run: %s, %s" % (compose_command, str(e)) | ||
raise RuntimeError(msg) | ||
|
||
s3_artifacts = self.retrieve_artifacts(compose_data) | ||
artifacts = self.retrieve_artifacts(compose_data, output_data_config, job_name) | ||
|
||
# free up the training data directory as it may contain | ||
# lots of data downloaded from S3. This doesn't delete any local | ||
# data that was just mounted to the container. | ||
_delete_tree(data_dir) | ||
_delete_tree(shared_dir) | ||
# Also free the container config files. | ||
for host in self.hosts: | ||
container_config_path = os.path.join(self.container_root, host) | ||
_delete_tree(container_config_path) | ||
|
||
self._cleanup() | ||
# Print our Job Complete line to have a simmilar experience to training on SageMaker where you | ||
dirs_to_delete = [data_dir, shared_dir] | ||
self._cleanup(dirs_to_delete) | ||
# Print our Job Complete line to have a similar experience to training on SageMaker where you | ||
# see this line at the end. | ||
print('===== Job Complete =====') | ||
return s3_artifacts | ||
return artifacts | ||
|
||
def serve(self, model_dir, environment): | ||
"""Host a local endpoint using docker-compose. | ||
|
@@ -188,7 +185,7 @@ def stop_serving(self): | |
# for serving we can delete everything in the container root. | ||
_delete_tree(self.container_root) | ||
|
||
def retrieve_artifacts(self, compose_data): | ||
def retrieve_artifacts(self, compose_data, output_data_config, job_name): | ||
"""Get the model artifacts from all the container nodes. | ||
|
||
Used after training completes to gather the data from all the individual containers. As the | ||
|
@@ -201,26 +198,49 @@ def retrieve_artifacts(self, compose_data): | |
Returns: Local path to the collected model artifacts. | ||
|
||
""" | ||
# Grab the model artifacts from all the Nodes. | ||
s3_artifacts = os.path.join(self.container_root, 's3_artifacts') | ||
os.mkdir(s3_artifacts) | ||
# We need a directory to store the artfiacts from all the nodes | ||
# and another one to contained the compressed final artifacts | ||
artifacts = os.path.join(self.container_root, 'artifacts') | ||
compressed_artifacts = os.path.join(self.container_root, 'compressed_artifacts') | ||
os.mkdir(artifacts) | ||
|
||
model_artifacts = os.path.join(artifacts, 'model') | ||
output_artifacts = os.path.join(artifacts, 'output') | ||
|
||
s3_model_artifacts = os.path.join(s3_artifacts, 'model') | ||
s3_output_artifacts = os.path.join(s3_artifacts, 'output') | ||
os.mkdir(s3_model_artifacts) | ||
os.mkdir(s3_output_artifacts) | ||
artifact_dirs = [model_artifacts, output_artifacts, compressed_artifacts] | ||
for d in artifact_dirs: | ||
os.mkdir(d) | ||
|
||
# Gather the artifacts from all nodes into artifacts/model and artifacts/output | ||
for host in self.hosts: | ||
volumes = compose_data['services'][str(host)]['volumes'] | ||
|
||
for volume in volumes: | ||
host_dir, container_dir = volume.split(':') | ||
if container_dir == '/opt/ml/model': | ||
sagemaker.local.utils.recursive_copy(host_dir, s3_model_artifacts) | ||
sagemaker.local.utils.recursive_copy(host_dir, model_artifacts) | ||
elif container_dir == '/opt/ml/output': | ||
sagemaker.local.utils.recursive_copy(host_dir, s3_output_artifacts) | ||
sagemaker.local.utils.recursive_copy(host_dir, output_artifacts) | ||
|
||
# Tar Artifacts -> model.tar.gz and output.tar.gz | ||
model_files = [os.path.join(model_artifacts, name) for name in os.listdir(model_artifacts)] | ||
output_files = [os.path.join(output_artifacts, name) for name in os.listdir(output_artifacts)] | ||
sagemaker.utils.create_tar_file(model_files, os.path.join(compressed_artifacts, 'model.tar.gz')) | ||
sagemaker.utils.create_tar_file(output_files, os.path.join(compressed_artifacts, 'output.tar.gz')) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should that be a function? (all 5 lines) |
||
|
||
if output_data_config['S3OutputPath'] == '': | ||
output_data = 'file://%s' % compressed_artifacts | ||
else: | ||
# Now we just need to move the compressed artifacts to wherever they are required | ||
output_data = sagemaker.local.utils.move_to_destination( | ||
compressed_artifacts, | ||
output_data_config['S3OutputPath'], | ||
job_name, | ||
self.sagemaker_session) | ||
|
||
_delete_tree(model_artifacts) | ||
_delete_tree(output_artifacts) | ||
|
||
return s3_model_artifacts | ||
return os.path.join(output_data, 'model.tar.gz') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we moved both files should we return both (model and output) too?? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not really, SageMaker just returns the S3ModelArtifacts - if you want to look at output.tar.gz you basically have to do a replace in the string. This gets sent directly to the local client describeTrainingJob() |
||
|
||
def write_config_files(self, host, hyperparameters, input_data_config): | ||
"""Write the config files for the training containers. | ||
|
@@ -235,7 +255,6 @@ def write_config_files(self, host, hyperparameters, input_data_config): | |
Returns: None | ||
|
||
""" | ||
|
||
config_path = os.path.join(self.container_root, host, 'input', 'config') | ||
|
||
resource_config = { | ||
|
@@ -261,29 +280,13 @@ def _prepare_training_volumes(self, data_dir, input_data_config, hyperparameters | |
# mount the local directory to the container. For S3 Data we will download the S3 data | ||
# first. | ||
for channel in input_data_config: | ||
if channel['DataSource'] and 'S3DataSource' in channel['DataSource']: | ||
uri = channel['DataSource']['S3DataSource']['S3Uri'] | ||
elif channel['DataSource'] and 'FileDataSource' in channel['DataSource']: | ||
uri = channel['DataSource']['FileDataSource']['FileUri'] | ||
else: | ||
raise ValueError('Need channel[\'DataSource\'] to have' | ||
' [\'S3DataSource\'] or [\'FileDataSource\']') | ||
|
||
parsed_uri = urlparse(uri) | ||
key = parsed_uri.path.lstrip('/') | ||
|
||
uri = channel['DataUri'] | ||
channel_name = channel['ChannelName'] | ||
channel_dir = os.path.join(data_dir, channel_name) | ||
os.mkdir(channel_dir) | ||
|
||
if parsed_uri.scheme == 's3': | ||
bucket_name = parsed_uri.netloc | ||
sagemaker.utils.download_folder(bucket_name, key, channel_dir, self.sagemaker_session) | ||
elif parsed_uri.scheme == 'file': | ||
path = parsed_uri.path | ||
volumes.append(_Volume(path, channel=channel_name)) | ||
else: | ||
raise ValueError('Unknown URI scheme {}'.format(parsed_uri.scheme)) | ||
data_source = sagemaker.local.data.get_data_source_instance(uri, self.sagemaker_session) | ||
volumes.append(_Volume(data_source.get_root_dir(), channel=channel_name)) | ||
|
||
# If there is a training script directory and it is a local directory, | ||
# mount it to the container. | ||
|
@@ -301,25 +304,20 @@ def _prepare_serving_volumes(self, model_location): | |
volumes = [] | ||
host = self.hosts[0] | ||
# Make the model available to the container. If this is a local file just mount it to | ||
# the container as a volume. If it is an S3 location download it and extract the tar file. | ||
# the container as a volume. If it is an S3 location, the DataSource will download it, we | ||
# just need to extract the tar file. | ||
host_dir = os.path.join(self.container_root, host) | ||
os.makedirs(host_dir) | ||
|
||
if model_location.startswith('s3'): | ||
container_model_dir = os.path.join(self.container_root, host, 'model') | ||
os.makedirs(container_model_dir) | ||
model_data_source = sagemaker.local.data.get_data_source_instance( | ||
model_location, self.sagemaker_session) | ||
|
||
parsed_uri = urlparse(model_location) | ||
filename = os.path.basename(parsed_uri.path) | ||
tar_location = os.path.join(container_model_dir, filename) | ||
sagemaker.utils.download_file(parsed_uri.netloc, parsed_uri.path, tar_location, self.sagemaker_session) | ||
for filename in model_data_source.get_file_list(): | ||
if tarfile.is_tarfile(filename): | ||
with tarfile.open(filename) as tar: | ||
tar.extractall(path=model_data_source.get_root_dir()) | ||
|
||
if tarfile.is_tarfile(tar_location): | ||
with tarfile.open(tar_location) as tar: | ||
tar.extractall(path=container_model_dir) | ||
volumes.append(_Volume(container_model_dir, '/opt/ml/model')) | ||
else: | ||
volumes.append(_Volume(model_location, '/opt/ml/model')) | ||
volumes.append(_Volume(model_data_source.get_root_dir(), '/opt/ml/model')) | ||
|
||
return volumes | ||
|
||
|
@@ -368,7 +366,6 @@ def _generate_compose_file(self, command, additional_volumes=None, additional_en | |
'networks': { | ||
'sagemaker-local': {'name': 'sagemaker-local'} | ||
} | ||
|
||
} | ||
|
||
docker_compose_path = os.path.join(self.container_root, DOCKER_COMPOSE_FILENAME) | ||
|
@@ -469,9 +466,15 @@ def _build_optml_volumes(self, host, subdirs): | |
|
||
return volumes | ||
|
||
def _cleanup(self): | ||
# we don't need to cleanup anything at the moment | ||
pass | ||
def _cleanup(self, dirs_to_delete=None): | ||
if dirs_to_delete: | ||
for d in dirs_to_delete: | ||
_delete_tree(d) | ||
|
||
# Free the container config files. | ||
for host in self.hosts: | ||
container_config_path = os.path.join(self.container_root, host) | ||
_delete_tree(container_config_path) | ||
|
||
|
||
class _HostingContainer(Thread): | ||
|
@@ -610,7 +613,7 @@ def _aws_credentials(session): | |
'AWS_SECRET_ACCESS_KEY=%s' % (str(secret_key)) | ||
] | ||
elif not _aws_credentials_available_in_metadata_service(): | ||
logger.warn("Using the short-lived AWS credentials found in session. They might expire while running.") | ||
logger.warning("Using the short-lived AWS credentials found in session. They might expire while running.") | ||
return [ | ||
'AWS_ACCESS_KEY_ID=%s' % (str(access_key)), | ||
'AWS_SECRET_ACCESS_KEY=%s' % (str(secret_key)), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👌