@@ -201,33 +201,18 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
201
201
202
202
public:
203
203
// / Constructs a sparse tensor with the given encoding, and allocates
204
- // / overhead storage according to some simple heuristics. When the
205
- // / `bool` argument is true and `lvlTypes` are all dense, then this
206
- // / ctor will also initialize the values array with zeros. That
207
- // / argument should be true when an empty tensor is intended; whereas
208
- // / it should usually be false when the ctor will be followed up by
209
- // / some other form of initialization.
204
+ // / overhead storage according to some simple heuristics. When lvlCOO
205
+ // / is set, the sparse tensor initializes with the contents from that
206
+ // / data structure. Otherwise, an empty sparse tensor results.
210
207
SparseTensorStorage (uint64_t dimRank, const uint64_t *dimSizes,
211
208
uint64_t lvlRank, const uint64_t *lvlSizes,
212
209
const LevelType *lvlTypes, const uint64_t *dim2lvl,
213
- const uint64_t *lvl2dim, SparseTensorCOO<V> *lvlCOO,
214
- bool initializeValuesIfAllDense);
210
+ const uint64_t *lvl2dim, SparseTensorCOO<V> *lvlCOO);
215
211
216
212
// / Constructs a sparse tensor with the given encoding, and initializes
217
- // / the contents from the COO. This ctor performs the same heuristic
218
- // / overhead-storage allocation as the ctor above.
219
- SparseTensorStorage (uint64_t dimRank, const uint64_t *dimSizes,
220
- uint64_t lvlRank, const uint64_t *lvlSizes,
221
- const LevelType *lvlTypes, const uint64_t *dim2lvl,
222
- const uint64_t *lvl2dim, SparseTensorCOO<V> &lvlCOO);
223
-
224
- // / Constructs a sparse tensor with the given encoding, and initializes
225
- // / the contents from the level buffers. This ctor allocates exactly
226
- // / the required amount of overhead storage, not using any heuristics.
227
- // / It assumes that the data provided by `lvlBufs` can be directly used to
228
- // / interpret the result sparse tensor and performs *NO* integrity test on the
229
- // / input data. It also assume that the trailing COO coordinate buffer is
230
- // / passed in as a single AoS memory.
213
+ // / the contents from the level buffers. The constructor assumes that the
214
+ // / data provided by `lvlBufs` can be directly used to interpret the result
215
+ // / sparse tensor and performs no integrity test on the input data.
231
216
SparseTensorStorage (uint64_t dimRank, const uint64_t *dimSizes,
232
217
uint64_t lvlRank, const uint64_t *lvlSizes,
233
218
const LevelType *lvlTypes, const uint64_t *dim2lvl,
@@ -244,16 +229,14 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
244
229
newFromCOO (uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
245
230
const uint64_t *lvlSizes, const LevelType *lvlTypes,
246
231
const uint64_t *dim2lvl, const uint64_t *lvl2dim,
247
- SparseTensorCOO<V> & lvlCOO);
232
+ SparseTensorCOO<V> * lvlCOO);
248
233
249
- // / Allocates a new sparse tensor and initialize it with the data stored level
250
- // / buffers directly.
234
+ // / Allocates a new sparse tensor and initialize it from the given buffers.
251
235
static SparseTensorStorage<P, C, V> *
252
- packFromLvlBuffers (uint64_t dimRank, const uint64_t *dimSizes,
253
- uint64_t lvlRank, const uint64_t *lvlSizes,
254
- const LevelType *lvlTypes, const uint64_t *dim2lvl,
255
- const uint64_t *lvl2dim, uint64_t srcRank,
256
- const intptr_t *buffers);
236
+ newFromBuffers (uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
237
+ const uint64_t *lvlSizes, const LevelType *lvlTypes,
238
+ const uint64_t *dim2lvl, const uint64_t *lvl2dim,
239
+ uint64_t srcRank, const intptr_t *buffers);
257
240
258
241
~SparseTensorStorage () final = default ;
259
242
@@ -563,23 +546,24 @@ SparseTensorStorage<P, C, V> *SparseTensorStorage<P, C, V>::newEmpty(
563
546
uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
564
547
const uint64_t *lvlSizes, const LevelType *lvlTypes,
565
548
const uint64_t *dim2lvl, const uint64_t *lvl2dim) {
549
+ SparseTensorCOO<V> *noLvlCOO = nullptr ;
566
550
return new SparseTensorStorage<P, C, V>(dimRank, dimSizes, lvlRank, lvlSizes,
567
- lvlTypes, dim2lvl, lvl2dim, nullptr ,
568
- true );
551
+ lvlTypes, dim2lvl, lvl2dim, noLvlCOO);
569
552
}
570
553
571
554
template <typename P, typename C, typename V>
572
555
SparseTensorStorage<P, C, V> *SparseTensorStorage<P, C, V>::newFromCOO(
573
556
uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
574
557
const uint64_t *lvlSizes, const LevelType *lvlTypes,
575
558
const uint64_t *dim2lvl, const uint64_t *lvl2dim,
576
- SparseTensorCOO<V> &lvlCOO) {
559
+ SparseTensorCOO<V> *lvlCOO) {
560
+ assert (lvlCOO);
577
561
return new SparseTensorStorage<P, C, V>(dimRank, dimSizes, lvlRank, lvlSizes,
578
562
lvlTypes, dim2lvl, lvl2dim, lvlCOO);
579
563
}
580
564
581
565
template <typename P, typename C, typename V>
582
- SparseTensorStorage<P, C, V> *SparseTensorStorage<P, C, V>::packFromLvlBuffers (
566
+ SparseTensorStorage<P, C, V> *SparseTensorStorage<P, C, V>::newFromBuffers (
583
567
uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
584
568
const uint64_t *lvlSizes, const LevelType *lvlTypes,
585
569
const uint64_t *dim2lvl, const uint64_t *lvl2dim, uint64_t srcRank,
@@ -599,10 +583,9 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
599
583
uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
600
584
const uint64_t *lvlSizes, const LevelType *lvlTypes,
601
585
const uint64_t *dim2lvl, const uint64_t *lvl2dim,
602
- SparseTensorCOO<V> *lvlCOO, bool initializeValuesIfAllDense )
586
+ SparseTensorCOO<V> *lvlCOO)
603
587
: SparseTensorStorage(dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes,
604
588
dim2lvl, lvl2dim) {
605
- assert (!lvlCOO || lvlRank == lvlCOO->getRank ());
606
589
// Provide hints on capacity of positions and coordinates.
607
590
// TODO: needs much fine-tuning based on actual sparsity; currently
608
591
// we reserve position/coordinate space based on all previous dense
@@ -633,27 +616,20 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
633
616
sz = detail::checkedMul (sz, lvlSizes[l]);
634
617
}
635
618
}
636
- if (allDense && initializeValuesIfAllDense)
619
+ if (lvlCOO) {
620
+ /* New from COO: ensure it is sorted. */
621
+ assert (lvlCOO->getRank () == lvlRank);
622
+ lvlCOO->sort ();
623
+ // Now actually insert the `elements`.
624
+ const auto &elements = lvlCOO->getElements ();
625
+ const uint64_t nse = elements.size ();
626
+ assert (values.size () == 0 );
627
+ values.reserve (nse);
628
+ fromCOO (elements, 0 , nse, 0 );
629
+ } else if (allDense) {
630
+ /* New empty (all dense) */
637
631
values.resize (sz, 0 );
638
- }
639
-
640
- template <typename P, typename C, typename V>
641
- SparseTensorStorage<P, C, V>::SparseTensorStorage( // NOLINT
642
- uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
643
- const uint64_t *lvlSizes, const LevelType *lvlTypes,
644
- const uint64_t *dim2lvl, const uint64_t *lvl2dim,
645
- SparseTensorCOO<V> &lvlCOO)
646
- : SparseTensorStorage(dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes,
647
- dim2lvl, lvl2dim, nullptr , false ) {
648
- // Ensure lvlCOO is sorted.
649
- assert (lvlRank == lvlCOO.getRank ());
650
- lvlCOO.sort ();
651
- // Now actually insert the `elements`.
652
- const auto &elements = lvlCOO.getElements ();
653
- const uint64_t nse = elements.size ();
654
- assert (values.size () == 0 );
655
- values.reserve (nse);
656
- fromCOO (elements, 0 , nse, 0 );
632
+ }
657
633
}
658
634
659
635
template <typename P, typename C, typename V>
0 commit comments