Skip to content

Commit 8b92a02

Browse files
author
Luke Hinds
committed
Pytorch Load / Save Plugin
This plugin checks for the use of `torch.load` and `torch.save`. Using `torch.load` with untrusted data can lead to arbitrary code execution, and improper use of `torch.save` might expose sensitive data or lead to data corruption. Signed-off-by: Luke Hinds <[email protected]>
1 parent 4c5b3c8 commit 8b92a02

File tree

6 files changed

+123
-0
lines changed

6 files changed

+123
-0
lines changed

bandit/blacklists/calls.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,19 @@
320320
| | | - os.tmpnam | |
321321
+------+---------------------+------------------------------------+-----------+
322322
323+
B704: pytorch_load_save
324+
325+
Use of unsafe PyTorch load. `torch.load` can lead to arbitrary code execution,
326+
and improper use of `torch.save` might expose sensitive data or lead to data
327+
corruption.
328+
329+
+------+---------------------+--------------------------------------+-----------+
330+
| ID | Name | Calls | Severity |
331+
+======+=====================+======================================+===========+
332+
| B704 | pytorch_load_save| | - torch.load | Medium |
333+
| B704 | pytorch_load_save| | - torch.save | Medium |
334+
+------+---------------------+--------------------------------------+-----------+
335+
323336
"""
324337
import sys
325338

@@ -685,6 +698,18 @@ def gen_blacklist():
685698
)
686699
)
687700

701+
sets.append(
702+
utils.build_conf_dict(
703+
"pytorch_load_save",
704+
"B704",
705+
issue.Cwe.DESERIALIZATION_OF_UNTRUSTED_DATA,
706+
["torch.load", "torch.save"],
707+
"Use of unsafe PyTorch load or save",
708+
"MEDIUM",
709+
)
710+
)
711+
712+
688713
# skipped B324 (used in bandit/plugins/hashlib_new_insecure_functions.py)
689714

690715
# skipped B325 as the check for a call to os.tempnam and os.tmpnam have

bandit/plugins/pytorch_load_save.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright (c) 2024 Stacklok, Inc.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
r"""
5+
=========================================
6+
B704: Test for unsafe PyTorch load or save
7+
=========================================
8+
9+
This plugin checks for the use of `torch.load` and `torch.save`. Using `torch.load`
10+
with untrusted data can lead to arbitrary code execution, and improper use of
11+
`torch.save` might expose sensitive data or lead to data corruption.
12+
13+
:Example:
14+
15+
.. code-block:: none
16+
17+
>> Issue: Use of unsafe PyTorch load or save
18+
Severity: Medium Confidence: High
19+
CWE: CWE-94 (https://cwe.mitre.org/data/definitions/94.html)
20+
Location: examples/pytorch_load_save.py:8
21+
7 loaded_model.load_state_dict(torch.load('model_weights.pth'))
22+
8 another_model.load_state_dict(torch.load('model_weights.pth', map_location='cpu'))
23+
9
24+
10 print("Model loaded successfully!")
25+
26+
.. seealso::
27+
28+
- https://cwe.mitre.org/data/definitions/94.html
29+
30+
.. versionadded:: 1.7.8
31+
32+
"""
33+
import bandit
34+
from bandit.core import issue
35+
from bandit.core import test_properties as test
36+
37+
38+
@test.checks("Call")
39+
@test.test_id("B704") # Ensure the test ID is unique and does not conflict with existing Bandit tests
40+
def pytorch_load_save(context):
41+
"""
42+
This plugin checks for the use of `torch.load` and `torch.save`. Using `torch.load`
43+
with untrusted data can lead to arbitrary code execution, and improper use of
44+
`torch.save` might expose sensitive data or lead to data corruption.
45+
"""
46+
imported = context.is_module_imported_exact("torch")
47+
qualname = context.call_function_name_qual
48+
if not imported and isinstance(qualname, str):
49+
return
50+
51+
qualname_list = qualname.split(".")
52+
func = qualname_list[-1]
53+
if all(
54+
[
55+
"torch" in qualname_list,
56+
func in ["load"],
57+
not context.check_call_arg_value("map_location", "cpu"),
58+
]
59+
):
60+
return bandit.Issue(
61+
severity=bandit.MEDIUM,
62+
confidence=bandit.HIGH,
63+
text="Use of unsafe PyTorch load or save",
64+
cwe=issue.Cwe.UNTRUSTED_INPUT,
65+
lineno=context.get_lineno_for_call_arg("load"),
66+
)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
-----------------------
2+
B704: pytorch_load_save
3+
-----------------------
4+
5+
.. automodule:: bandit.plugins.pytorch_load_save

examples/pytorch_load_save.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import torch
2+
import torchvision.models as models
3+
4+
# Example of saving a model
5+
model = models.resnet18(pretrained=True)
6+
torch.save(model.state_dict(), 'model_weights.pth')
7+
8+
# Example of loading the model weights in an insecure way
9+
loaded_model = models.resnet18()
10+
loaded_model.load_state_dict(torch.load('model_weights.pth'))
11+
12+
# Another example using torch.load with more parameters
13+
another_model = models.resnet18()
14+
another_model.load_state_dict(torch.load('model_weights.pth', map_location='cpu'))
15+
16+
print("Model loaded successfully!")

setup.cfg

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,9 @@ bandit.plugins =
148148
#bandit/plugins/tarfile_unsafe_members.py
149149
tarfile_unsafe_members = bandit.plugins.tarfile_unsafe_members:tarfile_unsafe_members
150150

151+
#bandit/plugins/pytorch_load_save.py
152+
pytorch_load_save = bandit.plugins.pytorch_load_save:pytorch_load_save
153+
151154
[build_sphinx]
152155
all_files = 1
153156
build-dir = doc/build

tests/functional/test_functional.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -930,3 +930,11 @@ def test_tarfile_unsafe_members(self):
930930
"CONFIDENCE": {"UNDEFINED": 0, "LOW": 1, "MEDIUM": 2, "HIGH": 1},
931931
}
932932
self.check_example("tarfile_extractall.py", expect)
933+
934+
def test_pytorch_load_save(self):
935+
"""Test insecure usage of torch.load and torch.save."""
936+
expect = {
937+
"SEVERITY": {"UNDEFINED": 0, "LOW": 1, "MEDIUM": 3, "HIGH": 0},
938+
"CONFIDENCE": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 1, "HIGH": 3},
939+
}
940+
self.check_example("pytorch_load_save.py", expect)

0 commit comments

Comments
 (0)