Skip to content

Commit a686e6d

Browse files
author
Uemit Yoldas
committed
fix: codestyle, type hints, license, and docstrings
1 parent 09aa5b0 commit a686e6d

File tree

6 files changed

+201
-107
lines changed

6 files changed

+201
-107
lines changed

src/sagemaker/amtviz/__init__.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2-
# SPDX-License-Identifier: MIT-0
3-
4-
# Permission is hereby granted, free of charge, to any person obtaining a copy of this
5-
# software and associated documentation files (the "Software"), to deal in the Software
6-
# without restriction, including without limitation the rights to use, copy, modify,
7-
# merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
8-
# permit persons to whom the Software is furnished to do so.
9-
10-
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
11-
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
12-
# PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
13-
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
14-
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Placeholder docstring"""
14+
from __future__ import absolute_import
1515

1616
from sagemaker.amtviz.visualization import visualize_tuning_job
17-
__all__ = ['visualize_tuning_job']
17+
__all__ = ['visualize_tuning_job']

src/sagemaker/amtviz/job_metrics.py

+28-27
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,28 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2-
# SPDX-License-Identifier: MIT-0
3-
4-
# Permission is hereby granted, free of charge, to any person obtaining a copy
5-
# of this software and associated documentation files (the "Software"), to deal
6-
# in the Software without restriction, including without limitation the rights
7-
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8-
# copies of the Software, and to permit persons to whom the Software is
9-
# furnished to do so.
10-
11-
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
12-
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
13-
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
14-
# AUTHORS OR COPYRIGHT OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
15-
# IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Helper functions to retrieve job metrics from CloudWatch."""
14+
from __future__ import absolute_import
1615

1716
from datetime import datetime, timedelta
1817
from typing import Callable, List, Optional, Tuple, Dict, Any
1918
import hashlib
2019
import os
2120
from pathlib import Path
2221

22+
import logging
2323
import pandas as pd
2424
import numpy as np
2525
import boto3
26-
import logging
2726

2827
logger = logging.getLogger(__name__)
2928

@@ -58,16 +57,16 @@ def inner(*args: Any, **kwargs: Any) -> pd.DataFrame:
5857
logger.debug("H", end="")
5958
df["ts"] = pd.to_datetime(df["ts"])
6059
df["ts"] = df["ts"].dt.tz_localize(None)
61-
df["rel_ts"] = pd.to_datetime(df["rel_ts"]) # pyright: ignore [reportIndexIssue, reportOptionalSubscript]
60+
# pyright: ignore [reportIndexIssue, reportOptionalSubscript]
61+
df["rel_ts"] = pd.to_datetime(df["rel_ts"])
6262
df["rel_ts"] = df["rel_ts"].dt.tz_localize(None)
6363
return df
6464
except KeyError:
6565
# Empty file leads to empty df, hence no df['ts'] possible
6666
pass
6767
# nosec b110 - doesn't matter why we could not load it.
6868
except BaseException as e:
69-
logger.error("\nException", type(e), e)
70-
pass # continue with calling the outer function
69+
logger.error("\nException: %s - %s", type(e), e)
7170

7271
logger.debug("M", end="")
7372
df = outer(*args, **kwargs)
@@ -82,6 +81,7 @@ def inner(*args: Any, **kwargs: Any) -> pd.DataFrame:
8281

8382

8483
def _metric_data_query_tpl(metric_name: str, dim_name: str, dim_value: str) -> Dict[str, Any]:
84+
"""Returns a CloudWatch metric data query template."""
8585
return {
8686
"Id": metric_name.lower().replace(":", "_").replace("-", "_"),
8787
"MetricStat": {
@@ -100,18 +100,19 @@ def _metric_data_query_tpl(metric_name: str, dim_name: str, dim_value: str) -> D
100100

101101

102102
def _get_metric_data(
103-
queries: List[Dict[str, Any]],
104-
start_time: datetime,
103+
queries: List[Dict[str, Any]],
104+
start_time: datetime,
105105
end_time: datetime
106106
) -> pd.DataFrame:
107+
"""Fetches CloudWatch metrics between timestamps and returns a DataFrame with selected columns."""
107108
start_time = start_time - timedelta(hours=1)
108109
end_time = end_time + timedelta(hours=1)
109110
response = cw.get_metric_data(MetricDataQueries=queries, StartTime=start_time, EndTime=end_time)
110111

111112
df = pd.DataFrame()
112113
if "MetricDataResults" not in response:
113114
return df
114-
115+
115116
for metric_data in response["MetricDataResults"]:
116117
values = metric_data["Values"]
117118
ts = np.array(metric_data["Timestamps"], dtype=np.datetime64)
@@ -130,11 +131,11 @@ def _get_metric_data(
130131

131132
@disk_cache
132133
def _collect_metrics(
133-
dimensions: List[Tuple[str, str]],
134-
start_time: datetime,
134+
dimensions: List[Tuple[str, str]],
135+
start_time: datetime,
135136
end_time: Optional[datetime]
136137
) -> pd.DataFrame:
137-
138+
"""Collects SageMaker training job metrics from CloudWatch based on given dimensions and time range."""
138139
df = pd.DataFrame()
139140
for dim_name, dim_value in dimensions:
140141
response = cw.list_metrics(
@@ -158,8 +159,8 @@ def _collect_metrics(
158159

159160

160161
def get_cw_job_metrics(
161-
job_name: str,
162-
start_time: Optional[datetime] = None,
162+
job_name: str,
163+
start_time: Optional[datetime] = None,
163164
end_time: Optional[datetime] = None
164165
) -> pd.DataFrame:
165166
"""Retrieves CloudWatch metrics for a SageMaker training job.
@@ -182,4 +183,4 @@ def get_cw_job_metrics(
182183
# If not given, use reasonable defaults for start and end time
183184
start_time = start_time or datetime.now() - timedelta(hours=4)
184185
end_time = end_time or start_time + timedelta(hours=4)
185-
return _collect_metrics(dimensions, start_time, end_time)
186+
return _collect_metrics(dimensions, start_time, end_time)

0 commit comments

Comments
 (0)