|
20 | 20 | from typing import List, Union
|
21 | 21 | from sagemaker import image_uris
|
22 | 22 | from sagemaker.inputs import TrainingInput
|
23 |
| -from sagemaker.s3 import ( |
24 |
| - S3Downloader, |
25 |
| - S3Uploader, |
26 |
| -) |
27 | 23 | from sagemaker.estimator import EstimatorBase
|
28 | 24 | from sagemaker.sklearn.estimator import SKLearn
|
29 | 25 | from sagemaker.workflow.entities import RequestType
|
|
35 | 31 | Step,
|
36 | 32 | ConfigurableRetryStep,
|
37 | 33 | )
|
| 34 | +from sagemaker.utils import _save_model, download_file_from_url |
38 | 35 | from sagemaker.workflow.retry import RetryPolicy
|
39 | 36 |
|
40 | 37 | FRAMEWORK_VERSION = "0.23-1"
|
@@ -203,40 +200,36 @@ def _establish_source_dir(self):
|
203 | 200 | self._entry_point = self._entry_point_basename
|
204 | 201 |
|
205 | 202 | def _inject_repack_script(self):
|
206 |
| - """Injects the _repack_model.py script where it belongs. |
| 203 | + """Injects the _repack_model.py script into S3 or local source directory. |
207 | 204 |
|
208 | 205 | If the source_dir is an S3 path:
|
209 | 206 | 1) downloads the source_dir tar.gz
|
210 |
| - 2) copies the _repack_model.py script where it belongs |
211 |
| - 3) uploads the mutated source_dir |
| 207 | + 2) extracts it |
| 208 | + 3) copies the _repack_model.py script into the extracted directory |
| 209 | + 4) rezips the directory |
| 210 | + 5) overwrites the S3 source_dir with the new tar.gz |
212 | 211 |
|
213 | 212 | If the source_dir is a local path:
|
214 | 213 | 1) copies the _repack_model.py script into the source dir
|
215 | 214 | """
|
216 | 215 | fname = os.path.join(os.path.dirname(__file__), REPACK_SCRIPT)
|
217 | 216 | if self._source_dir.lower().startswith("s3://"):
|
218 | 217 | with tempfile.TemporaryDirectory() as tmp:
|
219 |
| - local_path = os.path.join(tmp, "local.tar.gz") |
220 |
| - |
221 |
| - S3Downloader.download( |
222 |
| - s3_uri=self._source_dir, |
223 |
| - local_path=local_path, |
224 |
| - sagemaker_session=self.sagemaker_session, |
225 |
| - ) |
226 |
| - |
227 |
| - src_dir = os.path.join(tmp, "src") |
228 |
| - with tarfile.open(name=local_path, mode="r:gz") as tf: |
229 |
| - tf.extractall(path=src_dir) |
230 |
| - |
231 |
| - shutil.copy2(fname, os.path.join(src_dir, REPACK_SCRIPT)) |
232 |
| - with tarfile.open(name=local_path, mode="w:gz") as tf: |
233 |
| - tf.add(src_dir, arcname=".") |
234 |
| - |
235 |
| - S3Uploader.upload( |
236 |
| - local_path=local_path, |
237 |
| - desired_s3_uri=self._source_dir, |
238 |
| - sagemaker_session=self.sagemaker_session, |
239 |
| - ) |
| 218 | + targz_contents_dir = os.path.join(tmp, "extracted") |
| 219 | + |
| 220 | + old_targz_path = os.path.join(tmp, "old.tar.gz") |
| 221 | + download_file_from_url(self._source_dir, old_targz_path, self.sagemaker_session) |
| 222 | + |
| 223 | + with tarfile.open(name=old_targz_path, mode="r:gz") as t: |
| 224 | + t.extractall(path=targz_contents_dir) |
| 225 | + |
| 226 | + shutil.copy2(fname, os.path.join(targz_contents_dir, REPACK_SCRIPT)) |
| 227 | + |
| 228 | + new_targz_path = os.path.join(tmp, "new.tar.gz") |
| 229 | + with tarfile.open(new_targz_path, mode="w:gz") as t: |
| 230 | + t.add(targz_contents_dir, arcname=os.path.sep) |
| 231 | + |
| 232 | + _save_model(self._source_dir, new_targz_path, self.sagemaker_session, kms_key=None) |
240 | 233 | else:
|
241 | 234 | shutil.copy2(fname, os.path.join(self._source_dir, REPACK_SCRIPT))
|
242 | 235 |
|
|
0 commit comments