@@ -743,6 +743,235 @@ def _get_regional_property(
743
743
return alias_value
744
744
745
745
746
+ class ModelAccessConfig (JumpStartDataHolderType ):
747
+ """Data class of model access config that mirrors CreateModel API."""
748
+
749
+ __slots__ = ["accept_eula" ]
750
+
751
+ def __init__ (self , spec : Dict [str , Any ]):
752
+ """Initializes a ModelAccessConfig object.
753
+
754
+ Args:
755
+ spec (Dict[str, Any]): Dictionary representation of data source.
756
+ """
757
+ self .from_json (spec )
758
+
759
+ def from_json (self , json_obj : Dict [str , Any ]) -> None :
760
+ """Sets fields in object based on json.
761
+
762
+ Args:
763
+ json_obj (Dict[str, Any]): Dictionary representation of data source.
764
+ """
765
+ self .accept_eula : bool = json_obj ["accept_eula" ]
766
+
767
+ def to_json (self ) -> Dict [str , Any ]:
768
+ """Returns json representation of ModelAccessConfig object."""
769
+ json_obj = {att : getattr (self , att ) for att in self .__slots__ if hasattr (self , att )}
770
+ return json_obj
771
+
772
+
773
+ class HubAccessConfig (JumpStartDataHolderType ):
774
+ """Data class of model access config that mirrors CreateModel API."""
775
+
776
+ __slots__ = ["hub_content_arn" ]
777
+
778
+ def __init__ (self , spec : Dict [str , Any ]):
779
+ """Initializes a HubAccessConfig object.
780
+
781
+ Args:
782
+ spec (Dict[str, Any]): Dictionary representation of data source.
783
+ """
784
+ self .from_json (spec )
785
+
786
+ def from_json (self , json_obj : Dict [str , Any ]) -> None :
787
+ """Sets fields in object based on json.
788
+
789
+ Args:
790
+ json_obj (Dict[str, Any]): Dictionary representation of data source.
791
+ """
792
+ self .hub_content_arn : bool = json_obj ["accept_eula" ]
793
+
794
+ def to_json (self ) -> Dict [str , Any ]:
795
+ """Returns json representation of ModelAccessConfig object."""
796
+ json_obj = {att : getattr (self , att ) for att in self .__slots__ if hasattr (self , att )}
797
+ return json_obj
798
+
799
+
800
+ class S3DataSource (JumpStartDataHolderType ):
801
+ """Data class of S3 data source that mirrors CreateModel API."""
802
+
803
+ __slots__ = [
804
+ "compression_type" ,
805
+ "s3_data_type" ,
806
+ "s3_uri" ,
807
+ "model_access_config" ,
808
+ "hub_access_config" ,
809
+ ]
810
+
811
+ def __init__ (self , spec : Dict [str , Any ]):
812
+ """Initializes a S3DataSource object.
813
+
814
+ Args:
815
+ spec (Dict[str, Any]): Dictionary representation of data source.
816
+ """
817
+ self .from_json (spec )
818
+
819
+ def from_json (self , json_obj : Dict [str , Any ]) -> None :
820
+ """Sets fields in object based on json.
821
+
822
+ Args:
823
+ json_obj (Dict[str, Any]): Dictionary representation of data source.
824
+ """
825
+ self .compression_type : str = json_obj ["compression_type" ]
826
+ self .s3_data_type : str = json_obj ["s3_data_type" ]
827
+ self .s3_uri : str = json_obj ["s3_uri" ]
828
+ self .model_access_config : ModelAccessConfig = (
829
+ ModelAccessConfig (json_obj ["model_access_config" ])
830
+ if json_obj .get ("model_access_config" )
831
+ else None
832
+ )
833
+ self .hub_access_config : HubAccessConfig = (
834
+ HubAccessConfig (json_obj ["hub_access_config" ])
835
+ if json_obj .get ("hub_access_config" )
836
+ else None
837
+ )
838
+
839
+ def to_json (self ) -> Dict [str , Any ]:
840
+ """Returns json representation of S3DataSource object."""
841
+ json_obj = {}
842
+ for att in self .__slots__ :
843
+ if hasattr (self , att ):
844
+ cur_val = getattr (self , att )
845
+ if issubclass (type (cur_val ), JumpStartDataHolderType ):
846
+ json_obj [att ] = cur_val .to_json ()
847
+ else :
848
+ json_obj [att ] = cur_val
849
+ return json_obj
850
+
851
+
852
+ class AdditionalModelDataSource (JumpStartDataHolderType ):
853
+ """Data class of additional model data source mirrors Hosting API."""
854
+
855
+ __slots__ = ["channel_name" , "s3_data_source" ]
856
+
857
+ def __init__ (self , spec : Dict [str , Any ]):
858
+ """Initializes a AdditionalModelDataSource object.
859
+
860
+ Args:
861
+ spec (Dict[str, Any]): Dictionary representation of data source.
862
+ """
863
+ self .from_json (spec )
864
+
865
+ def from_json (self , json_obj : Dict [str , Any ]) -> None :
866
+ """Sets fields in object based on json.
867
+
868
+ Args:
869
+ json_obj (Dict[str, Any]): Dictionary representation of data source.
870
+ """
871
+ self .channel_name : str = json_obj ["channel_name" ]
872
+ self .s3_data_source : S3DataSource = S3DataSource (json_obj ["s3_data_source" ])
873
+
874
+ def to_json (self ) -> Dict [str , Any ]:
875
+ """Returns json representation of AdditionalModelDataSource object."""
876
+ json_obj = {}
877
+ for att in self .__slots__ :
878
+ if hasattr (self , att ):
879
+ cur_val = getattr (self , att )
880
+ if issubclass (type (cur_val ), JumpStartDataHolderType ):
881
+ json_obj [att ] = cur_val .to_json ()
882
+ else :
883
+ json_obj [att ] = cur_val
884
+ return json_obj
885
+
886
+
887
+ class JumpStartModelDataSource (JumpStartDataHolderType ):
888
+ """Data class JumpStart additional model data source."""
889
+
890
+ __slots__ = ["version" , "additional_model_data_source" ]
891
+
892
+ def __init__ (self , spec : Dict [str , Any ]):
893
+ """Initializes a JumpStartModelDataSource object.
894
+
895
+ Args:
896
+ spec (Dict[str, Any]): Dictionary representation of data source.
897
+ """
898
+ self .from_json (spec )
899
+
900
+ def from_json (self , json_obj : Dict [str , Any ]) -> None :
901
+ """Sets fields in object based on json.
902
+
903
+ Args:
904
+ json_obj (Dict[str, Any]): Dictionary representation of data source.
905
+ """
906
+ self .version : str = json_obj ["artifact_version" ]
907
+ self .additional_model_data_source : AdditionalModelDataSource = AdditionalModelDataSource (
908
+ json_obj
909
+ )
910
+
911
+ def to_json (self ) -> Dict [str , Any ]:
912
+ """Returns json representation of JumpStartModelDataSource object."""
913
+ json_obj = {}
914
+ for att in self .__slots__ :
915
+ if hasattr (self , att ):
916
+ cur_val = getattr (self , att )
917
+ if issubclass (type (cur_val ), JumpStartDataHolderType ):
918
+ json_obj [att ] = cur_val .to_json ()
919
+ else :
920
+ json_obj [att ] = cur_val
921
+ return json_obj
922
+
923
+
924
+ class JumpStartAdditionalDataSources (JumpStartDataHolderType ):
925
+ """Data class of additional data sources."""
926
+
927
+ __slots__ = ["speculative_decoding" , "scripts" ]
928
+
929
+ def __init__ (self , spec : Dict [str , Any ]):
930
+ """Initializes a AdditionalDataSources object.
931
+
932
+ Args:
933
+ spec (Dict[str, Any]): Dictionary representation of data source.
934
+ """
935
+ self .from_json (spec )
936
+
937
+ def from_json (self , json_obj : Dict [str , Any ]) -> None :
938
+ """Sets fields in object based on json.
939
+
940
+ Args:
941
+ json_obj (Dict[str, Any]): Dictionary representation of data source.
942
+ """
943
+ self .speculative_decoding : Optional [List [JumpStartModelDataSource ]] = (
944
+ [
945
+ JumpStartModelDataSource (data_source )
946
+ for data_source in json_obj ["speculative_decoding" ]
947
+ ]
948
+ if json_obj .get ("speculative_decoding" )
949
+ else None
950
+ )
951
+ self .scripts : Optional [List [JumpStartModelDataSource ]] = (
952
+ [JumpStartModelDataSource (data_source ) for data_source in json_obj ["scripts" ]]
953
+ if json_obj .get ("scripts" )
954
+ else None
955
+ )
956
+
957
+ def to_json (self ) -> Dict [str , Any ]:
958
+ """Returns json representation of AdditionalDataSources object."""
959
+ json_obj = {}
960
+ for att in self .__slots__ :
961
+ if hasattr (self , att ):
962
+ cur_val = getattr (self , att )
963
+ if isinstance (cur_val , list ):
964
+ json_obj [att ] = []
965
+ for obj in cur_val :
966
+ if issubclass (type (obj ), JumpStartDataHolderType ):
967
+ json_obj [att ].append (obj .to_json ())
968
+ else :
969
+ json_obj [att ].append (obj )
970
+ else :
971
+ json_obj [att ] = cur_val
972
+ return json_obj
973
+
974
+
746
975
class JumpStartBenchmarkStat (JumpStartDataHolderType ):
747
976
"""Data class JumpStart benchmark stat."""
748
977
@@ -857,6 +1086,7 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType):
857
1086
"default_payloads" ,
858
1087
"gated_bucket" ,
859
1088
"model_subscription_link" ,
1089
+ "hosting_additional_data_sources" ,
860
1090
]
861
1091
862
1092
def __init__ (self , fields : Dict [str , Any ]):
@@ -962,6 +1192,11 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
962
1192
if json_obj .get ("hosting_instance_type_variants" )
963
1193
else None
964
1194
)
1195
+ self .hosting_additional_data_sources : Optional [JumpStartAdditionalDataSources ] = (
1196
+ JumpStartAdditionalDataSources (json_obj ["hosting_additional_data_sources" ])
1197
+ if json_obj .get ("hosting_additional_data_sources" )
1198
+ else None
1199
+ )
965
1200
966
1201
if self .training_supported :
967
1202
self .training_ecr_specs : Optional [JumpStartECRSpecs ] = (
0 commit comments