@@ -28,27 +28,21 @@ class NdMatrixProxy {
28
28
// idim: The dimension associated with this proxy
29
29
// dim_stride: The stride of this dimension (i.e. how many element in memory between indicies of this dimension)
30
30
// start: Pointer to the start of the sub-matrix this proxy represents
31
- NdMatrixProxy<T, N>(const size_t * dim_sizes, size_t idim, size_t dim_stride , T* start)
31
+ NdMatrixProxy<T, N>(const size_t * dim_sizes, const size_t * dim_strides , T* start)
32
32
: dim_sizes_(dim_sizes)
33
- , idim_(idim)
34
- , dim_stride_(dim_stride)
33
+ , dim_strides_(dim_strides)
35
34
, start_(start) {}
36
35
37
36
const NdMatrixProxy<T, N - 1 > operator [](size_t index) const {
38
37
VTR_ASSERT_SAFE_MSG (index >= 0 , " Index out of range (below dimension minimum)" );
39
- VTR_ASSERT_SAFE_MSG (index < dim_sizes_[idim_], " Index out of range (above dimension maximum)" );
40
-
41
- size_t next_dim_size = dim_sizes_[idim_ + 1 ];
42
- VTR_ASSERT_SAFE_MSG (next_dim_size > 0 , " Can not index into zero-sized dimension" );
43
-
44
- // Determine the stride of the next dimension
45
- size_t next_dim_stride = dim_stride_ / next_dim_size;
38
+ VTR_ASSERT_SAFE_MSG (index < dim_sizes_[0 ], " Index out of range (above dimension maximum)" );
39
+ VTR_ASSERT_SAFE_MSG (dim_sizes_[1 ] > 0 , " Can not index into zero-sized dimension" );
46
40
47
41
// Strip off one dimension
48
- return NdMatrixProxy<T, N - 1 >(dim_sizes_, // Pass the dimension information
49
- idim_ + 1 , // Pass the next dimension
50
- next_dim_stride, // Pass the stride for the next dimension
51
- start_ + dim_stride_ * index ); // Advance to index in this dimension
42
+ return NdMatrixProxy<T, N - 1 >(
43
+ dim_sizes_ + 1 , // Pass the dimension information
44
+ dim_strides_ + 1 , // Pass the stride for the next dimension
45
+ start_ + dim_strides_[ 0 ] * index ); // Advance to index in this dimension
52
46
}
53
47
54
48
NdMatrixProxy<T, N - 1 > operator [](size_t index) {
@@ -58,25 +52,23 @@ class NdMatrixProxy {
58
52
59
53
private:
60
54
const size_t * dim_sizes_;
61
- const size_t idim_;
62
- const size_t dim_stride_;
55
+ const size_t * dim_strides_;
63
56
T* start_;
64
57
};
65
58
66
59
// Base case: 1-dimensional array
67
60
template <typename T>
68
61
class NdMatrixProxy <T, 1 > {
69
62
public:
70
- NdMatrixProxy<T, 1 >(const size_t * dim_sizes, size_t idim, size_t dim_stride, T* start)
63
+ NdMatrixProxy<T, 1 >(const size_t * dim_sizes, const size_t * dim_stride, T* start)
71
64
: dim_sizes_(dim_sizes)
72
- , idim_(idim)
73
- , dim_stride_(dim_stride)
65
+ , dim_strides_(dim_stride)
74
66
, start_(start) {}
75
67
76
68
const T& operator [](size_t index) const {
77
- VTR_ASSERT_SAFE_MSG (dim_stride_ == 1 , " Final dimension must have stride 1" );
69
+ VTR_ASSERT_SAFE_MSG (dim_strides_[ 0 ] == 1 , " Final dimension must have stride 1" );
78
70
VTR_ASSERT_SAFE_MSG (index >= 0 , " Index out of range (below dimension minimum)" );
79
- VTR_ASSERT_SAFE_MSG (index < dim_sizes_[idim_ ], " Index out of range (above dimension maximum)" );
71
+ VTR_ASSERT_SAFE_MSG (index < dim_sizes_[0 ], " Index out of range (above dimension maximum)" );
80
72
81
73
// Base case
82
74
return start_[index ];
@@ -103,8 +95,7 @@ class NdMatrixProxy<T, 1> {
103
95
104
96
private:
105
97
const size_t * dim_sizes_;
106
- const size_t idim_;
107
- const size_t dim_stride_;
98
+ const size_t * dim_strides_;
108
99
T* start_;
109
100
};
110
101
@@ -207,12 +198,21 @@ class NdMatrixBase {
207
198
size_ = calc_size ();
208
199
alloc ();
209
200
fill (value);
201
+ if (size_ > 0 ) {
202
+ dim_strides_[0 ] = size_ / dim_sizes_[0 ];
203
+ for (size_t dim = 1 ; dim < N; ++dim) {
204
+ dim_strides_[dim] = dim_strides_[dim - 1 ] / dim_sizes_[dim];
205
+ }
206
+ } else {
207
+ dim_strides_.fill (0 );
208
+ }
210
209
}
211
210
212
211
// Reset the matrix to size zero
213
212
void clear () {
214
213
data_.reset (nullptr );
215
214
dim_sizes_.fill (0 );
215
+ dim_strides_.fill (0 );
216
216
size_ = 0 ;
217
217
}
218
218
@@ -242,6 +242,7 @@ class NdMatrixBase {
242
242
using std::swap;
243
243
swap (m1.size_ , m2.size_ );
244
244
swap (m1.dim_sizes_ , m2.dim_sizes_ );
245
+ swap (m1.dim_strides_ , m2.dim_strides_ );
245
246
swap (m1.data_ , m2.data_ );
246
247
}
247
248
@@ -265,6 +266,7 @@ class NdMatrixBase {
265
266
protected:
266
267
size_t size_ = 0 ;
267
268
std::array<size_t , N> dim_sizes_;
269
+ std::array<size_t , N> dim_strides_;
268
270
std::unique_ptr<T[]> data_ = nullptr ;
269
271
};
270
272
@@ -316,17 +318,11 @@ class NdMatrix : public NdMatrixBase<T, N> {
316
318
VTR_ASSERT_SAFE_MSG (index >= 0 , " Index out of range (below dimension minimum)" );
317
319
VTR_ASSERT_SAFE_MSG (index < this ->dim_sizes_ [0 ], " Index out of range (above dimension maximum)" );
318
320
319
- // Calculate the stride for the current dimension
320
- size_t dim_stride = this ->size () / this ->dim_size (0 );
321
-
322
- // Calculate the stride for the next dimension
323
- size_t next_dim_stride = dim_stride / this ->dim_size (1 );
324
-
325
321
// Peel off the first dimension
326
- return NdMatrixProxy<T, N - 1 >(this -> dim_sizes_ . data (), // Pass the dimension information
327
- 1 , // Pass the next dimension
328
- next_dim_stride, // Pass the stride for the next dimension
329
- this ->data_ .get () + dim_stride * index ); // Advance to index in this dimension
322
+ return NdMatrixProxy<T, N - 1 >(
323
+ this -> dim_sizes_ . data () + 1 , // Pass the dimension information
324
+ this -> dim_strides_ . data () + 1 , // Pass the stride for the next dimension
325
+ this ->data_ .get () + this -> dim_strides_ [ 0 ] * index ); // Advance to index in this dimension
330
326
}
331
327
332
328
// Access an element
0 commit comments