|
7 | 7 | namespace at {
|
8 | 8 |
|
9 | 9 | Tensor TensorMaker::make_tensor() {
|
10 |
| - AutoDispatchBelowADInplaceOrView guard{}; // TODO: Remove. |
11 |
| - tracer::impl::NoTracerDispatchMode tracer_guard{}; |
| 10 | + AutoDispatchBelowADInplaceOrView guard{}; // TODO: Remove. |
| 11 | + tracer::impl::NoTracerDispatchMode tracer_guard{}; |
12 | 12 |
|
13 |
| - check_size_nonnegative(sizes_); |
| 13 | + check_size_nonnegative(sizes_); |
14 | 14 |
|
15 |
| - TORCH_CHECK_VALUE( |
16 |
| - !deleter_ || !ctx_, |
17 |
| - "The deleter and context arguments are mutually exclusive."); |
| 15 | + TORCH_CHECK_VALUE( |
| 16 | + !deleter_ || !ctx_, |
| 17 | + "The deleter and context arguments are mutually exclusive."); |
18 | 18 |
|
19 |
| - if (device_ == nullopt) { |
20 |
| - device_ = globalContext().getDeviceFromPtr(data_, opts_.device().type()); |
21 |
| - } |
| 19 | + if (device_ == nullopt) { |
| 20 | + device_ = globalContext().getDeviceFromPtr(data_, opts_.device().type()); |
| 21 | + } |
22 | 22 |
|
23 |
| - if (opts_.device().has_index()) { |
24 |
| - // clang-format off |
25 |
| - TORCH_CHECK_VALUE( |
26 |
| - opts_.device() == *device_, |
27 |
| - "Specified device ", opts_.device(), " does not match device of data ", *device_); |
28 |
| - // clang-format on |
29 |
| - } |
| 23 | + if (opts_.device().has_index()) { |
| 24 | + // clang-format off |
| 25 | + TORCH_CHECK_VALUE( |
| 26 | + opts_.device() == *device_, |
| 27 | + "Specified device ", opts_.device(), " does not match device of data ", *device_); |
| 28 | + // clang-format on |
| 29 | + } |
30 | 30 |
|
31 |
| - std::size_t size_bytes = computeStorageSize(); |
| 31 | + std::size_t size_bytes = computeStorageSize(); |
32 | 32 |
|
33 |
| - DataPtr data_ptr{}; |
34 |
| - if (deleter_) { |
35 |
| - data_ptr = makeDataPtrFromDeleter(); |
36 |
| - } else { |
37 |
| - data_ptr = makeDataPtrFromContext(); |
38 |
| - } |
| 33 | + if (data_ == nullptr) { |
| 34 | + // We need to ensure that there's always a valid pointer or the underlying |
| 35 | + // std::unique_ptr<> won't call the deleter (custom or context). |
| 36 | + data_ = malloc(0); |
| 37 | + } |
| 38 | + |
| 39 | + DataPtr data_ptr{}; |
| 40 | + if (deleter_) { |
| 41 | + data_ptr = makeDataPtrFromDeleter(); |
| 42 | + } else { |
| 43 | + data_ptr = makeDataPtrFromContext(); |
| 44 | + } |
39 | 45 |
|
40 |
| - TORCH_CHECK(!resizeable_ || allocator_ != nullptr, "Must specify an allocator with allocator() if you want to use resizeable_storage()"); |
41 |
| - Storage storage{Storage::use_byte_size_t{}, size_bytes, std::move(data_ptr), /*allocator=*/allocator_, /*resizeable=*/resizeable_}; |
| 46 | + TORCH_CHECK(!resizeable_ || allocator_ != nullptr, "Must specify an allocator with allocator() if you want to use resizeable_storage()"); |
| 47 | + Storage storage{Storage::use_byte_size_t{}, size_bytes, std::move(data_ptr), /*allocator=*/allocator_, /*resizeable=*/resizeable_}; |
42 | 48 |
|
43 |
| - Tensor tensor = detail::make_tensor<TensorImpl>( |
44 |
| - std::move(storage), opts_.computeDispatchKey(), opts_.dtype()); |
| 49 | + Tensor tensor = detail::make_tensor<TensorImpl>( |
| 50 | + std::move(storage), opts_.computeDispatchKey(), opts_.dtype()); |
45 | 51 |
|
46 | 52 | TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl();
|
47 | 53 | if (strides_) {
|
|
0 commit comments