Skip to content

Commit 5a064ff

Browse files
committed
Optimize NdMatrix.
Previous implementation recomputed dimension strides every usage. When using the map lookahead, this has a very significant affect, as the core of the map lookahead is an NdMatrix. Signed-off-by: Keith Rothman <[email protected]>
1 parent b88108c commit 5a064ff

File tree

1 file changed

+29
-33
lines changed

1 file changed

+29
-33
lines changed

libs/libvtrutil/src/vtr_ndmatrix.h

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -28,27 +28,21 @@ class NdMatrixProxy {
2828
// idim: The dimension associated with this proxy
2929
// dim_stride: The stride of this dimension (i.e. how many element in memory between indicies of this dimension)
3030
// start: Pointer to the start of the sub-matrix this proxy represents
31-
NdMatrixProxy<T, N>(const size_t* dim_sizes, size_t idim, size_t dim_stride, T* start)
31+
NdMatrixProxy<T, N>(const size_t* dim_sizes, const size_t* dim_strides, T* start)
3232
: dim_sizes_(dim_sizes)
33-
, idim_(idim)
34-
, dim_stride_(dim_stride)
33+
, dim_strides_(dim_strides)
3534
, start_(start) {}
3635

3736
const NdMatrixProxy<T, N - 1> operator[](size_t index) const {
3837
VTR_ASSERT_SAFE_MSG(index >= 0, "Index out of range (below dimension minimum)");
39-
VTR_ASSERT_SAFE_MSG(index < dim_sizes_[idim_], "Index out of range (above dimension maximum)");
40-
41-
size_t next_dim_size = dim_sizes_[idim_ + 1];
42-
VTR_ASSERT_SAFE_MSG(next_dim_size > 0, "Can not index into zero-sized dimension");
43-
44-
//Determine the stride of the next dimension
45-
size_t next_dim_stride = dim_stride_ / next_dim_size;
38+
VTR_ASSERT_SAFE_MSG(index < dim_sizes_[0], "Index out of range (above dimension maximum)");
39+
VTR_ASSERT_SAFE_MSG(dim_sizes_[1] > 0, "Can not index into zero-sized dimension");
4640

4741
//Strip off one dimension
48-
return NdMatrixProxy<T, N - 1>(dim_sizes_, //Pass the dimension information
49-
idim_ + 1, //Pass the next dimension
50-
next_dim_stride, //Pass the stride for the next dimension
51-
start_ + dim_stride_ * index); //Advance to index in this dimension
42+
return NdMatrixProxy<T, N - 1>(
43+
dim_sizes_ + 1, //Pass the dimension information
44+
dim_strides_ + 1, //Pass the stride for the next dimension
45+
start_ + dim_strides_[0] * index); //Advance to index in this dimension
5246
}
5347

5448
NdMatrixProxy<T, N - 1> operator[](size_t index) {
@@ -58,25 +52,23 @@ class NdMatrixProxy {
5852

5953
private:
6054
const size_t* dim_sizes_;
61-
const size_t idim_;
62-
const size_t dim_stride_;
55+
const size_t* dim_strides_;
6356
T* start_;
6457
};
6558

6659
//Base case: 1-dimensional array
6760
template<typename T>
6861
class NdMatrixProxy<T, 1> {
6962
public:
70-
NdMatrixProxy<T, 1>(const size_t* dim_sizes, size_t idim, size_t dim_stride, T* start)
63+
NdMatrixProxy<T, 1>(const size_t* dim_sizes, const size_t* dim_stride, T* start)
7164
: dim_sizes_(dim_sizes)
72-
, idim_(idim)
73-
, dim_stride_(dim_stride)
65+
, dim_strides_(dim_stride)
7466
, start_(start) {}
7567

7668
const T& operator[](size_t index) const {
77-
VTR_ASSERT_SAFE_MSG(dim_stride_ == 1, "Final dimension must have stride 1");
69+
VTR_ASSERT_SAFE_MSG(dim_strides_[0] == 1, "Final dimension must have stride 1");
7870
VTR_ASSERT_SAFE_MSG(index >= 0, "Index out of range (below dimension minimum)");
79-
VTR_ASSERT_SAFE_MSG(index < dim_sizes_[idim_], "Index out of range (above dimension maximum)");
71+
VTR_ASSERT_SAFE_MSG(index < dim_sizes_[0], "Index out of range (above dimension maximum)");
8072

8173
//Base case
8274
return start_[index];
@@ -103,8 +95,7 @@ class NdMatrixProxy<T, 1> {
10395

10496
private:
10597
const size_t* dim_sizes_;
106-
const size_t idim_;
107-
const size_t dim_stride_;
98+
const size_t* dim_strides_;
10899
T* start_;
109100
};
110101

@@ -207,12 +198,21 @@ class NdMatrixBase {
207198
size_ = calc_size();
208199
alloc();
209200
fill(value);
201+
if (size_ > 0) {
202+
dim_strides_[0] = size_ / dim_sizes_[0];
203+
for (size_t dim = 1; dim < N; ++dim) {
204+
dim_strides_[dim] = dim_strides_[dim - 1] / dim_sizes_[dim];
205+
}
206+
} else {
207+
dim_strides_.fill(0);
208+
}
210209
}
211210

212211
//Reset the matrix to size zero
213212
void clear() {
214213
data_.reset(nullptr);
215214
dim_sizes_.fill(0);
215+
dim_strides_.fill(0);
216216
size_ = 0;
217217
}
218218

@@ -242,6 +242,7 @@ class NdMatrixBase {
242242
using std::swap;
243243
swap(m1.size_, m2.size_);
244244
swap(m1.dim_sizes_, m2.dim_sizes_);
245+
swap(m1.dim_strides_, m2.dim_strides_);
245246
swap(m1.data_, m2.data_);
246247
}
247248

@@ -265,6 +266,7 @@ class NdMatrixBase {
265266
protected:
266267
size_t size_ = 0;
267268
std::array<size_t, N> dim_sizes_;
269+
std::array<size_t, N> dim_strides_;
268270
std::unique_ptr<T[]> data_ = nullptr;
269271
};
270272

@@ -316,17 +318,11 @@ class NdMatrix : public NdMatrixBase<T, N> {
316318
VTR_ASSERT_SAFE_MSG(index >= 0, "Index out of range (below dimension minimum)");
317319
VTR_ASSERT_SAFE_MSG(index < this->dim_sizes_[0], "Index out of range (above dimension maximum)");
318320

319-
//Calculate the stride for the current dimension
320-
size_t dim_stride = this->size() / this->dim_size(0);
321-
322-
//Calculate the stride for the next dimension
323-
size_t next_dim_stride = dim_stride / this->dim_size(1);
324-
325321
//Peel off the first dimension
326-
return NdMatrixProxy<T, N - 1>(this->dim_sizes_.data(), //Pass the dimension information
327-
1, //Pass the next dimension
328-
next_dim_stride, //Pass the stride for the next dimension
329-
this->data_.get() + dim_stride * index); //Advance to index in this dimension
322+
return NdMatrixProxy<T, N - 1>(
323+
this->dim_sizes_.data() + 1, //Pass the dimension information
324+
this->dim_strides_.data() + 1, //Pass the stride for the next dimension
325+
this->data_.get() + this->dim_strides_[0] * index); //Advance to index in this dimension
330326
}
331327

332328
//Access an element

0 commit comments

Comments
 (0)