Skip to content

feat: feature-processor extra data sources support #4155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/sagemaker/feature_store/feature_processor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
CSVDataSource,
FeatureGroupDataSource,
ParquetDataSource,
BaseDataSource,
PySparkDataSource,
)
from sagemaker.feature_store.feature_processor._exceptions import ( # noqa: F401
IngestionError,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
_JobSettings,
RUNTIME_SCRIPTS_CHANNEL_NAME,
REMOTE_FUNCTION_WORKSPACE,
SPARK_CONF_WORKSPACE,
SPARK_CONF_CHANNEL_NAME,
_prepare_and_upload_spark_dependent_files,
)
from sagemaker.remote_function.runtime_environment.runtime_environment_manager import (
Expand Down Expand Up @@ -99,7 +99,7 @@ def prepare_step_input_channel_for_spark_mode(
)

if config_file_s3_uri:
input_data_config[SPARK_CONF_WORKSPACE] = TrainingInput(
input_data_config[SPARK_CONF_CHANNEL_NAME] = TrainingInput(
s3_data=config_file_s3_uri,
s3_data_type="S3Prefix",
distribution=S3_DATA_DISTRIBUTION_TYPE,
Expand Down
70 changes: 68 additions & 2 deletions src/sagemaker/feature_store/feature_processor/_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,76 @@
"""Contains classes to define input data sources."""
from __future__ import absolute_import

from typing import Optional
from typing import Optional, Dict, Union, TypeVar, Generic
from abc import ABC, abstractmethod
from pyspark.sql import DataFrame, SparkSession


import attr

T = TypeVar("T")


@attr.s
class BaseDataSource(Generic[T], ABC):
"""Abstract base class for feature processor data sources.

Provides a skeleton for customization requiring the overriding of the method to read data from
data source and return the specified type.
"""

@abstractmethod
def read_data(self, *args, **kwargs) -> T:
"""Read data from data source and return the specified type.

Args:
args: Arguments for reading the data.
kwargs: Keyword argument for reading the data.
Returns:
T: The specified abstraction of data source.
"""

@property
@abstractmethod
def data_source_unique_id(self) -> str:
"""The identifier for the customized feature processor data source.

Returns:
str: The data source unique id.
"""

@property
@abstractmethod
def data_source_name(self) -> str:
"""The name for the customized feature processor data source.

Returns:
str: The data source name.
"""


@attr.s
class PySparkDataSource(BaseDataSource[DataFrame], ABC):
"""Abstract base class for feature processor data sources.

Provides a skeleton for customization requiring the overriding of the method to read data from
data source and return the Spark DataFrame.
"""

@abstractmethod
def read_data(
self, spark: SparkSession, params: Optional[Dict[str, Union[str, Dict]]] = None
) -> DataFrame:
"""Read data from data source and convert the data to Spark DataFrame.

Args:
spark (SparkSession): The Spark session to read the data.
params (Optional[Dict[str, Union[str, Dict]]]): Parameters provided to the
feature_processor decorator.
Returns:
DataFrame: The Spark DataFrame as an abstraction on the data source.
"""


@attr.s
class FeatureGroupDataSource:
Expand All @@ -26,7 +92,7 @@ class FeatureGroupDataSource:
name (str): The name or ARN of the Feature Group.
input_start_offset (Optional[str], optional): A duration specified as a string in the
format '<no> <unit>' where 'no' is a number and 'unit' is a unit of time in ['hours',
'days', 'weeks', 'months', 'years'] (plural and singluar forms). Inputs contain data
'days', 'weeks', 'months', 'years'] (plural and singular forms). Inputs contain data
with event times no earlier than input_start_offset in the past. Offsets are relative
to the function execution time. If the function is executed by a Schedule, then the
offset is relative to the scheduled start time. Defaults to None.
Expand Down
32 changes: 24 additions & 8 deletions src/sagemaker/feature_store/feature_processor/_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"""Contains static factory classes to instantiate complex objects for the FeatureProcessor."""
from __future__ import absolute_import

from typing import Dict
from pyspark.sql import DataFrame

from sagemaker.feature_store.feature_processor._enums import FeatureProcessorMode
Expand Down Expand Up @@ -41,6 +42,7 @@
InputValidator,
SparkUDFSignatureValidator,
InputOffsetValidator,
BaseDataSourceValidator,
ValidatorChain,
)

Expand All @@ -55,6 +57,7 @@ def get_validation_chain(fp_config: FeatureProcessorConfig) -> ValidatorChain:
InputValidator(),
FeatureProcessorArgValidator(),
InputOffsetValidator(),
BaseDataSourceValidator(),
]

mode = fp_config.mode
Expand Down Expand Up @@ -85,14 +88,19 @@ def get_udf_wrapper(fp_config: FeatureProcessorConfig) -> UDFWrapper:
mode = fp_config.mode

if FeatureProcessorMode.PYSPARK == mode:
return UDFWrapperFactory._get_spark_udf_wrapper()
return UDFWrapperFactory._get_spark_udf_wrapper(fp_config)

raise ValueError(f"FeatureProcessorMode {mode} is not supported.")

@staticmethod
def _get_spark_udf_wrapper() -> UDFWrapper[DataFrame]:
"""Instantiate a new UDFWrapper for PySpark functions."""
spark_session_factory = UDFWrapperFactory._get_spark_session_factory()
def _get_spark_udf_wrapper(fp_config: FeatureProcessorConfig) -> UDFWrapper[DataFrame]:
"""Instantiate a new UDFWrapper for PySpark functions.

Args:
fp_config (FeatureProcessorConfig): the configuration values for the feature_processor
decorator.
"""
spark_session_factory = UDFWrapperFactory._get_spark_session_factory(fp_config.spark_config)
feature_store_manager_factory = UDFWrapperFactory._get_feature_store_manager_factory()

output_manager = UDFWrapperFactory._get_spark_output_receiver(feature_store_manager_factory)
Expand Down Expand Up @@ -131,7 +139,7 @@ def _get_spark_output_receiver(

Args:
feature_store_manager_factory (FeatureStoreManagerFactory): A factory to provide
that provides a FeaturStoreManager that handles data ingestion to a Feature Group.
that provides a FeatureStoreManager that handles data ingestion to a Feature Group.
The factory lazily loads the FeatureStoreManager.

Returns:
Expand All @@ -140,10 +148,18 @@ def _get_spark_output_receiver(
return SparkOutputReceiver(feature_store_manager_factory)

@staticmethod
def _get_spark_session_factory() -> SparkSessionFactory:
"""Instantiate a new SparkSessionFactory"""
def _get_spark_session_factory(spark_config: Dict[str, str]) -> SparkSessionFactory:
"""Instantiate a new SparkSessionFactory

Args:
spark_config (Dict[str, str]): The Spark configuration that will be passed to the
initialization of Spark session.

Returns:
SparkSessionFactory: A Spark session factory instance.
"""
environment_helper = EnvironmentHelper()
return SparkSessionFactory(environment_helper)
return SparkSessionFactory(environment_helper, spark_config)

@staticmethod
def _get_feature_store_manager_factory() -> FeatureStoreManagerFactory:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
CSVDataSource,
FeatureGroupDataSource,
ParquetDataSource,
BaseDataSource,
)
from sagemaker.feature_store.feature_processor._enums import FeatureProcessorMode

Expand All @@ -37,21 +38,27 @@ class FeatureProcessorConfig:
It only serves as an immutable data class.
"""

inputs: Sequence[Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource]] = attr.ib()
inputs: Sequence[
Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource, BaseDataSource]
] = attr.ib()
output: str = attr.ib()
mode: FeatureProcessorMode = attr.ib()
target_stores: Optional[List[str]] = attr.ib()
parameters: Optional[Dict[str, Union[str, Dict]]] = attr.ib()
enable_ingestion: bool = attr.ib()
spark_config: Dict[str, str] = attr.ib()

@staticmethod
def create(
inputs: Sequence[Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource]],
inputs: Sequence[
Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource, BaseDataSource]
],
output: str,
mode: FeatureProcessorMode,
target_stores: Optional[List[str]],
parameters: Optional[Dict[str, Union[str, Dict]]],
enable_ingestion: bool,
spark_config: Dict[str, str],
) -> "FeatureProcessorConfig":
"""Static initializer."""
return FeatureProcessorConfig(
Expand All @@ -61,4 +68,5 @@ def create(
target_stores=target_stores,
parameters=parameters,
enable_ingestion=enable_ingestion,
spark_config=spark_config,
)
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def get_parameter_args(
feature_processor decorator.

Returns:
Dict[str, Union[str, Dict]]: A dictionary containin both user provided
Dict[str, Union[str, Dict]]: A dictionary that contains both user provided
parameters (feature_processor argument) and system parameters.
"""
return {
Expand Down
37 changes: 25 additions & 12 deletions src/sagemaker/feature_store/feature_processor/_spark_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import absolute_import

from functools import lru_cache
from typing import List, Tuple
from typing import List, Tuple, Dict

import feature_store_pyspark
import feature_store_pyspark.FeatureStoreManager as fsm
Expand All @@ -34,14 +34,19 @@ class SparkSessionFactory:
instance throughout the application.
"""

def __init__(self, environment_helper: EnvironmentHelper) -> None:
def __init__(
self, environment_helper: EnvironmentHelper, spark_config: Dict[str, str] = None
) -> None:
"""Initialize the SparkSessionFactory.

Args:
environment_helper (EnvironmentHelper): A helper class to determine the current
execution.
spark_config (Dict[str, str]): The Spark configuration that will be passed to the
initialization of Spark session.
"""
self.environment_helper = environment_helper
self.spark_config = spark_config

@property
@lru_cache()
Expand Down Expand Up @@ -106,24 +111,32 @@ def _get_spark_configs(self, is_training_job) -> List[Tuple[str, str]]:
("spark.port.maxRetries", "50"),
]

if self.spark_config:
spark_configs.extend(self.spark_config.items())

if not is_training_job:
fp_spark_jars = feature_store_pyspark.classpath_jars()
fp_spark_packages = [
"org.apache.hadoop:hadoop-aws:3.3.1",
"org.apache.hadoop:hadoop-common:3.3.1",
]

if self.spark_config and "spark.jars" in self.spark_config:
fp_spark_jars.append(self.spark_config.get("spark.jars"))

if self.spark_config and "spark.jars.packages" in self.spark_config:
fp_spark_packages.append(self.spark_config.get("spark.jars.packages"))

spark_configs.extend(
(
(
"spark.jars",
",".join(feature_store_pyspark.classpath_jars()),
),
("spark.jars", ",".join(fp_spark_jars)),
(
"spark.jars.packages",
",".join(
[
"org.apache.hadoop:hadoop-aws:3.3.1",
"org.apache.hadoop:hadoop-common:3.3.1",
]
),
",".join(fp_spark_packages),
),
)
)

return spark_configs

def _get_jsc_hadoop_configs(self) -> List[Tuple[str, str]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from abc import ABC, abstractmethod
from inspect import signature
from typing import Any, Callable, Dict, Generic, List, OrderedDict, TypeVar, Union
from typing import Any, Callable, Dict, Generic, List, OrderedDict, TypeVar, Union, Optional

import attr
from pyspark.sql import DataFrame, SparkSession
Expand All @@ -24,6 +24,8 @@
CSVDataSource,
FeatureGroupDataSource,
ParquetDataSource,
BaseDataSource,
PySparkDataSource,
)
from sagemaker.feature_store.feature_processor._feature_processor_config import (
FeatureProcessorConfig,
Expand Down Expand Up @@ -119,6 +121,9 @@ def provide_input_args(
"""
udf_parameter_names = list(signature(udf).parameters.keys())
udf_input_names = self._get_input_parameters(udf_parameter_names)
udf_params = self.params_loader.get_parameter_args(fp_config).get(
self.PARAMS_ARG_NAME, None
)

if len(udf_input_names) == 0:
raise ValueError("Expected at least one input to the user defined function.")
Expand All @@ -130,7 +135,7 @@ def provide_input_args(
)

return OrderedDict(
(input_name, self._load_data_frame(input_uri))
(input_name, self._load_data_frame(data_source=input_uri, params=udf_params))
for (input_name, input_uri) in zip(udf_input_names, fp_config.inputs)
)

Expand Down Expand Up @@ -189,13 +194,19 @@ def _get_input_parameters(self, udf_parameter_names: List[str]) -> List[str]:

def _load_data_frame(
self,
data_source: Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource],
data_source: Union[
FeatureGroupDataSource, CSVDataSource, ParquetDataSource, BaseDataSource
],
params: Optional[Dict[str, Union[str, Dict]]] = None,
) -> DataFrame:
"""Given a data source definition, load the data as a Spark DataFrame.

Args:
data_source (Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource]):
A user specified data source from the feature_processor decorator's parameters.
data_source (Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource,
BaseDataSource]): A user specified data source from the feature_processor
decorator's parameters.
params (Optional[Dict[str, Union[str, Dict]]]): Parameters provided to the
feature_processor decorator.

Returns:
DataFrame: The contents of the data source as a Spark DataFrame.
Expand All @@ -206,6 +217,13 @@ def _load_data_frame(
if isinstance(data_source, FeatureGroupDataSource):
return self.input_loader.load_from_feature_group(data_source)

if isinstance(data_source, PySparkDataSource):
spark_session = self.spark_session_factory.spark_session
return data_source.read_data(spark=spark_session, params=params)

if isinstance(data_source, BaseDataSource):
return data_source.read_data(params=params)

raise ValueError(f"Unknown data source type: {type(data_source)}")

def _has_param(self, udf: Callable, name: str) -> bool:
Expand Down
Loading