Skip to content

Commit 394285c

Browse files
committed
Ensure that deleter is called even for a no-data tensor.
Summary: The deleter is called indirectly by a std::unique_ptr<> whch only fires if there's actually data to delete. Fixes #117273 Test Plan: [ghstack-poisoned]
1 parent 4b25948 commit 394285c

File tree

1 file changed

+33
-27
lines changed

1 file changed

+33
-27
lines changed

aten/src/ATen/templates/Functions.cpp

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,41 +7,47 @@
77
namespace at {
88

99
Tensor TensorMaker::make_tensor() {
10-
AutoDispatchBelowADInplaceOrView guard{}; // TODO: Remove.
11-
tracer::impl::NoTracerDispatchMode tracer_guard{};
10+
AutoDispatchBelowADInplaceOrView guard{}; // TODO: Remove.
11+
tracer::impl::NoTracerDispatchMode tracer_guard{};
1212

13-
check_size_nonnegative(sizes_);
13+
check_size_nonnegative(sizes_);
1414

15-
TORCH_CHECK_VALUE(
16-
!deleter_ || !ctx_,
17-
"The deleter and context arguments are mutually exclusive.");
15+
TORCH_CHECK_VALUE(
16+
!deleter_ || !ctx_,
17+
"The deleter and context arguments are mutually exclusive.");
1818

19-
if (device_ == nullopt) {
20-
device_ = globalContext().getDeviceFromPtr(data_, opts_.device().type());
21-
}
19+
if (device_ == nullopt) {
20+
device_ = globalContext().getDeviceFromPtr(data_, opts_.device().type());
21+
}
2222

23-
if (opts_.device().has_index()) {
24-
// clang-format off
25-
TORCH_CHECK_VALUE(
26-
opts_.device() == *device_,
27-
"Specified device ", opts_.device(), " does not match device of data ", *device_);
28-
// clang-format on
29-
}
23+
if (opts_.device().has_index()) {
24+
// clang-format off
25+
TORCH_CHECK_VALUE(
26+
opts_.device() == *device_,
27+
"Specified device ", opts_.device(), " does not match device of data ", *device_);
28+
// clang-format on
29+
}
3030

31-
std::size_t size_bytes = computeStorageSize();
31+
std::size_t size_bytes = computeStorageSize();
3232

33-
DataPtr data_ptr{};
34-
if (deleter_) {
35-
data_ptr = makeDataPtrFromDeleter();
36-
} else {
37-
data_ptr = makeDataPtrFromContext();
38-
}
33+
if (data_ == nullptr) {
34+
// We need to ensure that there's always a valid pointer or the underlying
35+
// std::unique_ptr<> won't call the deleter (custom or context).
36+
data_ = malloc(0);
37+
}
38+
39+
DataPtr data_ptr{};
40+
if (deleter_) {
41+
data_ptr = makeDataPtrFromDeleter();
42+
} else {
43+
data_ptr = makeDataPtrFromContext();
44+
}
3945

40-
TORCH_CHECK(!resizeable_ || allocator_ != nullptr, "Must specify an allocator with allocator() if you want to use resizeable_storage()");
41-
Storage storage{Storage::use_byte_size_t{}, size_bytes, std::move(data_ptr), /*allocator=*/allocator_, /*resizeable=*/resizeable_};
46+
TORCH_CHECK(!resizeable_ || allocator_ != nullptr, "Must specify an allocator with allocator() if you want to use resizeable_storage()");
47+
Storage storage{Storage::use_byte_size_t{}, size_bytes, std::move(data_ptr), /*allocator=*/allocator_, /*resizeable=*/resizeable_};
4248

43-
Tensor tensor = detail::make_tensor<TensorImpl>(
44-
std::move(storage), opts_.computeDispatchKey(), opts_.dtype());
49+
Tensor tensor = detail::make_tensor<TensorImpl>(
50+
std::move(storage), opts_.computeDispatchKey(), opts_.dtype());
4551

4652
TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl();
4753
if (strides_) {

0 commit comments

Comments
 (0)