From 2971c88880b3778edc7f6b1a703d1aa1d4d27750 Mon Sep 17 00:00:00 2001 From: cansun <80425164+can-sun@users.noreply.github.com> Date: Thu, 28 Sep 2023 13:14:18 -0700 Subject: [PATCH] [feat] @feature-processor extra data sources support (#1087) * Enable feature_processor to connect with user defined data sources * Add local integration test for BaseDataSource * Fix the integration tests * Add custom data source name validations * Remove test unused imports * Fix the data source name invalid pattern * Resolve minor PR comments * Add extra abstraction layer for feature processor * Fix the docstring of BaseDataSource * Fix the failed integration test --- .../feature_processor/__init__.py | 2 + .../feature_processor/_config_uploader.py | 4 +- .../feature_processor/_data_source.py | 70 +++++++- .../feature_processor/_factory.py | 32 +++- .../_feature_processor_config.py | 12 +- .../feature_processor/_params_loader.py | 2 +- .../feature_processor/_spark_factory.py | 37 ++-- .../feature_processor/_udf_arg_provider.py | 28 ++- .../feature_processor/_validation.py | 36 ++++ .../feature_processor/feature_processor.py | 14 +- .../lineage/_feature_processor_lineage.py | 26 ++- .../lineage/_s3_lineage_entity_handler.py | 30 +++- .../test_feature_processor.py | 169 +++++++++++++++++- .../lineage/test_constants.py | 16 +- .../lineage/test_feature_processor_lineage.py | 43 +++-- .../lineage/test_s3_lineage_entity_handler.py | 55 ++++++ .../feature_processor/test_config_uploader.py | 4 +- .../feature_processor/test_data_helpers.py | 2 + .../feature_processor/test_data_source.py | 34 ++++ .../feature_processor/test_factory.py | 4 +- .../test_feature_processor_config.py | 1 + .../test_feature_scheduler.py | 4 +- .../test_spark_session_factory.py | 9 +- .../test_udf_arg_provider.py | 31 +++- .../feature_processor/test_validation.py | 44 +++++ 25 files changed, 635 insertions(+), 74 deletions(-) create mode 100644 tests/unit/sagemaker/feature_store/feature_processor/test_data_source.py diff --git a/src/sagemaker/feature_store/feature_processor/__init__.py b/src/sagemaker/feature_store/feature_processor/__init__.py index 867365a620..5064a5a839 100644 --- a/src/sagemaker/feature_store/feature_processor/__init__.py +++ b/src/sagemaker/feature_store/feature_processor/__init__.py @@ -17,6 +17,8 @@ CSVDataSource, FeatureGroupDataSource, ParquetDataSource, + BaseDataSource, + PySparkDataSource, ) from sagemaker.feature_store.feature_processor._exceptions import ( # noqa: F401 IngestionError, diff --git a/src/sagemaker/feature_store/feature_processor/_config_uploader.py b/src/sagemaker/feature_store/feature_processor/_config_uploader.py index 018a56d785..5dc0fe5331 100644 --- a/src/sagemaker/feature_store/feature_processor/_config_uploader.py +++ b/src/sagemaker/feature_store/feature_processor/_config_uploader.py @@ -31,7 +31,7 @@ _JobSettings, RUNTIME_SCRIPTS_CHANNEL_NAME, REMOTE_FUNCTION_WORKSPACE, - SPARK_CONF_WORKSPACE, + SPARK_CONF_CHANNEL_NAME, _prepare_and_upload_spark_dependent_files, ) from sagemaker.remote_function.runtime_environment.runtime_environment_manager import ( @@ -99,7 +99,7 @@ def prepare_step_input_channel_for_spark_mode( ) if config_file_s3_uri: - input_data_config[SPARK_CONF_WORKSPACE] = TrainingInput( + input_data_config[SPARK_CONF_CHANNEL_NAME] = TrainingInput( s3_data=config_file_s3_uri, s3_data_type="S3Prefix", distribution=S3_DATA_DISTRIBUTION_TYPE, diff --git a/src/sagemaker/feature_store/feature_processor/_data_source.py b/src/sagemaker/feature_store/feature_processor/_data_source.py index f09ea1a0cd..a6c452267c 100644 --- a/src/sagemaker/feature_store/feature_processor/_data_source.py +++ b/src/sagemaker/feature_store/feature_processor/_data_source.py @@ -13,10 +13,76 @@ """Contains classes to define input data sources.""" from __future__ import absolute_import -from typing import Optional +from typing import Optional, Dict, Union, TypeVar, Generic +from abc import ABC, abstractmethod +from pyspark.sql import DataFrame, SparkSession + import attr +T = TypeVar("T") + + +@attr.s +class BaseDataSource(Generic[T], ABC): + """Abstract base class for feature processor data sources. + + Provides a skeleton for customization requiring the overriding of the method to read data from + data source and return the specified type. + """ + + @abstractmethod + def read_data(self, *args, **kwargs) -> T: + """Read data from data source and return the specified type. + + Args: + args: Arguments for reading the data. + kwargs: Keyword argument for reading the data. + Returns: + T: The specified abstraction of data source. + """ + + @property + @abstractmethod + def data_source_unique_id(self) -> str: + """The identifier for the customized feature processor data source. + + Returns: + str: The data source unique id. + """ + + @property + @abstractmethod + def data_source_name(self) -> str: + """The name for the customized feature processor data source. + + Returns: + str: The data source name. + """ + + +@attr.s +class PySparkDataSource(BaseDataSource[DataFrame], ABC): + """Abstract base class for feature processor data sources. + + Provides a skeleton for customization requiring the overriding of the method to read data from + data source and return the Spark DataFrame. + """ + + @abstractmethod + def read_data( + self, spark: SparkSession, params: Optional[Dict[str, Union[str, Dict]]] = None + ) -> DataFrame: + """Read data from data source and convert the data to Spark DataFrame. + + Args: + spark (SparkSession): The Spark session to read the data. + params (Optional[Dict[str, Union[str, Dict]]]): Parameters provided to the + feature_processor decorator. + Returns: + DataFrame: The Spark DataFrame as an abstraction on the data source. + """ + @attr.s class FeatureGroupDataSource: @@ -26,7 +92,7 @@ class FeatureGroupDataSource: name (str): The name or ARN of the Feature Group. input_start_offset (Optional[str], optional): A duration specified as a string in the format ' ' where 'no' is a number and 'unit' is a unit of time in ['hours', - 'days', 'weeks', 'months', 'years'] (plural and singluar forms). Inputs contain data + 'days', 'weeks', 'months', 'years'] (plural and singular forms). Inputs contain data with event times no earlier than input_start_offset in the past. Offsets are relative to the function execution time. If the function is executed by a Schedule, then the offset is relative to the scheduled start time. Defaults to None. diff --git a/src/sagemaker/feature_store/feature_processor/_factory.py b/src/sagemaker/feature_store/feature_processor/_factory.py index 36b8de4e45..4b73b5c4d3 100644 --- a/src/sagemaker/feature_store/feature_processor/_factory.py +++ b/src/sagemaker/feature_store/feature_processor/_factory.py @@ -13,6 +13,7 @@ """Contains static factory classes to instantiate complex objects for the FeatureProcessor.""" from __future__ import absolute_import +from typing import Dict from pyspark.sql import DataFrame from sagemaker.feature_store.feature_processor._enums import FeatureProcessorMode @@ -41,6 +42,7 @@ InputValidator, SparkUDFSignatureValidator, InputOffsetValidator, + BaseDataSourceValidator, ValidatorChain, ) @@ -55,6 +57,7 @@ def get_validation_chain(fp_config: FeatureProcessorConfig) -> ValidatorChain: InputValidator(), FeatureProcessorArgValidator(), InputOffsetValidator(), + BaseDataSourceValidator(), ] mode = fp_config.mode @@ -85,14 +88,19 @@ def get_udf_wrapper(fp_config: FeatureProcessorConfig) -> UDFWrapper: mode = fp_config.mode if FeatureProcessorMode.PYSPARK == mode: - return UDFWrapperFactory._get_spark_udf_wrapper() + return UDFWrapperFactory._get_spark_udf_wrapper(fp_config) raise ValueError(f"FeatureProcessorMode {mode} is not supported.") @staticmethod - def _get_spark_udf_wrapper() -> UDFWrapper[DataFrame]: - """Instantiate a new UDFWrapper for PySpark functions.""" - spark_session_factory = UDFWrapperFactory._get_spark_session_factory() + def _get_spark_udf_wrapper(fp_config: FeatureProcessorConfig) -> UDFWrapper[DataFrame]: + """Instantiate a new UDFWrapper for PySpark functions. + + Args: + fp_config (FeatureProcessorConfig): the configuration values for the feature_processor + decorator. + """ + spark_session_factory = UDFWrapperFactory._get_spark_session_factory(fp_config.spark_config) feature_store_manager_factory = UDFWrapperFactory._get_feature_store_manager_factory() output_manager = UDFWrapperFactory._get_spark_output_receiver(feature_store_manager_factory) @@ -131,7 +139,7 @@ def _get_spark_output_receiver( Args: feature_store_manager_factory (FeatureStoreManagerFactory): A factory to provide - that provides a FeaturStoreManager that handles data ingestion to a Feature Group. + that provides a FeatureStoreManager that handles data ingestion to a Feature Group. The factory lazily loads the FeatureStoreManager. Returns: @@ -140,10 +148,18 @@ def _get_spark_output_receiver( return SparkOutputReceiver(feature_store_manager_factory) @staticmethod - def _get_spark_session_factory() -> SparkSessionFactory: - """Instantiate a new SparkSessionFactory""" + def _get_spark_session_factory(spark_config: Dict[str, str]) -> SparkSessionFactory: + """Instantiate a new SparkSessionFactory + + Args: + spark_config (Dict[str, str]): The Spark configuration that will be passed to the + initialization of Spark session. + + Returns: + SparkSessionFactory: A Spark session factory instance. + """ environment_helper = EnvironmentHelper() - return SparkSessionFactory(environment_helper) + return SparkSessionFactory(environment_helper, spark_config) @staticmethod def _get_feature_store_manager_factory() -> FeatureStoreManagerFactory: diff --git a/src/sagemaker/feature_store/feature_processor/_feature_processor_config.py b/src/sagemaker/feature_store/feature_processor/_feature_processor_config.py index 28627eb1db..d98baa7b9a 100644 --- a/src/sagemaker/feature_store/feature_processor/_feature_processor_config.py +++ b/src/sagemaker/feature_store/feature_processor/_feature_processor_config.py @@ -21,6 +21,7 @@ CSVDataSource, FeatureGroupDataSource, ParquetDataSource, + BaseDataSource, ) from sagemaker.feature_store.feature_processor._enums import FeatureProcessorMode @@ -37,21 +38,27 @@ class FeatureProcessorConfig: It only serves as an immutable data class. """ - inputs: Sequence[Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource]] = attr.ib() + inputs: Sequence[ + Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource, BaseDataSource] + ] = attr.ib() output: str = attr.ib() mode: FeatureProcessorMode = attr.ib() target_stores: Optional[List[str]] = attr.ib() parameters: Optional[Dict[str, Union[str, Dict]]] = attr.ib() enable_ingestion: bool = attr.ib() + spark_config: Dict[str, str] = attr.ib() @staticmethod def create( - inputs: Sequence[Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource]], + inputs: Sequence[ + Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource, BaseDataSource] + ], output: str, mode: FeatureProcessorMode, target_stores: Optional[List[str]], parameters: Optional[Dict[str, Union[str, Dict]]], enable_ingestion: bool, + spark_config: Dict[str, str], ) -> "FeatureProcessorConfig": """Static initializer.""" return FeatureProcessorConfig( @@ -61,4 +68,5 @@ def create( target_stores=target_stores, parameters=parameters, enable_ingestion=enable_ingestion, + spark_config=spark_config, ) diff --git a/src/sagemaker/feature_store/feature_processor/_params_loader.py b/src/sagemaker/feature_store/feature_processor/_params_loader.py index 1bdaa77114..64ee001bc3 100644 --- a/src/sagemaker/feature_store/feature_processor/_params_loader.py +++ b/src/sagemaker/feature_store/feature_processor/_params_loader.py @@ -72,7 +72,7 @@ def get_parameter_args( feature_processor decorator. Returns: - Dict[str, Union[str, Dict]]: A dictionary containin both user provided + Dict[str, Union[str, Dict]]: A dictionary that contains both user provided parameters (feature_processor argument) and system parameters. """ return { diff --git a/src/sagemaker/feature_store/feature_processor/_spark_factory.py b/src/sagemaker/feature_store/feature_processor/_spark_factory.py index 474315bb77..76a48218b0 100644 --- a/src/sagemaker/feature_store/feature_processor/_spark_factory.py +++ b/src/sagemaker/feature_store/feature_processor/_spark_factory.py @@ -14,7 +14,7 @@ from __future__ import absolute_import from functools import lru_cache -from typing import List, Tuple +from typing import List, Tuple, Dict import feature_store_pyspark import feature_store_pyspark.FeatureStoreManager as fsm @@ -34,14 +34,19 @@ class SparkSessionFactory: instance throughout the application. """ - def __init__(self, environment_helper: EnvironmentHelper) -> None: + def __init__( + self, environment_helper: EnvironmentHelper, spark_config: Dict[str, str] = None + ) -> None: """Initialize the SparkSessionFactory. Args: environment_helper (EnvironmentHelper): A helper class to determine the current execution. + spark_config (Dict[str, str]): The Spark configuration that will be passed to the + initialization of Spark session. """ self.environment_helper = environment_helper + self.spark_config = spark_config @property @lru_cache() @@ -106,24 +111,32 @@ def _get_spark_configs(self, is_training_job) -> List[Tuple[str, str]]: ("spark.port.maxRetries", "50"), ] + if self.spark_config: + spark_configs.extend(self.spark_config.items()) + if not is_training_job: + fp_spark_jars = feature_store_pyspark.classpath_jars() + fp_spark_packages = [ + "org.apache.hadoop:hadoop-aws:3.3.1", + "org.apache.hadoop:hadoop-common:3.3.1", + ] + + if self.spark_config and "spark.jars" in self.spark_config: + fp_spark_jars.append(self.spark_config.get("spark.jars")) + + if self.spark_config and "spark.jars.packages" in self.spark_config: + fp_spark_packages.append(self.spark_config.get("spark.jars.packages")) + spark_configs.extend( ( - ( - "spark.jars", - ",".join(feature_store_pyspark.classpath_jars()), - ), + ("spark.jars", ",".join(fp_spark_jars)), ( "spark.jars.packages", - ",".join( - [ - "org.apache.hadoop:hadoop-aws:3.3.1", - "org.apache.hadoop:hadoop-common:3.3.1", - ] - ), + ",".join(fp_spark_packages), ), ) ) + return spark_configs def _get_jsc_hadoop_configs(self) -> List[Tuple[str, str]]: diff --git a/src/sagemaker/feature_store/feature_processor/_udf_arg_provider.py b/src/sagemaker/feature_store/feature_processor/_udf_arg_provider.py index abf55c5c7f..604505cba7 100644 --- a/src/sagemaker/feature_store/feature_processor/_udf_arg_provider.py +++ b/src/sagemaker/feature_store/feature_processor/_udf_arg_provider.py @@ -15,7 +15,7 @@ from abc import ABC, abstractmethod from inspect import signature -from typing import Any, Callable, Dict, Generic, List, OrderedDict, TypeVar, Union +from typing import Any, Callable, Dict, Generic, List, OrderedDict, TypeVar, Union, Optional import attr from pyspark.sql import DataFrame, SparkSession @@ -24,6 +24,8 @@ CSVDataSource, FeatureGroupDataSource, ParquetDataSource, + BaseDataSource, + PySparkDataSource, ) from sagemaker.feature_store.feature_processor._feature_processor_config import ( FeatureProcessorConfig, @@ -119,6 +121,9 @@ def provide_input_args( """ udf_parameter_names = list(signature(udf).parameters.keys()) udf_input_names = self._get_input_parameters(udf_parameter_names) + udf_params = self.params_loader.get_parameter_args(fp_config).get( + self.PARAMS_ARG_NAME, None + ) if len(udf_input_names) == 0: raise ValueError("Expected at least one input to the user defined function.") @@ -130,7 +135,7 @@ def provide_input_args( ) return OrderedDict( - (input_name, self._load_data_frame(input_uri)) + (input_name, self._load_data_frame(data_source=input_uri, params=udf_params)) for (input_name, input_uri) in zip(udf_input_names, fp_config.inputs) ) @@ -189,13 +194,19 @@ def _get_input_parameters(self, udf_parameter_names: List[str]) -> List[str]: def _load_data_frame( self, - data_source: Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource], + data_source: Union[ + FeatureGroupDataSource, CSVDataSource, ParquetDataSource, BaseDataSource + ], + params: Optional[Dict[str, Union[str, Dict]]] = None, ) -> DataFrame: """Given a data source definition, load the data as a Spark DataFrame. Args: - data_source (Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource]): - A user specified data source from the feature_processor decorator's parameters. + data_source (Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource, + BaseDataSource]): A user specified data source from the feature_processor + decorator's parameters. + params (Optional[Dict[str, Union[str, Dict]]]): Parameters provided to the + feature_processor decorator. Returns: DataFrame: The contents of the data source as a Spark DataFrame. @@ -206,6 +217,13 @@ def _load_data_frame( if isinstance(data_source, FeatureGroupDataSource): return self.input_loader.load_from_feature_group(data_source) + if isinstance(data_source, PySparkDataSource): + spark_session = self.spark_session_factory.spark_session + return data_source.read_data(spark=spark_session, params=params) + + if isinstance(data_source, BaseDataSource): + return data_source.read_data(params=params) + raise ValueError(f"Unknown data source type: {type(data_source)}") def _has_param(self, udf: Callable, name: str) -> bool: diff --git a/src/sagemaker/feature_store/feature_processor/_validation.py b/src/sagemaker/feature_store/feature_processor/_validation.py index d9458f062e..e5bab0ed07 100644 --- a/src/sagemaker/feature_store/feature_processor/_validation.py +++ b/src/sagemaker/feature_store/feature_processor/_validation.py @@ -14,6 +14,7 @@ from __future__ import absolute_import import inspect +import re from abc import ABC, abstractmethod from typing import Any, Callable, List @@ -21,6 +22,7 @@ from sagemaker.feature_store.feature_processor._data_source import ( FeatureGroupDataSource, + BaseDataSource, ) from sagemaker.feature_store.feature_processor._feature_processor_config import ( FeatureProcessorConfig, @@ -172,3 +174,37 @@ def validate(self, udf: Callable[..., Any], fp_config: FeatureProcessorConfig) - end_td = InputOffsetParser.parse_offset_to_timedelta(input_end_offset) if start_td and end_td and start_td > end_td: raise ValueError("input_start_offset should be always before input_end_offset.") + + +class BaseDataSourceValidator(Validator): + """An Validator for BaseDataSource.""" + + def validate(self, udf: Callable[..., Any], fp_config: FeatureProcessorConfig) -> None: + """Validate the BaseDataSource provided to the decorator. + + Args: + udf (Callable[..., T]): The feature_processor wrapped user function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + + Raises (ValueError): raises ValueError when data_source_unique_id or data_source_name + of the input data source is not valid. + """ + + for config_input in fp_config.inputs: + if isinstance(config_input, BaseDataSource): + source_name = config_input.data_source_name + source_id = config_input.data_source_unique_id + + source_name_pattern = r"^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,119}$" + source_id_pattern = r"^.{1,2048}$" + + if not re.match(source_name_pattern, source_name): + raise ValueError( + f"data_source_name of input does not match pattern '{source_name_pattern}'." + ) + + if not re.match(source_id_pattern, source_id): + raise ValueError( + f"data_source_unique_id of input does not match " + f"pattern '{source_id_pattern}'." + ) diff --git a/src/sagemaker/feature_store/feature_processor/feature_processor.py b/src/sagemaker/feature_store/feature_processor/feature_processor.py index c1b9053e57..e957dbd0ea 100644 --- a/src/sagemaker/feature_store/feature_processor/feature_processor.py +++ b/src/sagemaker/feature_store/feature_processor/feature_processor.py @@ -19,6 +19,7 @@ CSVDataSource, FeatureGroupDataSource, ParquetDataSource, + BaseDataSource, ) from sagemaker.feature_store.feature_processor._enums import FeatureProcessorMode from sagemaker.feature_store.feature_processor._factory import ( @@ -31,11 +32,14 @@ def feature_processor( - inputs: Sequence[Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource]], + inputs: Sequence[ + Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource, BaseDataSource] + ], output: str, target_stores: Optional[List[str]] = None, parameters: Optional[Dict[str, Union[str, Dict]]] = None, enable_ingestion: bool = True, + spark_config: Dict[str, str] = None, ) -> Callable: """Decorator to facilitate feature engineering for Feature Groups. @@ -46,7 +50,7 @@ def feature_processor( Decorated functions must conform to the expected signature. Parameters: one parameter of type pyspark.sql.DataFrame for each DataSource in 'inputs'; followed by the optional parameters with - names nand types in [params: Dict[str, Any], spark: SparkSession]. Outputs: a single return + names and types in [params: Dict[str, Any], spark: SparkSession]. Outputs: a single return value of type pyspark.sql.DataFrame. The function can have any name. **Example:** @@ -75,8 +79,8 @@ def transform(input_feature_group, input_csv): return ... Args: - inputs (Sequence[Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource]]): A list - of data sources. + inputs (Sequence[Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource, + BaseDataSource]]): A list of data sources. output (str): A Feature Group ARN to write results of this function to. target_stores (Optional[list[str]], optional): A list containing at least one of 'OnlineStore' or 'OfflineStore'. If unspecified, data will be ingested to the enabled @@ -91,6 +95,7 @@ def transform(input_feature_group, input_csv): return value is ingested to the 'output' Feature Group. This flag is useful during the development phase to ensure that data is not used until the function is ready. It also useful for users that want to manage their own data ingestion. Defaults to True. + spark_config (Dict[str, str]): A dict contains the key-value paris for Spark configurations. Raises: IngestionError: If any rows are not ingested successfully then a sample of the records, @@ -108,6 +113,7 @@ def decorator(udf: Callable[..., Any]) -> Callable: target_stores=target_stores, parameters=parameters, enable_ingestion=enable_ingestion, + spark_config=spark_config, ) validator_chain = ValidatorFactory.get_validation_chain(fp_config) diff --git a/src/sagemaker/feature_store/feature_processor/lineage/_feature_processor_lineage.py b/src/sagemaker/feature_store/feature_processor/lineage/_feature_processor_lineage.py index 0f0f913fa2..f8e198a1f9 100644 --- a/src/sagemaker/feature_store/feature_processor/lineage/_feature_processor_lineage.py +++ b/src/sagemaker/feature_store/feature_processor/lineage/_feature_processor_lineage.py @@ -73,6 +73,7 @@ CSVDataSource, FeatureGroupDataSource, ParquetDataSource, + BaseDataSource, ) logger = logging.getLogger(SAGEMAKER) @@ -86,8 +87,8 @@ class FeatureProcessorLineageHandler: pipeline_name (str): Pipeline Name. pipeline_arn (str): The ARN of the Pipeline. pipeline (str): The details of the Pipeline. - inputs (Sequence[Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource]]): - The inputs to the Feature processor. + inputs (Sequence[Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource, + BaseDataSource]]): The inputs to the Feature processor. output (str): The output Feature Group. transformation_code (TransformationCode): The Transformation Code for Feature Processor. sagemaker_session (Session): Session object which manages interactions @@ -99,9 +100,9 @@ class FeatureProcessorLineageHandler: pipeline_arn: str = attr.ib() pipeline: Dict = attr.ib() sagemaker_session: Session = attr.ib() - inputs: Sequence[Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource]] = attr.ib( - default=None - ) + inputs: Sequence[ + Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource, BaseDataSource] + ] = attr.ib(default=None) output: str = attr.ib(default=None) transformation_code: TransformationCode = attr.ib(default=None) @@ -398,17 +399,24 @@ def _retrieve_input_raw_data_artifacts(self) -> List[Artifact]: List[Artifact]: List of Raw Data Artifacts. """ raw_data_artifacts: List[Artifact] = list() - raw_data_s3_uri_set: Set[str] = set() + raw_data_uri_set: Set[str] = set() + for data_source in self.inputs: - if isinstance(data_source, (CSVDataSource, ParquetDataSource)): - if data_source.s3_uri not in raw_data_s3_uri_set: - raw_data_s3_uri_set.add(data_source.s3_uri) + if isinstance(data_source, (CSVDataSource, ParquetDataSource, BaseDataSource)): + data_source_uri = ( + data_source.s3_uri + if isinstance(data_source, (CSVDataSource, ParquetDataSource)) + else data_source.data_source_unique_id + ) + if data_source_uri not in raw_data_uri_set: + raw_data_uri_set.add(data_source_uri) raw_data_artifacts.append( S3LineageEntityHandler.retrieve_raw_data_artifact( raw_data=data_source, sagemaker_session=self.sagemaker_session, ) ) + return raw_data_artifacts def _compare_upstream_raw_data( diff --git a/src/sagemaker/feature_store/feature_processor/lineage/_s3_lineage_entity_handler.py b/src/sagemaker/feature_store/feature_processor/lineage/_s3_lineage_entity_handler.py index ee60d9475b..29bd79a75e 100644 --- a/src/sagemaker/feature_store/feature_processor/lineage/_s3_lineage_entity_handler.py +++ b/src/sagemaker/feature_store/feature_processor/lineage/_s3_lineage_entity_handler.py @@ -16,7 +16,11 @@ from typing import Union, Optional, List from sagemaker import Session -from sagemaker.feature_store.feature_processor import CSVDataSource, ParquetDataSource +from sagemaker.feature_store.feature_processor import ( + CSVDataSource, + ParquetDataSource, + BaseDataSource, +) # pylint: disable=C0301 from sagemaker.feature_store.feature_processor.lineage._feature_processor_lineage_name_helper import ( @@ -43,12 +47,14 @@ class S3LineageEntityHandler: @staticmethod def retrieve_raw_data_artifact( - raw_data: Union[CSVDataSource, ParquetDataSource], sagemaker_session: Session + raw_data: Union[CSVDataSource, ParquetDataSource, BaseDataSource], + sagemaker_session: Session, ) -> Artifact: """Load or create the FeatureProcessor Pipeline's raw data Artifact. Arguments: - raw_data (Union[CSVDataSource, ParquetDataSource]): The raw data to be retrieved. + raw_data (Union[CSVDataSource, ParquetDataSource, BaseDataSource]): The raw data to be + retrieved. sagemaker_session (Session): Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the function creates one using the default AWS configuration chain. @@ -56,18 +62,30 @@ def retrieve_raw_data_artifact( Returns: Artifact: The raw data artifact. """ + raw_data_uri = ( + raw_data.s3_uri + if isinstance(raw_data, (CSVDataSource, ParquetDataSource)) + else raw_data.data_source_unique_id + ) + raw_data_artifact_name = ( + "sm-fs-fe-raw-data" + if isinstance(raw_data, (CSVDataSource, ParquetDataSource)) + else raw_data.data_source_name + ) + load_artifact: ArtifactSummary = S3LineageEntityHandler._load_artifact_from_s3_uri( - s3_uri=raw_data.s3_uri, sagemaker_session=sagemaker_session + s3_uri=raw_data_uri, sagemaker_session=sagemaker_session ) if load_artifact is not None: return S3LineageEntityHandler.load_artifact_from_arn( artifact_arn=load_artifact.artifact_arn, sagemaker_session=sagemaker_session, ) + return S3LineageEntityHandler._create_artifact( - s3_uri=raw_data.s3_uri, + s3_uri=raw_data_uri, artifact_type="DataSet", - artifact_name="sm-fs-fe-raw-data", + artifact_name=raw_data_artifact_name, sagemaker_session=sagemaker_session, ) diff --git a/tests/integ/sagemaker/feature_store/feature_processor/test_feature_processor.py b/tests/integ/sagemaker/feature_store/feature_processor/test_feature_processor.py index f602296eb7..eff8be8a13 100644 --- a/tests/integ/sagemaker/feature_store/feature_processor/test_feature_processor.py +++ b/tests/integ/sagemaker/feature_store/feature_processor/test_feature_processor.py @@ -19,12 +19,14 @@ import time from typing import Dict from datetime import datetime +from pyspark.sql import DataFrame import pytz import pytest import pandas as pd import numpy as np import json +import attr from boto3 import client from tests.integ import DATA_DIR @@ -39,6 +41,7 @@ from sagemaker.feature_store.feature_processor import ( feature_processor, CSVDataSource, + PySparkDataSource, ) from sagemaker.feature_store.feature_processor.feature_scheduler import ( to_pipeline, @@ -222,6 +225,148 @@ def transform(raw_s3_data_as_df): ) +@pytest.mark.slow_test +def test_feature_processor_transform_with_customized_data_source( + sagemaker_session, +): + car_data_feature_group_name = get_car_data_feature_group_name() + car_data_aggregated_feature_group_name = get_car_data_aggregated_feature_group_name() + + try: + feature_groups = create_feature_groups( + sagemaker_session=sagemaker_session, + car_data_feature_group_name=car_data_feature_group_name, + car_data_aggregated_feature_group_name=car_data_aggregated_feature_group_name, + offline_store_s3_uri=get_offline_store_s3_uri(sagemaker_session=sagemaker_session), + ) + + raw_data_uri = get_raw_car_data_s3_uri(sagemaker_session=sagemaker_session) + + @attr.s + class TestCSVDataSource(PySparkDataSource): + + s3_uri = attr.ib() + data_source_name = "TestCSVDataSource" + data_source_unique_id = "s3_uri" + + def read_data(self, spark, params) -> DataFrame: + s3a_uri = self.s3_uri.replace("s3://", "s3a://") + return spark.read.csv(s3a_uri, header=True, inferSchema=False) + + @feature_processor( + inputs=[TestCSVDataSource(raw_data_uri)], + output=feature_groups["car_data_arn"], + target_stores=["OnlineStore"], + spark_config={ + "spark.hadoop.fs.s3a.aws.credentials.provider": ",".join( + [ + "com.amazonaws.auth.ContainerCredentialsProvider", + "com.amazonaws.auth.profile.ProfileCredentialsProvider", + "com.amazonaws.auth.DefaultAWSCredentialsProviderChain", + ] + ) + }, + ) + def transform(raw_s3_data_as_df): + """Load data from S3, perform basic feature engineering, store it in a Feature Group""" + from pyspark.sql.functions import regexp_replace + from pyspark.sql.functions import lit + + transformed_df = ( + raw_s3_data_as_df + # Rename Columns + .withColumnRenamed("Id", "id") + .withColumnRenamed("Model", "model") + .withColumnRenamed("Year", "model_year") + .withColumnRenamed("Status", "status") + .withColumnRenamed("Mileage", "mileage") + .withColumnRenamed("Price", "price") + .withColumnRenamed("MSRP", "msrp") + # Add Event Time + .withColumn("ingest_time", lit(int(time.time()))) + # Remove punctuation and fluff; replace with NA + .withColumn("Price", regexp_replace("Price", "\$", "")) # noqa: W605 + .withColumn("mileage", regexp_replace("mileage", "(,)|(mi\.)", "")) # noqa: W605 + .withColumn("mileage", regexp_replace("mileage", "Not available", "NA")) + .withColumn("price", regexp_replace("price", ",", "")) + .withColumn("msrp", regexp_replace("msrp", "(^MSRP\s\\$)|(,)", "")) # noqa: W605 + .withColumn("msrp", regexp_replace("msrp", "Not specified", "NA")) + .withColumn("msrp", regexp_replace("msrp", "\\$\d+[a-zA-Z\s]+", "NA")) # noqa: W605 + .withColumn("model", regexp_replace("model", "^\d\d\d\d\s", "")) # noqa: W605 + ) + + transformed_df.show() + return transformed_df + + transform() + + featurestore_client = sagemaker_session.sagemaker_featurestore_runtime_client + results = featurestore_client.batch_get_record( + Identifiers=[ + { + "FeatureGroupName": car_data_feature_group_name, + "RecordIdentifiersValueAsString": [ + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "10", + "11", + "12", + "13", + "14", + "15", + "16", + "17", + "18", + "19", + "20", + "21", + "22", + "23", + "24", + "25", + ], + }, + ] + ) + + assert len(results["Records"]) == 26 + + car_sales_query = feature_groups["car_data_feature_group"].athena_query() + query = f'SELECT * FROM "sagemaker_featurestore".{car_sales_query.table_name} LIMIT 1000;' + output_uri = "s3://{}/{}/input/data/{}".format( + sagemaker_session.default_bucket(), + "feature-processor-test", + "csv-data-fg-result", + ) + car_sales_query.run(query_string=query, output_location=output_uri) + car_sales_query.wait() + dataset = car_sales_query.as_dataframe() + assert dataset.empty + finally: + cleanup_offline_store( + feature_group=feature_groups["car_data_feature_group"], + sagemaker_session=sagemaker_session, + ) + cleanup_offline_store( + feature_group=feature_groups["car_data_aggregated_feature_group"], + sagemaker_session=sagemaker_session, + ) + cleanup_feature_group( + feature_groups["car_data_feature_group"], sagemaker_session=sagemaker_session + ) + cleanup_feature_group( + feature_groups["car_data_aggregated_feature_group"], sagemaker_session=sagemaker_session + ) + + @pytest.mark.slow_test @pytest.mark.flaky(reruns=5, reruns_delay=2) def test_feature_processor_transform_offline_only_store_ingestion( @@ -368,11 +513,17 @@ def test_feature_processor_transform_offline_only_store_ingestion_run_with_remot raw_data_uri = get_raw_car_data_s3_uri(sagemaker_session=sagemaker_session) whl_file_uri = get_wheel_file_s3_uri(sagemaker_session=sagemaker_session) + whl_file_name = os.path.basename(whl_file_uri) + + pre_execution_commands = [ + f"aws s3 cp {whl_file_uri} ./", + f"/usr/local/bin/python3.9 -m pip install ./{whl_file_name} --force-reinstall", + ] @remote( + pre_execution_commands=pre_execution_commands, spark_config=SparkConfig(), instance_type="ml.m5.xlarge", - python_sdk_whl_s3_uri=whl_file_uri, ) @feature_processor( inputs=[CSVDataSource(raw_data_uri)], @@ -504,11 +655,17 @@ def test_to_pipeline_and_execute( raw_data_uri = get_raw_car_data_s3_uri(sagemaker_session=sagemaker_session) whl_file_uri = get_wheel_file_s3_uri(sagemaker_session=sagemaker_session) + whl_file_name = os.path.basename(whl_file_uri) + + pre_execution_commands = [ + f"aws s3 cp {whl_file_uri} ./", + f"/usr/local/bin/python3.9 -m pip install ./{whl_file_name} --force-reinstall", + ] @remote( + pre_execution_commands=pre_execution_commands, spark_config=SparkConfig(), instance_type="ml.m5.xlarge", - python_sdk_whl_s3_uri=whl_file_uri, ) @feature_processor( inputs=[CSVDataSource(raw_data_uri)], @@ -621,11 +778,17 @@ def test_schedule( raw_data_uri = get_raw_car_data_s3_uri(sagemaker_session=sagemaker_session) whl_file_uri = get_wheel_file_s3_uri(sagemaker_session=sagemaker_session) + whl_file_name = os.path.basename(whl_file_uri) + + pre_execution_commands = [ + f"aws s3 cp {whl_file_uri} ./", + f"/usr/local/bin/python3.9 -m pip install ./{whl_file_name} --force-reinstall", + ] @remote( + pre_execution_commands=pre_execution_commands, spark_config=SparkConfig(), instance_type="ml.m5.xlarge", - python_sdk_whl_s3_uri=whl_file_uri, ) @feature_processor( inputs=[CSVDataSource(raw_data_uri)], diff --git a/tests/unit/sagemaker/feature_store/feature_processor/lineage/test_constants.py b/tests/unit/sagemaker/feature_store/feature_processor/lineage/test_constants.py index 1595b1164a..73b6cddd99 100644 --- a/tests/unit/sagemaker/feature_store/feature_processor/lineage/test_constants.py +++ b/tests/unit/sagemaker/feature_store/feature_processor/lineage/test_constants.py @@ -18,12 +18,14 @@ from botocore.exceptions import ClientError from mock import Mock +from pyspark.sql import DataFrame from sagemaker import Session from sagemaker.feature_store.feature_processor._data_source import ( CSVDataSource, FeatureGroupDataSource, ParquetDataSource, + BaseDataSource, ) from sagemaker.feature_store.feature_processor.lineage._feature_group_contexts import ( FeatureGroupContexts, @@ -46,6 +48,16 @@ CONTEXT_MOCK_01 = Mock(Context) CONTEXT_MOCK_02 = Mock(Context) + +class MockDataSource(BaseDataSource): + + data_source_unique_id = "test_source_unique_id" + data_source_name = "test_source_name" + + def read_data(self, spark, params) -> DataFrame: + return None + + FEATURE_GROUP_DATA_SOURCE: List[FeatureGroupDataSource] = [ FeatureGroupDataSource( name="feature-group-01", @@ -68,16 +80,18 @@ ), ] -RAW_DATA_INPUT: Sequence[Union[CSVDataSource, ParquetDataSource]] = [ +RAW_DATA_INPUT: Sequence[Union[CSVDataSource, ParquetDataSource, BaseDataSource]] = [ CSVDataSource(s3_uri="raw-data-uri-01"), CSVDataSource(s3_uri="raw-data-uri-02"), ParquetDataSource(s3_uri="raw-data-uri-03"), + MockDataSource(), ] RAW_DATA_INPUT_ARTIFACTS: List[Artifact] = [ Artifact(artifact_arn="artifact-01-arn"), Artifact(artifact_arn="artifact-02-arn"), Artifact(artifact_arn="artifact-03-arn"), + Artifact(artifact_arn="artifact-04-arn"), ] PIPELINE_SCHEDULE = PipelineSchedule( diff --git a/tests/unit/sagemaker/feature_store/feature_processor/lineage/test_feature_processor_lineage.py b/tests/unit/sagemaker/feature_store/feature_processor/lineage/test_feature_processor_lineage.py index f0cdead06d..f755905ddd 100644 --- a/tests/unit/sagemaker/feature_store/feature_processor/lineage/test_feature_processor_lineage.py +++ b/tests/unit/sagemaker/feature_store/feature_processor/lineage/test_feature_processor_lineage.py @@ -119,6 +119,7 @@ def test_create_lineage_when_no_lineage_exists_with_fg_only(): RAW_DATA_INPUT_ARTIFACTS[0], RAW_DATA_INPUT_ARTIFACTS[1], RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], ], ) as retrieve_raw_data_artifact_method, patch.object( S3LineageEntityHandler, @@ -264,6 +265,7 @@ def test_create_lineage_when_no_lineage_exists_with_raw_data_only(): RAW_DATA_INPUT_ARTIFACTS[0], RAW_DATA_INPUT_ARTIFACTS[1], RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], ], ) as retrieve_raw_data_artifact_method, patch.object( S3LineageEntityHandler, @@ -329,9 +331,10 @@ def test_create_lineage_when_no_lineage_exists_with_raw_data_only(): call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), ] ) - assert 3 == retrieve_raw_data_artifact_method.call_count + assert 4 == retrieve_raw_data_artifact_method.call_count create_transformation_code_artifact_method.assert_called_once_with( transformation_code=TRANSFORMATION_CODE_INPUT_1, @@ -411,6 +414,7 @@ def test_create_lineage_when_no_lineage_exists_with_fg_and_raw_data_with_tags(): RAW_DATA_INPUT_ARTIFACTS[0], RAW_DATA_INPUT_ARTIFACTS[1], RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], ], ) as retrieve_raw_data_artifact_method, patch.object( S3LineageEntityHandler, @@ -489,9 +493,10 @@ def test_create_lineage_when_no_lineage_exists_with_fg_and_raw_data_with_tags(): call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), ] ) - assert 3 == retrieve_raw_data_artifact_method.call_count + assert 4 == retrieve_raw_data_artifact_method.call_count create_transformation_code_artifact_method.assert_called_once_with( transformation_code=TRANSFORMATION_CODE_INPUT_1, @@ -570,6 +575,7 @@ def test_create_lineage_when_no_lineage_exists_with_no_transformation_code(): RAW_DATA_INPUT_ARTIFACTS[0], RAW_DATA_INPUT_ARTIFACTS[1], RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], ], ) as retrieve_raw_data_artifact_method, patch.object( S3LineageEntityHandler, @@ -648,9 +654,10 @@ def test_create_lineage_when_no_lineage_exists_with_no_transformation_code(): call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), ] ) - assert 3 == retrieve_raw_data_artifact_method.call_count + assert 4 == retrieve_raw_data_artifact_method.call_count create_transformation_code_artifact_method.assert_called_once_with( transformation_code=None, @@ -727,6 +734,7 @@ def test_create_lineage_when_already_exist_with_no_version_change(): RAW_DATA_INPUT_ARTIFACTS[0], RAW_DATA_INPUT_ARTIFACTS[1], RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], ], ) as retrieve_raw_data_artifact_method, patch.object( S3LineageEntityHandler, @@ -808,9 +816,10 @@ def test_create_lineage_when_already_exist_with_no_version_change(): call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), ] ) - assert 3 == retrieve_raw_data_artifact_method.call_count + assert 4 == retrieve_raw_data_artifact_method.call_count create_transformation_code_artifact_method.assert_called_once_with( transformation_code=TRANSFORMATION_CODE_INPUT_1, @@ -1133,6 +1142,7 @@ def test_create_lineage_when_already_exist_with_changed_input_fg(): RAW_DATA_INPUT_ARTIFACTS[0], RAW_DATA_INPUT_ARTIFACTS[1], RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], ], ) as retrieve_raw_data_artifact_method, patch.object( S3LineageEntityHandler, @@ -1210,9 +1220,10 @@ def test_create_lineage_when_already_exist_with_changed_input_fg(): call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), ] ) - assert 3 == retrieve_raw_data_artifact_method.call_count + assert 4 == retrieve_raw_data_artifact_method.call_count create_transformation_code_artifact_method.assert_called_once_with( transformation_code=TRANSFORMATION_CODE_INPUT_1, @@ -1349,6 +1360,7 @@ def test_create_lineage_when_already_exist_with_changed_output_fg(): RAW_DATA_INPUT_ARTIFACTS[0], RAW_DATA_INPUT_ARTIFACTS[1], RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], ], ) as retrieve_raw_data_artifact_method, patch.object( S3LineageEntityHandler, @@ -1430,9 +1442,10 @@ def test_create_lineage_when_already_exist_with_changed_output_fg(): call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), ] ) - assert 3 == retrieve_raw_data_artifact_method.call_count + assert 4 == retrieve_raw_data_artifact_method.call_count create_transformation_code_artifact_method.assert_called_once_with( transformation_code=TRANSFORMATION_CODE_INPUT_1, @@ -1569,6 +1582,7 @@ def test_create_lineage_when_already_exist_with_changed_transformation_code(): RAW_DATA_INPUT_ARTIFACTS[0], RAW_DATA_INPUT_ARTIFACTS[1], RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], ], ) as retrieve_raw_data_artifact_method, patch.object( S3LineageEntityHandler, @@ -1650,9 +1664,10 @@ def test_create_lineage_when_already_exist_with_changed_transformation_code(): call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), ] ) - assert 3 == retrieve_raw_data_artifact_method.call_count + assert 4 == retrieve_raw_data_artifact_method.call_count create_transformation_code_artifact_method.assert_called_once_with( transformation_code=TRANSFORMATION_CODE_INPUT_2, @@ -1769,6 +1784,7 @@ def test_create_lineage_when_already_exist_with_last_transformation_code_as_none RAW_DATA_INPUT_ARTIFACTS[0], RAW_DATA_INPUT_ARTIFACTS[1], RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], ], ) as retrieve_raw_data_artifact_method, patch.object( S3LineageEntityHandler, @@ -1850,9 +1866,10 @@ def test_create_lineage_when_already_exist_with_last_transformation_code_as_none call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), ] ) - assert 3 == retrieve_raw_data_artifact_method.call_count + assert 4 == retrieve_raw_data_artifact_method.call_count create_transformation_code_artifact_method.assert_called_once_with( transformation_code=TRANSFORMATION_CODE_INPUT_2, @@ -1957,6 +1974,7 @@ def test_create_lineage_when_already_exist_with_all_previous_transformation_code RAW_DATA_INPUT_ARTIFACTS[0], RAW_DATA_INPUT_ARTIFACTS[1], RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], ], ) as retrieve_raw_data_artifact_method, patch.object( S3LineageEntityHandler, @@ -2037,9 +2055,10 @@ def test_create_lineage_when_already_exist_with_all_previous_transformation_code call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), ] ) - assert 3 == retrieve_raw_data_artifact_method.call_count + assert 4 == retrieve_raw_data_artifact_method.call_count create_transformation_code_artifact_method.assert_called_once_with( transformation_code=TRANSFORMATION_CODE_INPUT_2, @@ -2141,6 +2160,7 @@ def test_create_lineage_when_already_exist_with_removed_transformation_code(): RAW_DATA_INPUT_ARTIFACTS[0], RAW_DATA_INPUT_ARTIFACTS[1], RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], ], ) as retrieve_raw_data_artifact_method, patch.object( S3LineageEntityHandler, @@ -2222,9 +2242,10 @@ def test_create_lineage_when_already_exist_with_removed_transformation_code(): call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), ] ) - assert 3 == retrieve_raw_data_artifact_method.call_count + assert 4 == retrieve_raw_data_artifact_method.call_count create_transformation_code_artifact_method.assert_called_once_with( transformation_code=None, @@ -2471,6 +2492,7 @@ def test_upsert_tags_for_lineage_resources(): RAW_DATA_INPUT_ARTIFACTS[0], RAW_DATA_INPUT_ARTIFACTS[1], RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], ], ) as retrieve_raw_data_artifact_method, patch.object( PipelineLineageEntityHandler, @@ -2518,6 +2540,7 @@ def test_upsert_tags_for_lineage_resources(): call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=mock_session), call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=mock_session), call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=mock_session), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=mock_session), ] ) diff --git a/tests/unit/sagemaker/feature_store/feature_processor/lineage/test_s3_lineage_entity_handler.py b/tests/unit/sagemaker/feature_store/feature_processor/lineage/test_s3_lineage_entity_handler.py index 7fc49df59a..d71605fc0e 100644 --- a/tests/unit/sagemaker/feature_store/feature_processor/lineage/test_s3_lineage_entity_handler.py +++ b/tests/unit/sagemaker/feature_store/feature_processor/lineage/test_s3_lineage_entity_handler.py @@ -24,6 +24,7 @@ TRANSFORMATION_CODE_ARTIFACT_1, TRANSFORMATION_CODE_INPUT_1, LAST_UPDATE_TIME, + MockDataSource, ) from test_pipeline_lineage_entity_handler import SAGEMAKER_SESSION_MOCK @@ -94,6 +95,60 @@ def test_retrieve_raw_data_artifact_when_artifact_does_not_exist(): ) +def test_retrieve_user_defined_raw_data_artifact_when_artifact_already_exist(): + data_source = MockDataSource() + with patch.object(Artifact, "list", return_value=[ARTIFACT_SUMMARY]) as artifact_list_method: + with patch.object(Artifact, "load", return_value=ARTIFACT_RESULT) as artifact_load_method: + with patch.object( + Artifact, "create", return_value=ARTIFACT_RESULT + ) as artifact_create_method: + result = S3LineageEntityHandler.retrieve_raw_data_artifact( + raw_data=data_source, sagemaker_session=SAGEMAKER_SESSION_MOCK + ) + + assert result == ARTIFACT_RESULT + + artifact_list_method.assert_called_once_with( + source_uri=data_source.data_source_unique_id, sagemaker_session=SAGEMAKER_SESSION_MOCK + ) + + artifact_load_method.assert_called_once_with( + artifact_arn=ARTIFACT_SUMMARY.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_create_method.assert_not_called() + + +def test_retrieve_user_defined_raw_data_artifact_when_artifact_does_not_exist(): + data_source = MockDataSource() + with patch.object(Artifact, "list", return_value=[]) as artifact_list_method: + with patch.object(Artifact, "load", return_value=ARTIFACT_RESULT) as artifact_load_method: + with patch.object( + Artifact, "create", return_value=ARTIFACT_RESULT + ) as artifact_create_method: + result = S3LineageEntityHandler.retrieve_raw_data_artifact( + raw_data=data_source, sagemaker_session=SAGEMAKER_SESSION_MOCK + ) + + assert result == ARTIFACT_RESULT + + artifact_list_method.assert_called_once_with( + source_uri=data_source.data_source_unique_id, sagemaker_session=SAGEMAKER_SESSION_MOCK + ) + + artifact_load_method.assert_not_called() + + artifact_create_method.assert_called_once_with( + source_uri=data_source.data_source_unique_id, + artifact_type="DataSet", + artifact_name=data_source.data_source_name, + properties=None, + source_types=None, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + def test_create_transformation_code_artifact(): with patch.object( Artifact, "create", return_value=TRANSFORMATION_CODE_ARTIFACT_1 diff --git a/tests/unit/sagemaker/feature_store/feature_processor/test_config_uploader.py b/tests/unit/sagemaker/feature_store/feature_processor/test_config_uploader.py index fa664ac695..def1499a24 100644 --- a/tests/unit/sagemaker/feature_store/feature_processor/test_config_uploader.py +++ b/tests/unit/sagemaker/feature_store/feature_processor/test_config_uploader.py @@ -27,7 +27,7 @@ _JobSettings, RUNTIME_SCRIPTS_CHANNEL_NAME, REMOTE_FUNCTION_WORKSPACE, - SPARK_CONF_WORKSPACE, + SPARK_CONF_CHANNEL_NAME, ) from sagemaker.remote_function.spark_config import SparkConfig from sagemaker.session import Session @@ -218,7 +218,7 @@ def test_prepare_step_input_channel( s3_data=f"{config_uploader.remote_decorator_config.s3_root_uri}/pipeline_name/sm_rf_user_ws", s3_data_type="S3Prefix", ), - SPARK_CONF_WORKSPACE: mock_training_input(s3_data="path_d", s3_data_type="S3Prefix"), + SPARK_CONF_CHANNEL_NAME: mock_training_input(s3_data="path_d", s3_data_type="S3Prefix"), } assert spark_dependency_paths == { diff --git a/tests/unit/sagemaker/feature_store/feature_processor/test_data_helpers.py b/tests/unit/sagemaker/feature_store/feature_processor/test_data_helpers.py index a539c1b8d0..9c4f0fef49 100644 --- a/tests/unit/sagemaker/feature_store/feature_processor/test_data_helpers.py +++ b/tests/unit/sagemaker/feature_store/feature_processor/test_data_helpers.py @@ -137,6 +137,7 @@ def create_fp_config( target_stores=None, enable_ingestion=True, parameters=None, + spark_config=None, ): """Helper method to create a FeatureProcessorConfig with fewer arguments.""" @@ -147,4 +148,5 @@ def create_fp_config( target_stores=target_stores, enable_ingestion=enable_ingestion, parameters=parameters, + spark_config=spark_config, ) diff --git a/tests/unit/sagemaker/feature_store/feature_processor/test_data_source.py b/tests/unit/sagemaker/feature_store/feature_processor/test_data_source.py new file mode 100644 index 0000000000..06c1a5a351 --- /dev/null +++ b/tests/unit/sagemaker/feature_store/feature_processor/test_data_source.py @@ -0,0 +1,34 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from pyspark.sql import DataFrame + +from sagemaker.feature_store.feature_processor._data_source import PySparkDataSource + + +def test_pyspark_data_source(): + class TestDataSource(PySparkDataSource): + + data_source_unique_id = "test_unique_id" + data_source_name = "test_source_name" + + def read_data(self, spark, params) -> DataFrame: + return None + + test_data_source = TestDataSource() + + assert test_data_source.data_source_name == "test_source_name" + assert test_data_source.data_source_unique_id == "test_unique_id" + assert test_data_source.read_data(spark=None, params=None) is None diff --git a/tests/unit/sagemaker/feature_store/feature_processor/test_factory.py b/tests/unit/sagemaker/feature_store/feature_processor/test_factory.py index d163f2b6f3..8c76c5406f 100644 --- a/tests/unit/sagemaker/feature_store/feature_processor/test_factory.py +++ b/tests/unit/sagemaker/feature_store/feature_processor/test_factory.py @@ -31,6 +31,7 @@ InputValidator, SparkUDFSignatureValidator, InputOffsetValidator, + BaseDataSourceValidator, ) from sagemaker.session import Session @@ -44,6 +45,7 @@ def test_get_validation_chain(): InputValidator, FeatureProcessorArgValidator, InputOffsetValidator, + BaseDataSourceValidator, SparkUDFSignatureValidator, } == {type(instance) for instance in result.validators} @@ -53,7 +55,7 @@ def test_get_udf_wrapper(): udf_wrapper = Mock(UDFWrapper) with patch.object( - UDFWrapperFactory, "get_udf_wrapper", return_value=udf_wrapper + UDFWrapperFactory, "_get_spark_udf_wrapper", return_value=udf_wrapper ) as get_udf_wrapper_method: result = UDFWrapperFactory.get_udf_wrapper(fp_config) diff --git a/tests/unit/sagemaker/feature_store/feature_processor/test_feature_processor_config.py b/tests/unit/sagemaker/feature_store/feature_processor/test_feature_processor_config.py index e6aece5c64..20c289d8a2 100644 --- a/tests/unit/sagemaker/feature_store/feature_processor/test_feature_processor_config.py +++ b/tests/unit/sagemaker/feature_store/feature_processor/test_feature_processor_config.py @@ -31,6 +31,7 @@ def test_feature_processor_config_is_immutable(): target_stores=None, enable_ingestion=True, parameters=None, + spark_config=None, ) with pytest.raises(attr.exceptions.FrozenInstanceError): diff --git a/tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py b/tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py index 57ec73b377..2d193e9d30 100644 --- a/tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py +++ b/tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py @@ -49,8 +49,8 @@ SPARK_APP_SCRIPT_PATH, RUNTIME_SCRIPTS_CHANNEL_NAME, REMOTE_FUNCTION_WORKSPACE, - SPARK_CONF_WORKSPACE, ENTRYPOINT_SCRIPT_NAME, + SPARK_CONF_CHANNEL_NAME, ) from sagemaker.workflow.parameters import Parameter, ParameterTypeEnum from sagemaker.workflow.retry import ( @@ -316,7 +316,7 @@ def test_to_pipeline( REMOTE_FUNCTION_WORKSPACE: mock_training_input( s3_data=f"{S3_URI}/pipeline_name/sm_rf_user_ws", s3_data_type="S3Prefix" ), - SPARK_CONF_WORKSPACE: mock_training_input(s3_data="path_d", s3_data_type="S3Prefix"), + SPARK_CONF_CHANNEL_NAME: mock_training_input(s3_data="path_d", s3_data_type="S3Prefix"), }, retry_policies=[ StepRetryPolicy( diff --git a/tests/unit/sagemaker/feature_store/feature_processor/test_spark_session_factory.py b/tests/unit/sagemaker/feature_store/feature_processor/test_spark_session_factory.py index 0c7d5f0de9..f2a1daf788 100644 --- a/tests/unit/sagemaker/feature_store/feature_processor/test_spark_session_factory.py +++ b/tests/unit/sagemaker/feature_store/feature_processor/test_spark_session_factory.py @@ -33,7 +33,8 @@ def env_helper(): def test_spark_session_factory_configuration(): env_helper = Mock() - spark_session_factory = SparkSessionFactory(env_helper) + spark_config = {"spark.test.key": "spark.test.value"} + spark_session_factory = SparkSessionFactory(env_helper, spark_config) spark_configs = dict(spark_session_factory._get_spark_configs(is_training_job=False)) jsc_hadoop_configs = dict(spark_session_factory._get_jsc_hadoop_configs()) @@ -65,6 +66,8 @@ def test_spark_session_factory_configuration(): assert spark_configs.get("spark.hadoop.fs.trash.interval") == "0" assert spark_configs.get("spark.port.maxRetries") == "50" + assert spark_configs.get("spark.test.key") == "spark.test.value" + assert jsc_hadoop_configs.get("mapreduce.fileoutputcommitter.marksuccessfuljobs") == "false" # Verify configurations when not running on a training job @@ -79,9 +82,11 @@ def test_spark_session_factory_configuration(): def test_spark_session_factory_configuration_on_training_job(): env_helper = Mock() - spark_session_factory = SparkSessionFactory(env_helper) + spark_config = {"spark.test.key": "spark.test.value"} + spark_session_factory = SparkSessionFactory(env_helper, spark_config) spark_config = spark_session_factory._get_spark_configs(is_training_job=True) + assert dict(spark_config).get("spark.test.key") == "spark.test.value" assert all(tup[0] != "spark.jars" for tup in spark_config) assert all(tup[0] != "spark.jars.packages" for tup in spark_config) diff --git a/tests/unit/sagemaker/feature_store/feature_processor/test_udf_arg_provider.py b/tests/unit/sagemaker/feature_store/feature_processor/test_udf_arg_provider.py index 64dc12ab31..5fd9230680 100644 --- a/tests/unit/sagemaker/feature_store/feature_processor/test_udf_arg_provider.py +++ b/tests/unit/sagemaker/feature_store/feature_processor/test_udf_arg_provider.py @@ -15,19 +15,20 @@ import pytest import test_data_helpers as tdh -from mock import Mock +from mock import Mock, patch from pyspark.sql import DataFrame, SparkSession from sagemaker.feature_store.feature_processor._input_loader import InputLoader from sagemaker.feature_store.feature_processor._params_loader import ParamsLoader from sagemaker.feature_store.feature_processor._spark_factory import SparkSessionFactory from sagemaker.feature_store.feature_processor._udf_arg_provider import SparkArgProvider +from sagemaker.feature_store.feature_processor._data_source import PySparkDataSource @pytest.fixture def params_loader(): params_loader = Mock(ParamsLoader) - params_loader.get_parameter_args.return_value = Mock() + params_loader.get_parameter_args = Mock(return_value={"params": {"key": "value"}}) return params_loader @@ -41,6 +42,11 @@ def s3_uri_as_spark_df(): return Mock(DataFrame) +@pytest.fixture +def base_data_source_as_spark_df(): + return Mock(DataFrame) + + @pytest.fixture def input_loader(feature_group_as_spark_df, s3_uri_as_spark_df): input_loader = Mock(InputLoader) @@ -65,6 +71,15 @@ def spark_arg_provider(params_loader, input_loader, spark_session_factory): return SparkArgProvider(params_loader, input_loader, spark_session_factory) +class MockDataSource(PySparkDataSource): + + data_source_unique_id = "test_id" + data_source_name = "test_source" + + def read_data(self, spark, params) -> DataFrame: + return Mock(DataFrame) + + def test_provide_additional_kw_args(spark_arg_provider, spark_session): def udf(fg_input, s3_input, params, spark): return None @@ -251,3 +266,15 @@ def udf_only_spark(input_fg=None, input_s3_uri=None, spark=None) -> DataFrame: assert inputs.keys() == {"input_fg", "input_s3_uri"} assert inputs["input_fg"] == feature_group_as_spark_df assert inputs["input_s3_uri"] == s3_uri_as_spark_df + + +def test_provide_input_arg_for_base_data_source(spark_arg_provider, params_loader, spark_session): + fp_config = tdh.create_fp_config(inputs=[MockDataSource()], output=tdh.OUTPUT_FEATURE_GROUP_ARN) + + def udf(input_df) -> DataFrame: + return input_df + + with patch.object(MockDataSource, "read_data", return_value=Mock(DataFrame)) as mock_read: + spark_arg_provider.provide_input_args(udf, fp_config) + mock_read.assert_called_with(spark=spark_session, params={"key": "value"}) + params_loader.get_parameter_args.assert_called_with(fp_config) diff --git a/tests/unit/sagemaker/feature_store/feature_processor/test_validation.py b/tests/unit/sagemaker/feature_store/feature_processor/test_validation.py index 000c8d617b..8e0115afd2 100644 --- a/tests/unit/sagemaker/feature_store/feature_processor/test_validation.py +++ b/tests/unit/sagemaker/feature_store/feature_processor/test_validation.py @@ -14,16 +14,23 @@ from __future__ import absolute_import from typing import Callable +from pyspark.sql import DataFrame import pytest import test_data_helpers as tdh +import string +import random from mock import Mock from sagemaker.feature_store.feature_processor._validation import ( SparkUDFSignatureValidator, Validator, ValidatorChain, + BaseDataSourceValidator, +) +from sagemaker.feature_store.feature_processor._data_source import ( + BaseDataSource, ) @@ -148,3 +155,40 @@ def invalid_spark_position(spark, fg_data_source, s3_data_source): return None SparkUDFSignatureValidator().validate(invalid_spark_position, fp_config) + + +@pytest.mark.parametrize( + "data_source_name, data_source_unique_id, error_pattern", + [ + ("$_invalid_source", "unique_id", "data_source_name of input does not match pattern '.*'."), + ("", "unique_id", "data_source_name of input does not match pattern '.*'."), + ( + "source", + "".join(random.choices(string.ascii_uppercase, k=2050)), + "data_source_unique_id of input does not match pattern '.*'.", + ), + ("source", "", "data_source_unique_id of input does not match pattern '.*'."), + ], +) +def test_spark_udf_signature_validator_udf_invalid_base_data_source( + data_source_name, data_source_unique_id, error_pattern +): + class TestInValidCustomDataSource(BaseDataSource): + + data_source_name = None + data_source_unique_id = None + + def read_data(self, spark, params) -> DataFrame: + return None + + test_data_source = TestInValidCustomDataSource() + test_data_source.data_source_name = data_source_name + test_data_source.data_source_unique_id = data_source_unique_id + + fp_config = tdh.create_fp_config(inputs=[test_data_source]) + + def udf(input_data_source, params, spark): + return None + + with pytest.raises(ValueError, match=error_pattern): + BaseDataSourceValidator().validate(udf, fp_config)