15
15
16
16
from abc import ABC , abstractmethod
17
17
from inspect import signature
18
- from typing import Any , Callable , Dict , Generic , List , OrderedDict , TypeVar , Union
18
+ from typing import Any , Callable , Dict , Generic , List , OrderedDict , TypeVar , Union , Optional
19
19
20
20
import attr
21
21
from pyspark .sql import DataFrame , SparkSession
24
24
CSVDataSource ,
25
25
FeatureGroupDataSource ,
26
26
ParquetDataSource ,
27
+ BaseDataSource ,
28
+ PySparkDataSource ,
27
29
)
28
30
from sagemaker .feature_store .feature_processor ._feature_processor_config import (
29
31
FeatureProcessorConfig ,
@@ -119,6 +121,9 @@ def provide_input_args(
119
121
"""
120
122
udf_parameter_names = list (signature (udf ).parameters .keys ())
121
123
udf_input_names = self ._get_input_parameters (udf_parameter_names )
124
+ udf_params = self .params_loader .get_parameter_args (fp_config ).get (
125
+ self .PARAMS_ARG_NAME , None
126
+ )
122
127
123
128
if len (udf_input_names ) == 0 :
124
129
raise ValueError ("Expected at least one input to the user defined function." )
@@ -130,7 +135,7 @@ def provide_input_args(
130
135
)
131
136
132
137
return OrderedDict (
133
- (input_name , self ._load_data_frame (input_uri ))
138
+ (input_name , self ._load_data_frame (data_source = input_uri , params = udf_params ))
134
139
for (input_name , input_uri ) in zip (udf_input_names , fp_config .inputs )
135
140
)
136
141
@@ -189,13 +194,19 @@ def _get_input_parameters(self, udf_parameter_names: List[str]) -> List[str]:
189
194
190
195
def _load_data_frame (
191
196
self ,
192
- data_source : Union [FeatureGroupDataSource , CSVDataSource , ParquetDataSource ],
197
+ data_source : Union [
198
+ FeatureGroupDataSource , CSVDataSource , ParquetDataSource , BaseDataSource
199
+ ],
200
+ params : Optional [Dict [str , Union [str , Dict ]]] = None ,
193
201
) -> DataFrame :
194
202
"""Given a data source definition, load the data as a Spark DataFrame.
195
203
196
204
Args:
197
- data_source (Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource]):
198
- A user specified data source from the feature_processor decorator's parameters.
205
+ data_source (Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource,
206
+ BaseDataSource]): A user specified data source from the feature_processor
207
+ decorator's parameters.
208
+ params (Optional[Dict[str, Union[str, Dict]]]): Parameters provided to the
209
+ feature_processor decorator.
199
210
200
211
Returns:
201
212
DataFrame: The contents of the data source as a Spark DataFrame.
@@ -206,6 +217,13 @@ def _load_data_frame(
206
217
if isinstance (data_source , FeatureGroupDataSource ):
207
218
return self .input_loader .load_from_feature_group (data_source )
208
219
220
+ if isinstance (data_source , PySparkDataSource ):
221
+ spark_session = self .spark_session_factory .spark_session
222
+ return data_source .read_data (spark = spark_session , params = params )
223
+
224
+ if isinstance (data_source , BaseDataSource ):
225
+ return data_source .read_data (params = params )
226
+
209
227
raise ValueError (f"Unknown data source type: { type (data_source )} " )
210
228
211
229
def _has_param (self , udf : Callable , name : str ) -> bool :
0 commit comments