|
74 | 74 | "sagemaker_submit_directory": json.dumps("file:///tmp/code"),
|
75 | 75 | }
|
76 | 76 |
|
| 77 | +ENVIRONMENT = {"MYVAR": "HELLO_WORLD"} |
| 78 | + |
77 | 79 |
|
78 | 80 | @pytest.fixture()
|
79 | 81 | def sagemaker_session():
|
@@ -352,7 +354,7 @@ def test_train(
|
352 | 354 | "local", instance_count, image, sagemaker_session=sagemaker_session
|
353 | 355 | )
|
354 | 356 | sagemaker_container.train(
|
355 |
| - INPUT_DATA_CONFIG, OUTPUT_DATA_CONFIG, HYPERPARAMETERS, TRAINING_JOB_NAME |
| 357 | + INPUT_DATA_CONFIG, OUTPUT_DATA_CONFIG, HYPERPARAMETERS, ENVIRONMENT, TRAINING_JOB_NAME |
356 | 358 | )
|
357 | 359 |
|
358 | 360 | docker_compose_file = os.path.join(
|
@@ -415,7 +417,7 @@ def test_train_with_hyperparameters_without_job_name(
|
415 | 417 | "local", instance_count, image, sagemaker_session=sagemaker_session
|
416 | 418 | )
|
417 | 419 | sagemaker_container.train(
|
418 |
| - INPUT_DATA_CONFIG, OUTPUT_DATA_CONFIG, HYPERPARAMETERS, TRAINING_JOB_NAME |
| 420 | + INPUT_DATA_CONFIG, OUTPUT_DATA_CONFIG, HYPERPARAMETERS, ENVIRONMENT, TRAINING_JOB_NAME |
419 | 421 | )
|
420 | 422 |
|
421 | 423 | docker_compose_file = os.path.join(
|
@@ -456,7 +458,11 @@ def test_train_error(
|
456 | 458 |
|
457 | 459 | with pytest.raises(RuntimeError) as e:
|
458 | 460 | sagemaker_container.train(
|
459 |
| - INPUT_DATA_CONFIG, OUTPUT_DATA_CONFIG, HYPERPARAMETERS, TRAINING_JOB_NAME |
| 461 | + INPUT_DATA_CONFIG, |
| 462 | + OUTPUT_DATA_CONFIG, |
| 463 | + HYPERPARAMETERS, |
| 464 | + ENVIRONMENT, |
| 465 | + TRAINING_JOB_NAME, |
460 | 466 | )
|
461 | 467 |
|
462 | 468 | assert "this is expected" in str(e)
|
@@ -486,7 +492,11 @@ def test_train_local_code(get_data_source_instance, tmpdir, sagemaker_session):
|
486 | 492 | )
|
487 | 493 |
|
488 | 494 | sagemaker_container.train(
|
489 |
| - INPUT_DATA_CONFIG, OUTPUT_DATA_CONFIG, LOCAL_CODE_HYPERPARAMETERS, TRAINING_JOB_NAME |
| 495 | + INPUT_DATA_CONFIG, |
| 496 | + OUTPUT_DATA_CONFIG, |
| 497 | + LOCAL_CODE_HYPERPARAMETERS, |
| 498 | + ENVIRONMENT, |
| 499 | + TRAINING_JOB_NAME, |
490 | 500 | )
|
491 | 501 |
|
492 | 502 | docker_compose_file = os.path.join(
|
@@ -538,7 +548,7 @@ def test_train_local_intermediate_output(get_data_source_instance, tmpdir, sagem
|
538 | 548 | hyperparameters = {"sagemaker_s3_output": output_path}
|
539 | 549 |
|
540 | 550 | sagemaker_container.train(
|
541 |
| - INPUT_DATA_CONFIG, output_data_config, hyperparameters, TRAINING_JOB_NAME |
| 551 | + INPUT_DATA_CONFIG, output_data_config, hyperparameters, ENVIRONMENT, TRAINING_JOB_NAME |
542 | 552 | )
|
543 | 553 |
|
544 | 554 | docker_compose_file = os.path.join(
|
|
0 commit comments