Skip to content

feature: integrate amtviz for visualization of tuning jobs #5044

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/sagemaker/amtviz/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: MIT-0

# Permission is hereby granted, free of charge, to any person obtaining a copy of this
# software and associated documentation files (the "Software"), to deal in the Software
# without restriction, including without limitation the rights to use, copy, modify,
# merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
# PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

from sagemaker.amtviz.visualization import visualize_tuning_job
__all__ = ['visualize_tuning_job']
185 changes: 185 additions & 0 deletions src/sagemaker/amtviz/job_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: MIT-0

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
# IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

from datetime import datetime, timedelta
from typing import Callable, List, Optional, Tuple, Dict, Any
import hashlib
import os
from pathlib import Path

import pandas as pd
import numpy as np
import boto3
import logging

logger = logging.getLogger(__name__)

cw = boto3.client("cloudwatch")
sm = boto3.client("sagemaker")


def disk_cache(outer: Callable) -> Callable:
"""A decorator that implements disk-based caching for CloudWatch metrics data.

This decorator caches the output of the wrapped function to disk in JSON Lines format.
It creates a cache key using MD5 hash of the function arguments and stores the data
in the user's home directory under .amtviz/cw_metrics_cache/.

Args:
outer (Callable): The function to be wrapped. Must return a pandas DataFrame
containing CloudWatch metrics data.

Returns:
Callable: A wrapper function that implements the caching logic.
"""

def inner(*args: Any, **kwargs: Any) -> pd.DataFrame:
key_input = str(args) + str(kwargs)
# nosec b303 - Not used for cryptography, but to create lookup key
key = hashlib.md5(key_input.encode("utf-8")).hexdigest()
cache_dir = Path.home().joinpath(".amtviz/cw_metrics_cache")
fn = f"{cache_dir}/req_{key}.jsonl.gz"
if Path(fn).exists():
try:
df = pd.read_json(fn, lines=True)
logger.debug("H", end="")
df["ts"] = pd.to_datetime(df["ts"])
df["ts"] = df["ts"].dt.tz_localize(None)
df["rel_ts"] = pd.to_datetime(df["rel_ts"]) # pyright: ignore [reportIndexIssue, reportOptionalSubscript]
df["rel_ts"] = df["rel_ts"].dt.tz_localize(None)
return df
except KeyError:
# Empty file leads to empty df, hence no df['ts'] possible
pass
# nosec b110 - doesn't matter why we could not load it.
except BaseException as e:
logger.error("\nException", type(e), e)
pass # continue with calling the outer function

logger.debug("M", end="")
df = outer(*args, **kwargs)
assert isinstance(df, pd.DataFrame), "Only caching Pandas DataFrames."

os.makedirs(cache_dir, exist_ok=True)
df.to_json(fn, orient="records", date_format="iso", lines=True)

return df

return inner


def _metric_data_query_tpl(metric_name: str, dim_name: str, dim_value: str) -> Dict[str, Any]:
return {
"Id": metric_name.lower().replace(":", "_").replace("-", "_"),
"MetricStat": {
"Stat": "Average",
"Metric": {
"Namespace": "/aws/sagemaker/TrainingJobs",
"MetricName": metric_name,
"Dimensions": [
{"Name": dim_name, "Value": dim_value},
],
},
"Period": 60,
},
"ReturnData": True,
}


def _get_metric_data(
queries: List[Dict[str, Any]],
start_time: datetime,
end_time: datetime
) -> pd.DataFrame:
start_time = start_time - timedelta(hours=1)
end_time = end_time + timedelta(hours=1)
response = cw.get_metric_data(MetricDataQueries=queries, StartTime=start_time, EndTime=end_time)

df = pd.DataFrame()
if "MetricDataResults" not in response:
return df

for metric_data in response["MetricDataResults"]:
values = metric_data["Values"]
ts = np.array(metric_data["Timestamps"], dtype=np.datetime64)
labels = [metric_data["Label"]] * len(values)

df = pd.concat([df, pd.DataFrame({"value": values, "ts": ts, "label": labels})])

# We now calculate the relative time based on the first actual observed
# time stamps, not the potentially start time that we used to scope our CW
# API call. The difference could be for example startup times or waiting
# for Spot.
if not df.empty:
df["rel_ts"] = datetime.fromtimestamp(1) + (df["ts"] - df["ts"].min()) # pyright: ignore
return df


@disk_cache
def _collect_metrics(
dimensions: List[Tuple[str, str]],
start_time: datetime,
end_time: Optional[datetime]
) -> pd.DataFrame:

df = pd.DataFrame()
for dim_name, dim_value in dimensions:
response = cw.list_metrics(
Namespace="/aws/sagemaker/TrainingJobs",
Dimensions=[
{"Name": dim_name, "Value": dim_value},
],
)
if not response["Metrics"]:
continue
metric_names = [metric["MetricName"] for metric in response["Metrics"]]
if not metric_names:
# No metric data yet, or not any longer, because the data were aged out
continue
metric_data_queries = [
_metric_data_query_tpl(metric_name, dim_name, dim_value) for metric_name in metric_names
]
df = pd.concat([df, _get_metric_data(metric_data_queries, start_time, end_time)])

return df


def get_cw_job_metrics(
job_name: str,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None
) -> pd.DataFrame:
"""Retrieves CloudWatch metrics for a SageMaker training job.

Args:
job_name (str): Name of the SageMaker training job.
start_time (datetime, optional): Start time for metrics collection.
Defaults to now - 4 hours.
end_time (datetime, optional): End time for metrics collection.
Defaults to start_time + 4 hours.

Returns:
pd.DataFrame: Metrics data with columns for value, timestamp, and metric name.
Results are cached to disk for improved performance.
"""
dimensions = [
("TrainingJobName", job_name),
("Host", job_name + "/algo-1"),
]
# If not given, use reasonable defaults for start and end time
start_time = start_time or datetime.now() - timedelta(hours=4)
end_time = end_time or start_time + timedelta(hours=4)
return _collect_metrics(dimensions, start_time, end_time)
Loading
Loading