Skip to content

feat(data-masking): add support for Pydantic models, dataclasses, and standard classes #6413

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 69 additions & 2 deletions aws_lambda_powertools/utilities/data_masking/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,72 @@

logger = logging.getLogger(__name__)


def prepare_data(data: Any, _visited: set[int] | None = None) -> Any:
"""
Recursively convert complex objects into dictionaries (or simple types) so that they can be
processed by the data masking utility. This function handles:

- Dataclasses (using dataclasses.asdict)
- Pydantic models (using model_dump)
- Custom classes with a dict() method
- Fallback to using __dict__ if available
- Recursively traverses dicts, lists, tuples, and sets
- Guards against circular references

Parameters
----------
data : Any
The input data which may be a complex type.
_visited : set, optional
Internal set of visited object IDs to prevent infinite recursion on cyclic references.

Returns
-------
Any
A primitive type, or a recursively converted structure (dict, list, etc.)
"""
# Initialize _visited set if not provided.
if _visited is None:
_visited = set()

# Prevent circular references by checking if the object's id has been seen.
data_id = id(data)
if data_id in _visited:
return data # Return the object as-is if it has already been processed.
_visited.add(data_id)

# If data is a primitive type, return it directly.
if isinstance(data, (str, int, float, bool, type(None))):
return data

# Handle dataclasses by converting them to a dictionary.
if hasattr(data, "__dataclass_fields__"):
import dataclasses
return prepare_data(dataclasses.asdict(data), _visited=_visited)

# Handle Pydantic models (Pydantic v2 uses 'model_dump').
if callable(getattr(data, "model_dump", None)):
return prepare_data(data.model_dump(), _visited=_visited)

# Handle custom objects that implement a dict() method (but are not already a dict).
if callable(getattr(data, "dict", None)) and not isinstance(data, dict):
return prepare_data(data.dict(), _visited=_visited)

# If data is a dictionary, process both keys and values recursively.
if isinstance(data, dict):
return {prepare_data(key, _visited=_visited): prepare_data(value, _visited=_visited)
for key, value in data.items()}

# If data is an iterable (like a list, tuple, or set), process each element recursively.
if isinstance(data, (list, tuple, set)):
return type(data)(prepare_data(item, _visited=_visited) for item in data)

# As a fallback, if the object has a __dict__, convert its attributes.
if hasattr(data, "__dict__"):
return prepare_data(vars(data), _visited=_visited)

# If no conversion is applicable, return the data as is.
return data
class DataMasking:
"""
The DataMasking class orchestrates erasing, encrypting, and decrypting
Expand Down Expand Up @@ -93,6 +158,7 @@ def encrypt(
data_masker = DataMasking(provider=encryption_provider)
encrypted = data_masker.encrypt({"secret": "value"})
"""
data = prepare_data(data)
return self._apply_action(
data=data,
fields=None,
Expand Down Expand Up @@ -135,7 +201,7 @@ def decrypt(
data_masker = DataMasking(provider=encryption_provider)
encrypted = data_masker.decrypt(encrypted_data)
"""

data = prepare_data(data)
return self._apply_action(
data=data,
fields=None,
Expand Down Expand Up @@ -184,6 +250,7 @@ def erase(
Any
The data with sensitive information erased or masked.
"""
data = prepare_data(data)
if masking_rules:
return self._apply_masking_rules(data=data, masking_rules=masking_rules)
else:
Expand Down
62 changes: 60 additions & 2 deletions docs/utilities/data_masking.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,58 @@ Erasing will remove the original data and replace it with a `*****`. This means
--8<-- "examples/data_masking/src/getting_started_erase_data_output.json"
```

### Supported Input Types

You can pass in different types of Python objects. Internally, we convert these to dictionaries for processing.

Examples below show how `erase()` works with each type.

=== "Pydantic Model"

```python
from pydantic import BaseModel
from aws_lambda_powertools.utilities.data_masking import DataMasking

class User(BaseModel):
name: str
age: int

model = User(name="powertools", age=42)
masked = DataMasking().erase(model, fields=["age"])
print(masked) # {'name': 'powertools', 'age': '*****'}
```

=== "Dataclass"

```python
from dataclasses import dataclass
from aws_lambda_powertools.utilities.data_masking import DataMasking

@dataclass
class User:
name: str
age: int

model = User(name="powertools", age=42)
masked = DataMasking().erase(model, fields=["age"])
print(masked) # {'name': 'powertools', 'age': '*****'}
```

=== "Custom Class with dict()"

```python
class User:
def __init__(self, name, age):
self.name = name
self.age = age
def dict(self):
return {"name": self.name, "age": self.age}

model = User("powertools", 42)
masked = DataMasking().erase(model, fields=["age"])
print(masked) # {'name': 'powertools', 'age': '*****'}
```

#### Custom masking

The `erase` method also supports additional flags for more advanced and flexible masking:
Expand Down Expand Up @@ -440,8 +492,14 @@ Note that the return will be a deserialized JSON and your desired fields updated

### Data serialization

???+ note "Current limitations"
1. Python classes, `Dataclasses`, and `Pydantic models` are not supported yet.
???+ tip "Extended input support"
We now support `Pydantic models`, `Dataclasses`, and custom classes with `dict()` or `__dict__` for input.

These types are automatically converted into dictionaries before masking, encrypting, or decrypting.

However, please note that we don't convert the result **back** into the original object type. The returned object will be a dictionary.

This may impact validation or schema enforcement when using tools like Pydantic.

Before we traverse the data structure, we perform two important operations on input data:

Expand Down
218 changes: 218 additions & 0 deletions tests/unit/data_masking/test_data_masking_input_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
import dataclasses
import pytest
from pydantic import BaseModel

from aws_lambda_powertools.utilities.data_masking.base import DataMasking, prepare_data
from aws_lambda_powertools.utilities.data_masking.constants import DATA_MASKING_STRING


@pytest.fixture
def data_masker() -> DataMasking:
return DataMasking()


def test_prepare_data_primitive():
assert prepare_data("hello") == "hello"
assert prepare_data(123) == 123
assert prepare_data(3.14) == pytest.approx(3.14)
assert prepare_data(True) is True
assert prepare_data(None) is None


def test_prepare_data_dict_no_change():
data = {"x": "y", "z": 10}
result = prepare_data(data)
assert isinstance(result, dict)
assert result == data


def test_prepare_data_list():
data = [1, "a", {"b": 2}]
result = prepare_data(data)
assert isinstance(result, list)
assert result == [1, "a", {"b": 2}]


def test_prepare_data_tuple():
data = (1, 2, {"a": 3})
result = prepare_data(data)
assert isinstance(result, tuple)
assert result[2]["a"] == 3


def test_prepare_data_set():
data = {1, 2, 3}
result = prepare_data(data)
assert isinstance(result, set)
assert result == {1, 2, 3}


def test_prepare_data_dataclass():
@dataclasses.dataclass
class MyDataClass:
name: str
age: int

instance = MyDataClass(name="delta", age=50)
result = prepare_data(instance)
assert isinstance(result, dict)
assert result["name"] == "delta"
assert result["age"] == 50


def test_prepare_data_pydantic():
class MyPydanticModel(BaseModel):
name: str
age: int

instance = MyPydanticModel(name="alpha", age=30)
result = prepare_data(instance)
assert isinstance(result, dict)
assert result["name"] == "alpha"
assert result["age"] == 30


def test_prepare_data_custom_class_with_dict():
class MyCustom:
def __init__(self, name, age):
self.name = name
self.age = age

def dict(self):
return {"name": self.name, "age": self.age}

instance = MyCustom("beta", 40)
result = prepare_data(instance)
assert isinstance(result, dict)
assert result["name"] == "beta"
assert result["age"] == 40


def test_prepare_data_fallback_dict_via_dunder():
class WithDict:
def __init__(self, value):
self.value = value

instance = WithDict(100)
result = prepare_data(instance)
assert isinstance(result, dict)
assert result["value"] == 100


def test_prepare_data_nested_structure():
@dataclasses.dataclass
class NestedDC:
x: int
y: str

class NestedPM(BaseModel):
a: int
b: str

class NestedCustom:
def __init__(self, z):
self.z = z

def dict(self):
return {"z": self.z}

data = {
"dc": NestedDC(x=10, y="foo"),
"pm": NestedPM(a=5, b="bar"),
"custom": NestedCustom(z="baz"),
"nested": {
"list": [NestedDC(x=1, y="inner"), NestedPM(a=2, b="inner2")]
}
}
result = prepare_data(data)
assert result["dc"]["x"] == 10
assert result["dc"]["y"] == "foo"
assert result["pm"]["a"] == 5
assert result["pm"]["b"] == "bar"
assert result["custom"]["z"] == "baz"
assert result["nested"]["list"][0]["y"] == "inner"
assert result["nested"]["list"][1]["a"] == 2


def test_prepare_data_circular_reference():
data = {"a": 1}
data["self"] = data
result = prepare_data(data)
assert result["a"] == 1
assert "self" in result


class MyPydanticModel(BaseModel):
name: str
age: int


@dataclasses.dataclass
class MyDataClass:
name: str
age: int


class MyCustomClass:
def __init__(self, name, age):
self.name = name
self.age = age

def dict(self):
return {"name": self.name, "age": self.age}


def test_erase_on_pydantic_model(data_masker):
instance = MyPydanticModel(name="powertools", age=5)
result = data_masker.erase(instance, fields=["age"])
assert isinstance(result, dict)
assert result["age"] == DATA_MASKING_STRING
assert result["name"] == "powertools"


def test_erase_on_dataclass(data_masker):
instance = MyDataClass(name="powertools", age=5)
result = data_masker.erase(instance, fields=["age"])
assert isinstance(result, dict)
assert result["age"] == DATA_MASKING_STRING
assert result["name"] == "powertools"


def test_erase_on_custom_class(data_masker):
instance = MyCustomClass("powertools", 5)
result = data_masker.erase(instance, fields=["age"])
assert isinstance(result, dict)
assert result["age"] == DATA_MASKING_STRING
assert result["name"] == "powertools"


def test_erase_on_nested_complex_structure(data_masker):
@dataclasses.dataclass
class NestedDC:
value: int

class NestedPM(BaseModel):
value: int

class MyCustomClass:
def __init__(self, name, age):
self.name = name
self.age = age

def dict(self):
return {"name": self.name, "age": self.age}

data = {
"pydantic": NestedPM(value=10),
"dataclass": NestedDC(value=20),
"custom": MyCustomClass("example", 30),
"plain_dict": {"value": 40},
"list": [NestedPM(value=50), {"value": 60}],
}
result = data_masker.erase(data, fields=["$..value"])
assert result["pydantic"]["value"] == DATA_MASKING_STRING
assert result["dataclass"]["value"] == DATA_MASKING_STRING
assert result["custom"] == {"name": "example", "age": 30}
assert result["plain_dict"]["value"] == DATA_MASKING_STRING
assert result["list"][0]["value"] == DATA_MASKING_STRING
assert result["list"][1]["value"] == DATA_MASKING_STRING
Loading