@@ -30,13 +30,14 @@ class NdMatrixProxy {
30
30
* @brief Construct a matrix proxy object
31
31
*
32
32
* @param dim_sizes: Array of dimension sizes
33
- * @param idim: The dimension associated with this proxy
34
33
* @param dim_stride: The stride of this dimension (i.e. how many element in memory between indicies of this dimension)
35
- * @param start: Pointer to the start of the sub-matrix this proxy represents
34
+ * @param offset: The offset from the start that this sub-matrix starts at.
35
+ * @param start: Pointer to the start of the base NDMatrix of this proxy
36
36
*/
37
- NdMatrixProxy (const size_t * dim_sizes, const size_t * dim_strides, T* start)
37
+ NdMatrixProxy (const size_t * dim_sizes, const size_t * dim_strides, size_t offset, const std::unique_ptr<T[]>& start)
38
38
: dim_sizes_(dim_sizes)
39
39
, dim_strides_(dim_strides)
40
+ , offset_(offset)
40
41
, start_(start) {}
41
42
42
43
NdMatrixProxy& operator =(const NdMatrixProxy& other) = delete ;
@@ -50,7 +51,8 @@ class NdMatrixProxy {
50
51
return NdMatrixProxy<T, N - 1 >(
51
52
dim_sizes_ + 1 , // Pass the dimension information
52
53
dim_strides_ + 1 , // Pass the stride for the next dimension
53
- start_ + dim_strides_[0 ] * index); // Advance to index in this dimension
54
+ offset_ + dim_strides_[0 ] * index, // Advance to index in this dimension
55
+ start_); // Pass the base pointer.
54
56
}
55
57
56
58
// /@brief [] operator
@@ -60,9 +62,22 @@ class NdMatrixProxy {
60
62
}
61
63
62
64
private:
65
+ // / @brief The sizes of each dimension of this proxy. This is an array of
66
+ // / length N.
63
67
const size_t * dim_sizes_;
68
+
69
+ // / @brief The stride of each dimension of this proxy. This is an array of
70
+ // / length N.
64
71
const size_t * dim_strides_;
65
- T* start_;
72
+
73
+ // / @brief The offset from the base NDMatrix object that this sub-matrix
74
+ // / starts at.
75
+ size_t offset_;
76
+
77
+ // / @brief The pointer to the start of the base NDMatrix data. Since the
78
+ // / base NDMatrix object owns the memory, we hold onto a reference
79
+ // / to its unique pointer. This is safer than passing a bare pointer.
80
+ const std::unique_ptr<T[]>& start_;
66
81
};
67
82
68
83
// /@brief Base case: 1-dimensional array
@@ -74,11 +89,13 @@ class NdMatrixProxy<T, 1> {
74
89
*
75
90
* @param dim_sizes: Array of dimension sizes
76
91
* @param dim_stride: The stride of this dimension (i.e. how many element in memory between indicies of this dimension)
77
- * @param start: Pointer to the start of the sub-matrix this proxy represents
92
+ * @param offset: The offset from the start that this sub-matrix starts at.
93
+ * @param start: Pointer to the start of the base NDMatrix of this proxy
78
94
*/
79
- NdMatrixProxy (const size_t * dim_sizes, const size_t * dim_stride, T* start)
95
+ NdMatrixProxy (const size_t * dim_sizes, const size_t * dim_stride, size_t offset, const std::unique_ptr<T[]>& start)
80
96
: dim_sizes_(dim_sizes)
81
97
, dim_strides_(dim_stride)
98
+ , offset_(offset)
82
99
, start_(start) {}
83
100
84
101
NdMatrixProxy& operator =(const NdMatrixProxy& other) = delete ;
@@ -89,7 +106,7 @@ class NdMatrixProxy<T, 1> {
89
106
VTR_ASSERT_SAFE_MSG (index < dim_sizes_[0 ], " Index out of range (above dimension maximum)" );
90
107
91
108
// Base case
92
- return start_[index];
109
+ return start_[offset_ + index];
93
110
}
94
111
95
112
// /@brief [] operator
@@ -108,7 +125,7 @@ class NdMatrixProxy<T, 1> {
108
125
* not to clobber elements in other dimensions
109
126
*/
110
127
const T* data () const {
111
- return start_;
128
+ return start_. get () + offset_ ;
112
129
}
113
130
114
131
// /@brief same as above but allow update the value
@@ -118,9 +135,22 @@ class NdMatrixProxy<T, 1> {
118
135
}
119
136
120
137
private:
138
+ // / @brief The sizes of each dimension of this proxy. This is an array of
139
+ // / length N.
121
140
const size_t * dim_sizes_;
141
+
142
+ // / @brief The stride of each dimension of this proxy. This is an array of
143
+ // / length N.
122
144
const size_t * dim_strides_;
123
- T* start_;
145
+
146
+ // / @brief The offset from the base NDMatrix object that this sub-matrix
147
+ // / starts at.
148
+ size_t offset_;
149
+
150
+ // / @brief The pointer to the start of the base NDMatrix data. Since the
151
+ // / base NDMatrix object owns the memory, we hold onto a reference
152
+ // / to its unique pointer. This is safer than passing a bare pointer.
153
+ const std::unique_ptr<T[]>& start_;
124
154
};
125
155
126
156
/* *
@@ -359,7 +389,8 @@ class NdMatrix : public NdMatrixBase<T, N> {
359
389
return NdMatrixProxy<T, N - 1 >(
360
390
this ->dim_sizes_ .data () + 1 , // Pass the dimension information
361
391
this ->dim_strides_ .data () + 1 , // Pass the stride for the next dimension
362
- this ->data_ .get () + this ->dim_strides_ [0 ] * index); // Advance to index in this dimension
392
+ this ->dim_strides_ [0 ] * index, // Advance to index in this dimension
393
+ this ->data_ ); // Pass the base pointer
363
394
}
364
395
365
396
/* *
0 commit comments