Skip to content

Commit f646180

Browse files
authored
feat: feature-processor extra data sources support (#4155)
1 parent f631e41 commit f646180

25 files changed

+635
-74
lines changed

src/sagemaker/feature_store/feature_processor/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
CSVDataSource,
1818
FeatureGroupDataSource,
1919
ParquetDataSource,
20+
BaseDataSource,
21+
PySparkDataSource,
2022
)
2123
from sagemaker.feature_store.feature_processor._exceptions import ( # noqa: F401
2224
IngestionError,

src/sagemaker/feature_store/feature_processor/_config_uploader.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
_JobSettings,
3232
RUNTIME_SCRIPTS_CHANNEL_NAME,
3333
REMOTE_FUNCTION_WORKSPACE,
34-
SPARK_CONF_WORKSPACE,
34+
SPARK_CONF_CHANNEL_NAME,
3535
_prepare_and_upload_spark_dependent_files,
3636
)
3737
from sagemaker.remote_function.runtime_environment.runtime_environment_manager import (
@@ -99,7 +99,7 @@ def prepare_step_input_channel_for_spark_mode(
9999
)
100100

101101
if config_file_s3_uri:
102-
input_data_config[SPARK_CONF_WORKSPACE] = TrainingInput(
102+
input_data_config[SPARK_CONF_CHANNEL_NAME] = TrainingInput(
103103
s3_data=config_file_s3_uri,
104104
s3_data_type="S3Prefix",
105105
distribution=S3_DATA_DISTRIBUTION_TYPE,

src/sagemaker/feature_store/feature_processor/_data_source.py

+68-2
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,76 @@
1313
"""Contains classes to define input data sources."""
1414
from __future__ import absolute_import
1515

16-
from typing import Optional
16+
from typing import Optional, Dict, Union, TypeVar, Generic
17+
from abc import ABC, abstractmethod
18+
from pyspark.sql import DataFrame, SparkSession
19+
1720

1821
import attr
1922

23+
T = TypeVar("T")
24+
25+
26+
@attr.s
27+
class BaseDataSource(Generic[T], ABC):
28+
"""Abstract base class for feature processor data sources.
29+
30+
Provides a skeleton for customization requiring the overriding of the method to read data from
31+
data source and return the specified type.
32+
"""
33+
34+
@abstractmethod
35+
def read_data(self, *args, **kwargs) -> T:
36+
"""Read data from data source and return the specified type.
37+
38+
Args:
39+
args: Arguments for reading the data.
40+
kwargs: Keyword argument for reading the data.
41+
Returns:
42+
T: The specified abstraction of data source.
43+
"""
44+
45+
@property
46+
@abstractmethod
47+
def data_source_unique_id(self) -> str:
48+
"""The identifier for the customized feature processor data source.
49+
50+
Returns:
51+
str: The data source unique id.
52+
"""
53+
54+
@property
55+
@abstractmethod
56+
def data_source_name(self) -> str:
57+
"""The name for the customized feature processor data source.
58+
59+
Returns:
60+
str: The data source name.
61+
"""
62+
63+
64+
@attr.s
65+
class PySparkDataSource(BaseDataSource[DataFrame], ABC):
66+
"""Abstract base class for feature processor data sources.
67+
68+
Provides a skeleton for customization requiring the overriding of the method to read data from
69+
data source and return the Spark DataFrame.
70+
"""
71+
72+
@abstractmethod
73+
def read_data(
74+
self, spark: SparkSession, params: Optional[Dict[str, Union[str, Dict]]] = None
75+
) -> DataFrame:
76+
"""Read data from data source and convert the data to Spark DataFrame.
77+
78+
Args:
79+
spark (SparkSession): The Spark session to read the data.
80+
params (Optional[Dict[str, Union[str, Dict]]]): Parameters provided to the
81+
feature_processor decorator.
82+
Returns:
83+
DataFrame: The Spark DataFrame as an abstraction on the data source.
84+
"""
85+
2086

2187
@attr.s
2288
class FeatureGroupDataSource:
@@ -26,7 +92,7 @@ class FeatureGroupDataSource:
2692
name (str): The name or ARN of the Feature Group.
2793
input_start_offset (Optional[str], optional): A duration specified as a string in the
2894
format '<no> <unit>' where 'no' is a number and 'unit' is a unit of time in ['hours',
29-
'days', 'weeks', 'months', 'years'] (plural and singluar forms). Inputs contain data
95+
'days', 'weeks', 'months', 'years'] (plural and singular forms). Inputs contain data
3096
with event times no earlier than input_start_offset in the past. Offsets are relative
3197
to the function execution time. If the function is executed by a Schedule, then the
3298
offset is relative to the scheduled start time. Defaults to None.

src/sagemaker/feature_store/feature_processor/_factory.py

+24-8
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""Contains static factory classes to instantiate complex objects for the FeatureProcessor."""
1414
from __future__ import absolute_import
1515

16+
from typing import Dict
1617
from pyspark.sql import DataFrame
1718

1819
from sagemaker.feature_store.feature_processor._enums import FeatureProcessorMode
@@ -41,6 +42,7 @@
4142
InputValidator,
4243
SparkUDFSignatureValidator,
4344
InputOffsetValidator,
45+
BaseDataSourceValidator,
4446
ValidatorChain,
4547
)
4648

@@ -55,6 +57,7 @@ def get_validation_chain(fp_config: FeatureProcessorConfig) -> ValidatorChain:
5557
InputValidator(),
5658
FeatureProcessorArgValidator(),
5759
InputOffsetValidator(),
60+
BaseDataSourceValidator(),
5861
]
5962

6063
mode = fp_config.mode
@@ -85,14 +88,19 @@ def get_udf_wrapper(fp_config: FeatureProcessorConfig) -> UDFWrapper:
8588
mode = fp_config.mode
8689

8790
if FeatureProcessorMode.PYSPARK == mode:
88-
return UDFWrapperFactory._get_spark_udf_wrapper()
91+
return UDFWrapperFactory._get_spark_udf_wrapper(fp_config)
8992

9093
raise ValueError(f"FeatureProcessorMode {mode} is not supported.")
9194

9295
@staticmethod
93-
def _get_spark_udf_wrapper() -> UDFWrapper[DataFrame]:
94-
"""Instantiate a new UDFWrapper for PySpark functions."""
95-
spark_session_factory = UDFWrapperFactory._get_spark_session_factory()
96+
def _get_spark_udf_wrapper(fp_config: FeatureProcessorConfig) -> UDFWrapper[DataFrame]:
97+
"""Instantiate a new UDFWrapper for PySpark functions.
98+
99+
Args:
100+
fp_config (FeatureProcessorConfig): the configuration values for the feature_processor
101+
decorator.
102+
"""
103+
spark_session_factory = UDFWrapperFactory._get_spark_session_factory(fp_config.spark_config)
96104
feature_store_manager_factory = UDFWrapperFactory._get_feature_store_manager_factory()
97105

98106
output_manager = UDFWrapperFactory._get_spark_output_receiver(feature_store_manager_factory)
@@ -131,7 +139,7 @@ def _get_spark_output_receiver(
131139
132140
Args:
133141
feature_store_manager_factory (FeatureStoreManagerFactory): A factory to provide
134-
that provides a FeaturStoreManager that handles data ingestion to a Feature Group.
142+
that provides a FeatureStoreManager that handles data ingestion to a Feature Group.
135143
The factory lazily loads the FeatureStoreManager.
136144
137145
Returns:
@@ -140,10 +148,18 @@ def _get_spark_output_receiver(
140148
return SparkOutputReceiver(feature_store_manager_factory)
141149

142150
@staticmethod
143-
def _get_spark_session_factory() -> SparkSessionFactory:
144-
"""Instantiate a new SparkSessionFactory"""
151+
def _get_spark_session_factory(spark_config: Dict[str, str]) -> SparkSessionFactory:
152+
"""Instantiate a new SparkSessionFactory
153+
154+
Args:
155+
spark_config (Dict[str, str]): The Spark configuration that will be passed to the
156+
initialization of Spark session.
157+
158+
Returns:
159+
SparkSessionFactory: A Spark session factory instance.
160+
"""
145161
environment_helper = EnvironmentHelper()
146-
return SparkSessionFactory(environment_helper)
162+
return SparkSessionFactory(environment_helper, spark_config)
147163

148164
@staticmethod
149165
def _get_feature_store_manager_factory() -> FeatureStoreManagerFactory:

src/sagemaker/feature_store/feature_processor/_feature_processor_config.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
CSVDataSource,
2222
FeatureGroupDataSource,
2323
ParquetDataSource,
24+
BaseDataSource,
2425
)
2526
from sagemaker.feature_store.feature_processor._enums import FeatureProcessorMode
2627

@@ -37,21 +38,27 @@ class FeatureProcessorConfig:
3738
It only serves as an immutable data class.
3839
"""
3940

40-
inputs: Sequence[Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource]] = attr.ib()
41+
inputs: Sequence[
42+
Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource, BaseDataSource]
43+
] = attr.ib()
4144
output: str = attr.ib()
4245
mode: FeatureProcessorMode = attr.ib()
4346
target_stores: Optional[List[str]] = attr.ib()
4447
parameters: Optional[Dict[str, Union[str, Dict]]] = attr.ib()
4548
enable_ingestion: bool = attr.ib()
49+
spark_config: Dict[str, str] = attr.ib()
4650

4751
@staticmethod
4852
def create(
49-
inputs: Sequence[Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource]],
53+
inputs: Sequence[
54+
Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource, BaseDataSource]
55+
],
5056
output: str,
5157
mode: FeatureProcessorMode,
5258
target_stores: Optional[List[str]],
5359
parameters: Optional[Dict[str, Union[str, Dict]]],
5460
enable_ingestion: bool,
61+
spark_config: Dict[str, str],
5562
) -> "FeatureProcessorConfig":
5663
"""Static initializer."""
5764
return FeatureProcessorConfig(
@@ -61,4 +68,5 @@ def create(
6168
target_stores=target_stores,
6269
parameters=parameters,
6370
enable_ingestion=enable_ingestion,
71+
spark_config=spark_config,
6472
)

src/sagemaker/feature_store/feature_processor/_params_loader.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def get_parameter_args(
7272
feature_processor decorator.
7373
7474
Returns:
75-
Dict[str, Union[str, Dict]]: A dictionary containin both user provided
75+
Dict[str, Union[str, Dict]]: A dictionary that contains both user provided
7676
parameters (feature_processor argument) and system parameters.
7777
"""
7878
return {

src/sagemaker/feature_store/feature_processor/_spark_factory.py

+25-12
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from functools import lru_cache
17-
from typing import List, Tuple
17+
from typing import List, Tuple, Dict
1818

1919
import feature_store_pyspark
2020
import feature_store_pyspark.FeatureStoreManager as fsm
@@ -34,14 +34,19 @@ class SparkSessionFactory:
3434
instance throughout the application.
3535
"""
3636

37-
def __init__(self, environment_helper: EnvironmentHelper) -> None:
37+
def __init__(
38+
self, environment_helper: EnvironmentHelper, spark_config: Dict[str, str] = None
39+
) -> None:
3840
"""Initialize the SparkSessionFactory.
3941
4042
Args:
4143
environment_helper (EnvironmentHelper): A helper class to determine the current
4244
execution.
45+
spark_config (Dict[str, str]): The Spark configuration that will be passed to the
46+
initialization of Spark session.
4347
"""
4448
self.environment_helper = environment_helper
49+
self.spark_config = spark_config
4550

4651
@property
4752
@lru_cache()
@@ -106,24 +111,32 @@ def _get_spark_configs(self, is_training_job) -> List[Tuple[str, str]]:
106111
("spark.port.maxRetries", "50"),
107112
]
108113

114+
if self.spark_config:
115+
spark_configs.extend(self.spark_config.items())
116+
109117
if not is_training_job:
118+
fp_spark_jars = feature_store_pyspark.classpath_jars()
119+
fp_spark_packages = [
120+
"org.apache.hadoop:hadoop-aws:3.3.1",
121+
"org.apache.hadoop:hadoop-common:3.3.1",
122+
]
123+
124+
if self.spark_config and "spark.jars" in self.spark_config:
125+
fp_spark_jars.append(self.spark_config.get("spark.jars"))
126+
127+
if self.spark_config and "spark.jars.packages" in self.spark_config:
128+
fp_spark_packages.append(self.spark_config.get("spark.jars.packages"))
129+
110130
spark_configs.extend(
111131
(
112-
(
113-
"spark.jars",
114-
",".join(feature_store_pyspark.classpath_jars()),
115-
),
132+
("spark.jars", ",".join(fp_spark_jars)),
116133
(
117134
"spark.jars.packages",
118-
",".join(
119-
[
120-
"org.apache.hadoop:hadoop-aws:3.3.1",
121-
"org.apache.hadoop:hadoop-common:3.3.1",
122-
]
123-
),
135+
",".join(fp_spark_packages),
124136
),
125137
)
126138
)
139+
127140
return spark_configs
128141

129142
def _get_jsc_hadoop_configs(self) -> List[Tuple[str, str]]:

src/sagemaker/feature_store/feature_processor/_udf_arg_provider.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from abc import ABC, abstractmethod
1717
from inspect import signature
18-
from typing import Any, Callable, Dict, Generic, List, OrderedDict, TypeVar, Union
18+
from typing import Any, Callable, Dict, Generic, List, OrderedDict, TypeVar, Union, Optional
1919

2020
import attr
2121
from pyspark.sql import DataFrame, SparkSession
@@ -24,6 +24,8 @@
2424
CSVDataSource,
2525
FeatureGroupDataSource,
2626
ParquetDataSource,
27+
BaseDataSource,
28+
PySparkDataSource,
2729
)
2830
from sagemaker.feature_store.feature_processor._feature_processor_config import (
2931
FeatureProcessorConfig,
@@ -119,6 +121,9 @@ def provide_input_args(
119121
"""
120122
udf_parameter_names = list(signature(udf).parameters.keys())
121123
udf_input_names = self._get_input_parameters(udf_parameter_names)
124+
udf_params = self.params_loader.get_parameter_args(fp_config).get(
125+
self.PARAMS_ARG_NAME, None
126+
)
122127

123128
if len(udf_input_names) == 0:
124129
raise ValueError("Expected at least one input to the user defined function.")
@@ -130,7 +135,7 @@ def provide_input_args(
130135
)
131136

132137
return OrderedDict(
133-
(input_name, self._load_data_frame(input_uri))
138+
(input_name, self._load_data_frame(data_source=input_uri, params=udf_params))
134139
for (input_name, input_uri) in zip(udf_input_names, fp_config.inputs)
135140
)
136141

@@ -189,13 +194,19 @@ def _get_input_parameters(self, udf_parameter_names: List[str]) -> List[str]:
189194

190195
def _load_data_frame(
191196
self,
192-
data_source: Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource],
197+
data_source: Union[
198+
FeatureGroupDataSource, CSVDataSource, ParquetDataSource, BaseDataSource
199+
],
200+
params: Optional[Dict[str, Union[str, Dict]]] = None,
193201
) -> DataFrame:
194202
"""Given a data source definition, load the data as a Spark DataFrame.
195203
196204
Args:
197-
data_source (Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource]):
198-
A user specified data source from the feature_processor decorator's parameters.
205+
data_source (Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource,
206+
BaseDataSource]): A user specified data source from the feature_processor
207+
decorator's parameters.
208+
params (Optional[Dict[str, Union[str, Dict]]]): Parameters provided to the
209+
feature_processor decorator.
199210
200211
Returns:
201212
DataFrame: The contents of the data source as a Spark DataFrame.
@@ -206,6 +217,13 @@ def _load_data_frame(
206217
if isinstance(data_source, FeatureGroupDataSource):
207218
return self.input_loader.load_from_feature_group(data_source)
208219

220+
if isinstance(data_source, PySparkDataSource):
221+
spark_session = self.spark_session_factory.spark_session
222+
return data_source.read_data(spark=spark_session, params=params)
223+
224+
if isinstance(data_source, BaseDataSource):
225+
return data_source.read_data(params=params)
226+
209227
raise ValueError(f"Unknown data source type: {type(data_source)}")
210228

211229
def _has_param(self, udf: Callable, name: str) -> bool:

0 commit comments

Comments
 (0)