diff --git a/aws_lambda_powertools/utilities/data_masking/base.py b/aws_lambda_powertools/utilities/data_masking/base.py index 3eed26045c2..1541e6f761b 100644 --- a/aws_lambda_powertools/utilities/data_masking/base.py +++ b/aws_lambda_powertools/utilities/data_masking/base.py @@ -6,6 +6,7 @@ from __future__ import annotations +import dataclasses import functools import logging import warnings @@ -27,6 +28,51 @@ logger = logging.getLogger(__name__) +def prepare_data(data: Any, _visited: set[int] | None = None) -> Any: + """ + Recursively convert complex objects into dictionaries or simple types. + Handles dataclasses, Pydantic models, and prevents circular references. + """ + _visited = _visited or set() + + # Handle circular references and primitive types + data_id = id(data) + if data_id in _visited or isinstance(data, (str, int, float, bool, type(None))): + return data + + _visited.add(data_id) + + # Define handlers as (condition, transformer) pairs + handlers: list[tuple[Callable[[Any], bool], Callable[[Any], Any]]] = [ + # Dataclasses + (lambda x: hasattr(x, "__dataclass_fields__"), lambda x: prepare_data(dataclasses.asdict(x), _visited)), + # Pydantic models + (lambda x: callable(getattr(x, "model_dump", None)), lambda x: prepare_data(x.model_dump(), _visited)), + # Objects with dict() method + ( + lambda x: callable(getattr(x, "dict", None)) and not isinstance(x, dict), + lambda x: prepare_data(x.dict(), _visited), + ), + # Dictionaries + ( + lambda x: isinstance(x, dict), + lambda x: {prepare_data(k, _visited): prepare_data(v, _visited) for k, v in x.items()}, + ), + # Lists, tuples, sets + (lambda x: isinstance(x, (list, tuple, set)), lambda x: type(x)(prepare_data(item, _visited) for item in x)), + # Objects with __dict__ + (lambda x: hasattr(x, "__dict__"), lambda x: prepare_data(vars(x), _visited)), + ] + + # Find and apply the first matching handler + for condition, transformer in handlers: + if condition(data): + return transformer(data) + + # Default fallback + return data + + class DataMasking: """ The DataMasking class orchestrates erasing, encrypting, and decrypting @@ -93,6 +139,7 @@ def encrypt( data_masker = DataMasking(provider=encryption_provider) encrypted = data_masker.encrypt({"secret": "value"}) """ + data = prepare_data(data) return self._apply_action( data=data, fields=None, @@ -135,7 +182,7 @@ def decrypt( data_masker = DataMasking(provider=encryption_provider) encrypted = data_masker.decrypt(encrypted_data) """ - + data = prepare_data(data) return self._apply_action( data=data, fields=None, @@ -184,6 +231,7 @@ def erase( Any The data with sensitive information erased or masked. """ + data = prepare_data(data) if masking_rules: return self._apply_masking_rules(data=data, masking_rules=masking_rules) else: diff --git a/docs/utilities/data_masking.md b/docs/utilities/data_masking.md index 1de6419c390..5abcc185938 100644 --- a/docs/utilities/data_masking.md +++ b/docs/utilities/data_masking.md @@ -440,21 +440,41 @@ Note that the return will be a deserialized JSON and your desired fields updated ### Data serialization -???+ note "Current limitations" - 1. Python classes, `Dataclasses`, and `Pydantic models` are not supported yet. +???+ tip "Extended input support" + We support `Pydantic models`, `Dataclasses`, and custom classes with `dict()` or `__dict__` for input. + + These types are automatically converted into dictionaries before `masking` and `encrypting` operations. Please not that we **don't convert back** to the original type, and the returned object will be a dictionary. Before we traverse the data structure, we perform two important operations on input data: 1. If `JSON string`, **deserialize** using default or provided deserializer. -2. If `dictionary`, **normalize** into `JSON` to prevent traversing unsupported data types. - -When decrypting, we revert the operation to restore the original data structure. +2. If `dictionary or complex types`, **normalize** into `JSON` to prevent traversing unsupported data types. For compatibility or performance, you can optionally pass your own JSON serializer and deserializer to replace `json.dumps` and `json.loads` respectively: -```python hl_lines="17-18" title="advanced_custom_serializer.py" ---8<-- "examples/data_masking/src/advanced_custom_serializer.py" -``` +=== "Working with custom types" + + ```python + --8<-- "examples/data_masking/src/working_with_custom_types.py" + ``` + +=== "Working with Pydantic" + + ```python + --8<-- "examples/data_masking/src/working_with_pydantic_types.py" + ``` + +=== "Working with dataclasses" + + ```python + --8<-- "examples/data_masking/src/working_with_dataclass_types.py" + ``` + +=== "Working with serializer" + + ```python + --8<-- "examples/data_masking/src/advanced_custom_serializer.py" + ``` ### Using multiple keys diff --git a/examples/data_masking/src/working_with_custom_types.py b/examples/data_masking/src/working_with_custom_types.py new file mode 100644 index 00000000000..833fe3465ec --- /dev/null +++ b/examples/data_masking/src/working_with_custom_types.py @@ -0,0 +1,17 @@ +from aws_lambda_powertools.utilities.data_masking import DataMasking + +data_masker = DataMasking() + + +class User: + def __init__(self, name, age): + self.name = name + self.age = age + + def dict(self): + return {"name": self.name, "age": self.age} + + +def lambda_handler(event, context): + user = User("powertools", 42) + return data_masker.erase(user, fields=["age"]) diff --git a/examples/data_masking/src/working_with_dataclass_types.py b/examples/data_masking/src/working_with_dataclass_types.py new file mode 100644 index 00000000000..bcd9b13de6d --- /dev/null +++ b/examples/data_masking/src/working_with_dataclass_types.py @@ -0,0 +1,16 @@ +from dataclasses import dataclass + +from aws_lambda_powertools.utilities.data_masking import DataMasking + +data_masker = DataMasking() + + +@dataclass +class User: + name: str + age: int + + +def lambda_handler(event, context): + user = User(name="powertools", age=42) + return data_masker.erase(user, fields=["age"]) diff --git a/examples/data_masking/src/working_with_pydantic_types.py b/examples/data_masking/src/working_with_pydantic_types.py new file mode 100644 index 00000000000..b9f3db293b5 --- /dev/null +++ b/examples/data_masking/src/working_with_pydantic_types.py @@ -0,0 +1,15 @@ +from pydantic import BaseModel + +from aws_lambda_powertools.utilities.data_masking import DataMasking + +data_masker = DataMasking() + + +class User(BaseModel): + name: str + age: int + + +def lambda_handler(event, context): + user = User(name="powertools", age=42) + return data_masker.erase(user, fields=["age"]) diff --git a/tests/functional/data_masking/_pydantic/test_data_masking_with_pydantic.py b/tests/functional/data_masking/_pydantic/test_data_masking_with_pydantic.py new file mode 100644 index 00000000000..b2bc94ed2ef --- /dev/null +++ b/tests/functional/data_masking/_pydantic/test_data_masking_with_pydantic.py @@ -0,0 +1,217 @@ +import dataclasses + +import pytest +from pydantic import BaseModel + +from aws_lambda_powertools.utilities.data_masking.base import DataMasking, prepare_data +from aws_lambda_powertools.utilities.data_masking.constants import DATA_MASKING_STRING + + +@pytest.fixture +def data_masker() -> DataMasking: + return DataMasking() + + +def test_prepare_data_primitive(): + assert prepare_data("hello") == "hello" + assert prepare_data(123) == 123 + assert prepare_data(3.14) == pytest.approx(3.14) + assert prepare_data(True) is True + assert prepare_data(None) is None + + +def test_prepare_data_dict_no_change(): + data = {"x": "y", "z": 10} + result = prepare_data(data) + assert isinstance(result, dict) + assert result == data + + +def test_prepare_data_list(): + data = [1, "a", {"b": 2}] + result = prepare_data(data) + assert isinstance(result, list) + assert result == [1, "a", {"b": 2}] + + +def test_prepare_data_tuple(): + data = (1, 2, {"a": 3}) + result = prepare_data(data) + assert isinstance(result, tuple) + assert result[2]["a"] == 3 + + +def test_prepare_data_set(): + data = {1, 2, 3} + result = prepare_data(data) + assert isinstance(result, set) + assert result == {1, 2, 3} + + +def test_prepare_data_dataclass(): + @dataclasses.dataclass + class MyDataClass: + name: str + age: int + + instance = MyDataClass(name="delta", age=50) + result = prepare_data(instance) + assert isinstance(result, dict) + assert result["name"] == "delta" + assert result["age"] == 50 + + +def test_prepare_data_pydantic(): + class MyPydanticModel(BaseModel): + name: str + age: int + + instance = MyPydanticModel(name="alpha", age=30) + result = prepare_data(instance) + assert isinstance(result, dict) + assert result["name"] == "alpha" + assert result["age"] == 30 + + +def test_prepare_data_custom_class_with_dict(): + class MyCustom: + def __init__(self, name, age): + self.name = name + self.age = age + + def dict(self): + return {"name": self.name, "age": self.age} + + instance = MyCustom("beta", 40) + result = prepare_data(instance) + assert isinstance(result, dict) + assert result["name"] == "beta" + assert result["age"] == 40 + + +def test_prepare_data_fallback_dict_via_dunder(): + class WithDict: + def __init__(self, value): + self.value = value + + instance = WithDict(100) + result = prepare_data(instance) + assert isinstance(result, dict) + assert result["value"] == 100 + + +def test_prepare_data_nested_structure(): + @dataclasses.dataclass + class NestedDC: + x: int + y: str + + class NestedPM(BaseModel): + a: int + b: str + + class NestedCustom: + def __init__(self, z): + self.z = z + + def dict(self): + return {"z": self.z} + + data = { + "dc": NestedDC(x=10, y="foo"), + "pm": NestedPM(a=5, b="bar"), + "custom": NestedCustom(z="baz"), + "nested": {"list": [NestedDC(x=1, y="inner"), NestedPM(a=2, b="inner2")]}, + } + result = prepare_data(data) + assert result["dc"]["x"] == 10 + assert result["dc"]["y"] == "foo" + assert result["pm"]["a"] == 5 + assert result["pm"]["b"] == "bar" + assert result["custom"]["z"] == "baz" + assert result["nested"]["list"][0]["y"] == "inner" + assert result["nested"]["list"][1]["a"] == 2 + + +def test_prepare_data_circular_reference(): + data = {"a": 1} + data["self"] = data + result = prepare_data(data) + assert result["a"] == 1 + assert "self" in result + + +class MyPydanticModel(BaseModel): + name: str + age: int + + +@dataclasses.dataclass +class MyDataClass: + name: str + age: int + + +class MyCustomClass: + def __init__(self, name, age): + self.name = name + self.age = age + + def dict(self): + return {"name": self.name, "age": self.age} + + +def test_erase_on_pydantic_model(data_masker): + instance = MyPydanticModel(name="powertools", age=5) + result = data_masker.erase(instance, fields=["age"]) + assert isinstance(result, dict) + assert result["age"] == DATA_MASKING_STRING + assert result["name"] == "powertools" + + +def test_erase_on_dataclass(data_masker): + instance = MyDataClass(name="powertools", age=5) + result = data_masker.erase(instance, fields=["age"]) + assert isinstance(result, dict) + assert result["age"] == DATA_MASKING_STRING + assert result["name"] == "powertools" + + +def test_erase_on_custom_class(data_masker): + instance = MyCustomClass("powertools", 5) + result = data_masker.erase(instance, fields=["age"]) + assert isinstance(result, dict) + assert result["age"] == DATA_MASKING_STRING + assert result["name"] == "powertools" + + +def test_erase_on_nested_complex_structure(data_masker): + @dataclasses.dataclass + class NestedDC: + value: int + + class NestedPM(BaseModel): + value: int + + class MyCustomClass: + def __init__(self, name, age): + self.name = name + self.age = age + + def dict(self): + return {"name": self.name, "age": self.age} + + data = { + "pydantic": NestedPM(value=10), + "dataclass": NestedDC(value=20), + "custom": MyCustomClass("example", 30), + "plain_dict": {"value": 40}, + "list": [NestedPM(value=50), {"value": 60}], + } + result = data_masker.erase(data, fields=["$..value"]) + assert result["pydantic"]["value"] == DATA_MASKING_STRING + assert result["dataclass"]["value"] == DATA_MASKING_STRING + assert result["custom"] == {"name": "example", "age": 30} + assert result["plain_dict"]["value"] == DATA_MASKING_STRING + assert result["list"][0]["value"] == DATA_MASKING_STRING + assert result["list"][1]["value"] == DATA_MASKING_STRING