Skip to content

Commit fa98bdb

Browse files
author
Kun Wu
committed
[mlir][sparse][gpu] make computeType mandatory
Differential Revision: https://reviews.llvm.org/D152018
1 parent 1c8b7c5 commit fa98bdb

File tree

6 files changed

+93
-105
lines changed

6 files changed

+93
-105
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 40 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1869,7 +1869,7 @@ def GPU_SpMVBufferSizeOp : GPU_Op<"spmv_buffer_size", [GPU_AsyncOpInterface]> {
18691869
GPU_SparseSpMatHandle:$spmatA,
18701870
GPU_SparseDnVecHandle:$dnX,
18711871
GPU_SparseDnVecHandle:$dnY,
1872-
OptionalAttr<TypeAttr>:$computeType);
1872+
TypeAttr:$computeType);
18731873
let results = (outs Res<Index>:$bufferSz,
18741874
Optional<GPU_AsyncToken>:$asyncToken);
18751875

@@ -1880,16 +1880,17 @@ def GPU_SpMVBufferSizeOp : GPU_Op<"spmv_buffer_size", [GPU_AsyncOpInterface]> {
18801880
"Value":$env,
18811881
"Value":$spmatA,
18821882
"Value":$dnX,
1883-
"Value":$dnY)
1883+
"Value":$dnY,
1884+
"Type":$computeType)
18841885
, [{
18851886
auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
18861887
return build($_builder, $_state, bufferSz, asyncToken, asyncDependencies,
1887-
env, modeA, spmatA, dnX, dnY, {});}]>
1888+
env, modeA, spmatA, dnX, dnY, computeType);}]>
18881889
];
18891890

18901891
let assemblyFormat = [{
18911892
custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
1892-
$env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnX `,` $dnY attr-dict ( `into` $computeType^)?
1893+
$env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnX `,` $dnY attr-dict `into` $computeType
18931894
}];
18941895
}
18951896

@@ -1921,7 +1922,7 @@ def GPU_SpMVOp : GPU_Op<"spmv", [GPU_AsyncOpInterface]> {
19211922
GPU_SparseSpMatHandle:$spmatA,
19221923
GPU_SparseDnVecHandle:$dnX,
19231924
GPU_SparseDnVecHandle:$dnY,
1924-
OptionalAttr<TypeAttr>:$computeType,
1925+
TypeAttr:$computeType,
19251926
AnyMemRef:$buffer);
19261927
let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
19271928

@@ -1932,15 +1933,16 @@ def GPU_SpMVOp : GPU_Op<"spmv", [GPU_AsyncOpInterface]> {
19321933
"Value":$spmatA,
19331934
"Value":$dnX,
19341935
"Value":$dnY,
1936+
"Type":$computeType,
19351937
"Value":$buffer), [{
19361938
auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
19371939
return build($_builder, $_state, asyncToken, asyncDependencies, env, modeA,
1938-
spmatA, dnX, dnY, {}, buffer);}]>
1940+
spmatA, dnX, dnY, computeType, buffer);}]>
19391941
];
19401942

19411943
let assemblyFormat = [{
19421944
custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
1943-
$env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnX `,` $dnY `,` $buffer attr-dict `:` type($buffer) ( `into` $computeType^)?
1945+
$env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnX `,` $dnY `,` $buffer attr-dict `:` type($buffer) `into` $computeType
19441946
}];
19451947
}
19461948

@@ -1974,7 +1976,7 @@ def GPU_SpMMBufferSizeOp : GPU_Op<"spmm_buffer_size", [GPU_AsyncOpInterface]> {
19741976
GPU_SparseSpMatHandle:$spmatA,
19751977
GPU_SparseDnMatHandle:$dnmatB,
19761978
GPU_SparseDnMatHandle:$dnmatC,
1977-
OptionalAttr<TypeAttr>:$computeType);
1979+
TypeAttr:$computeType);
19781980
let results = (outs Res<Index>:$bufferSz,
19791981
Optional<GPU_AsyncToken>:$asyncToken);
19801982

@@ -1985,16 +1987,17 @@ def GPU_SpMMBufferSizeOp : GPU_Op<"spmm_buffer_size", [GPU_AsyncOpInterface]> {
19851987
"Value":$env,
19861988
"Value":$spmatA,
19871989
"Value":$dnmatB,
1988-
"Value":$dnmatC), [{
1990+
"Value":$dnmatC,
1991+
"Type":$computeType), [{
19891992
auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
19901993
auto modeB = gpu::TransposeMode::NON_TRANSPOSE;
19911994
return build($_builder, $_state, bufferSz, asyncToken, asyncDependencies,
1992-
env, modeA, modeB, spmatA, dnmatB, dnmatC, {});}]>
1995+
env, modeA, modeB, spmatA, dnmatB, dnmatC, computeType);}]>
19931996
];
19941997

19951998
let assemblyFormat = [{
19961999
custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
1997-
$env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $dnmatC attr-dict ( `into` $computeType^)?
2000+
$env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $dnmatC attr-dict `into` $computeType
19982001
}];
19992002
}
20002003

@@ -2028,7 +2031,7 @@ def GPU_SpMMOp : GPU_Op<"spmm", [GPU_AsyncOpInterface]> {
20282031
GPU_SparseSpMatHandle:$spmatA,
20292032
GPU_SparseDnMatHandle:$dnmatB,
20302033
GPU_SparseDnMatHandle:$dnmatC,
2031-
OptionalAttr<TypeAttr>:$computeType,
2034+
TypeAttr:$computeType,
20322035
AnyMemRef:$buffer);
20332036
let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
20342037

@@ -2039,16 +2042,17 @@ def GPU_SpMMOp : GPU_Op<"spmm", [GPU_AsyncOpInterface]> {
20392042
"Value":$spmatA,
20402043
"Value":$dnmatB,
20412044
"Value":$dnmatC,
2045+
"Type":$computeType,
20422046
"Value":$buffer), [{
20432047
auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
20442048
auto modeB = gpu::TransposeMode::NON_TRANSPOSE;
20452049
return build($_builder, $_state, asyncToken, asyncDependencies, env, modeA,
2046-
modeB, spmatA, dnmatB, dnmatC, {}, buffer);}]>
2050+
modeB, spmatA, dnmatB, dnmatC, computeType, buffer);}]>
20472051
];
20482052

20492053
let assemblyFormat = [{
20502054
custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
2051-
$env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $dnmatC `,` $buffer attr-dict `:` type($buffer) ( `into` $computeType^)?
2055+
$env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $dnmatC `,` $buffer attr-dict `:` type($buffer) `into` $computeType
20522056
}];
20532057
}
20542058

@@ -2082,26 +2086,27 @@ def GPU_SDDMMBufferSizeOp : GPU_Op<"sddmm_buffer_size", [GPU_AsyncOpInterface]>
20822086
GPU_SparseDnMatHandle:$dnmatA,
20832087
GPU_SparseDnMatHandle:$dnmatB,
20842088
GPU_SparseSpMatHandle:$spmatC,
2085-
OptionalAttr<TypeAttr>:$computeType);
2089+
TypeAttr:$computeType);
20862090
let results = (outs Res<Index>:$bufferSz, Optional<GPU_AsyncToken>:$asyncToken);
20872091

20882092
let builders = [OpBuilder<(ins
2089-
"::mlir::Type":$bufferSz,
2090-
"::mlir::Type":$asyncToken,
2091-
"::mlir::ValueRange":$asyncDependencies,
2092-
"::mlir::Value":$env,
2093-
"::mlir::Value":$dnmatA,
2094-
"::mlir::Value":$dnmatB,
2095-
"::mlir::Value":$spmatC), [{
2093+
"Type":$bufferSz,
2094+
"Type":$asyncToken,
2095+
"ValueRange":$asyncDependencies,
2096+
"Value":$env,
2097+
"Value":$dnmatA,
2098+
"Value":$dnmatB,
2099+
"Value":$spmatC,
2100+
"Type":$computeType), [{
20962101
auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
20972102
auto modeB = gpu::TransposeMode::NON_TRANSPOSE;
20982103
return build($_builder, $_state, bufferSz, asyncToken, asyncDependencies,
2099-
env, modeA, modeB, dnmatA, dnmatB, spmatC, {});}]>
2104+
env, modeA, modeB, dnmatA, dnmatB, spmatC, computeType);}]>
21002105
];
21012106

21022107
let assemblyFormat = [{
21032108
custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
2104-
$env `,` $dnmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $spmatC attr-dict ( `into` $computeType^)?
2109+
$env `,` $dnmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $spmatC attr-dict `into` $computeType
21052110
}];
21062111
}
21072112

@@ -2135,27 +2140,28 @@ def GPU_SDDMMOp : GPU_Op<"sddmm", [GPU_AsyncOpInterface]> {
21352140
GPU_SparseDnMatHandle:$dnmatA,
21362141
GPU_SparseDnMatHandle:$dnmatB,
21372142
GPU_SparseSpMatHandle:$spmatC,
2138-
OptionalAttr<TypeAttr>:$computeType,
2143+
TypeAttr:$computeType,
21392144
AnyMemRef:$buffer);
21402145
let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
21412146

21422147
let builders = [OpBuilder<(ins
2143-
"::mlir::Type":$asyncToken,
2144-
"::mlir::ValueRange":$asyncDependencies,
2145-
"::mlir::Value":$env,
2146-
"::mlir::Value":$dnmatA,
2147-
"::mlir::Value":$dnmatB,
2148-
"::mlir::Value":$spmatC,
2149-
"::mlir::Value":$buffer), [{
2148+
"Type":$asyncToken,
2149+
"ValueRange":$asyncDependencies,
2150+
"Value":$env,
2151+
"Value":$dnmatA,
2152+
"Value":$dnmatB,
2153+
"Value":$spmatC,
2154+
"Type":$computeType,
2155+
"Value":$buffer), [{
21502156
auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
21512157
auto modeB = gpu::TransposeMode::NON_TRANSPOSE;
21522158
return build($_builder, $_state, asyncToken, asyncDependencies, env, modeA,
2153-
modeB, dnmatA, dnmatB, spmatC, {}, buffer);}]>
2159+
modeB, dnmatA, dnmatB, spmatC, computeType, buffer);}]>
21542160
];
21552161

21562162
let assemblyFormat = [{
21572163
custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
2158-
$env `,` $dnmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $spmatC `,` $buffer attr-dict `:` type($buffer) ( `into` $computeType^)?
2164+
$env `,` $dnmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $spmatC `,` $buffer attr-dict `:` type($buffer) `into` $computeType
21592165
}];
21602166
}
21612167

mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp

Lines changed: 17 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1274,40 +1274,18 @@ LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
12741274
return success();
12751275
}
12761276

1277-
// Returns the element type of the defining spmat op.
1278-
// TODO: safer and more flexible to store data type in actual op instead?
1279-
static Type getSpMatElemType(Value spMat) {
1280-
if (auto op = spMat.getDefiningOp<gpu::CreateCooOp>())
1281-
return llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1282-
if (auto op = spMat.getDefiningOp<gpu::CreateCsrOp>())
1283-
return llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1284-
llvm_unreachable("cannot find spmat def");
1285-
}
1286-
1287-
// Returns the element type of the defining dnmat or dnvec op.
1288-
static Type getDnElemType(Value dn) {
1289-
if (auto op = dn.getDefiningOp<gpu::CreateDnMatOp>())
1290-
return op.getMemref().getType().getElementType();
1291-
if (auto op = dn.getDefiningOp<gpu::CreateDnVecOp>())
1292-
return op.getMemref().getType().getElementType();
1293-
llvm_unreachable("cannot find dn def");
1294-
}
1295-
12961277
template <typename T>
12971278
static Value genConstInt32From(OpBuilder &builder, Location loc, T TValue) {
12981279
Type llvmInt32Type = builder.getIntegerType(32);
12991280
return builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
13001281
static_cast<int32_t>(TValue));
13011282
}
13021283

1303-
static Value
1304-
genConstInt32FromOptionalComputeMode(OpBuilder &builder, Location loc,
1305-
std::optional<Type> computeTypeOptional,
1306-
Type defaultType) {
1307-
auto computeTypeInt =
1308-
getCuSparseDataTypeFrom(computeTypeOptional.value_or(defaultType));
1309-
auto computeType = genConstInt32From(builder, loc, computeTypeInt);
1310-
return computeType;
1284+
static Value genConstInt32FromComputeMode(OpBuilder &builder, Location loc,
1285+
Type computeType) {
1286+
auto computeTypeInt = getCuSparseDataTypeFrom(computeType);
1287+
auto computeTypeConst = genConstInt32From(builder, loc, computeTypeInt);
1288+
return computeTypeConst;
13111289
}
13121290

13131291
LogicalResult ConvertCreateSparseEnvOpToGpuRuntimeCallPattern::matchAndRewrite(
@@ -1502,9 +1480,8 @@ LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
15021480
return failure();
15031481
Location loc = op.getLoc();
15041482
auto modeA = genConstInt32From(rewriter, loc, op.getModeA());
1505-
// retrieve the compute type, notice that it may be optional
1506-
auto computeType = genConstInt32FromOptionalComputeMode(
1507-
rewriter, loc, adaptor.getComputeType(), getDnElemType(op.getDnY()));
1483+
auto computeType =
1484+
genConstInt32FromComputeMode(rewriter, loc, adaptor.getComputeType());
15081485
auto stream = adaptor.getAsyncDependencies().front();
15091486
auto bufferSize =
15101487
spMVBufferSizeCallBuilder
@@ -1524,9 +1501,8 @@ LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
15241501
return failure();
15251502
Location loc = op.getLoc();
15261503
auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1527-
// retrieve the compute type, notice that it may be optional
1528-
auto computeType = genConstInt32FromOptionalComputeMode(
1529-
rewriter, loc, adaptor.getComputeType(), getDnElemType(op.getDnY()));
1504+
auto computeType =
1505+
genConstInt32FromComputeMode(rewriter, loc, adaptor.getComputeType());
15301506
auto stream = adaptor.getAsyncDependencies().front();
15311507
Value pBuf =
15321508
MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
@@ -1550,9 +1526,8 @@ LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
15501526
auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
15511527
auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
15521528
auto stream = adaptor.getAsyncDependencies().front();
1553-
// retrieve the compute type, notice that it may be optional
1554-
auto computeType = genConstInt32FromOptionalComputeMode(
1555-
rewriter, loc, adaptor.getComputeType(), getDnElemType(op.getDnmatC()));
1529+
auto computeType =
1530+
genConstInt32FromComputeMode(rewriter, loc, adaptor.getComputeType());
15561531

15571532
auto bufferSize = spMMBufferSizeCallBuilder
15581533
.create(loc, rewriter,
@@ -1573,9 +1548,8 @@ LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
15731548
Location loc = op.getLoc();
15741549
auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
15751550
auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1576-
auto computeType = genConstInt32FromOptionalComputeMode(
1577-
rewriter, loc, adaptor.getComputeType(),
1578-
getSpMatElemType(op.getSpmatC()));
1551+
auto computeType =
1552+
genConstInt32FromComputeMode(rewriter, loc, adaptor.getComputeType());
15791553
auto stream = adaptor.getAsyncDependencies().front();
15801554
auto bufferSize = SDDMMBufferSizeCallBuilder
15811555
.create(loc, rewriter,
@@ -1596,9 +1570,8 @@ LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
15961570
Location loc = op.getLoc();
15971571
auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
15981572
auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1599-
// retrieve the compute type, notice that it may be optional
1600-
auto computeType = genConstInt32FromOptionalComputeMode(
1601-
rewriter, loc, adaptor.getComputeType(), getDnElemType(op.getDnmatC()));
1573+
auto computeType =
1574+
genConstInt32FromComputeMode(rewriter, loc, adaptor.getComputeType());
16021575

16031576
auto stream = adaptor.getAsyncDependencies().front();
16041577
Value pBuf =
@@ -1628,9 +1601,8 @@ LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
16281601
failed(isAsyncWithOneDependency(rewriter, op)))
16291602
return failure();
16301603
Location loc = op.getLoc();
1631-
auto computeType = genConstInt32FromOptionalComputeMode(
1632-
rewriter, loc, adaptor.getComputeType(),
1633-
getSpMatElemType(op.getSpmatC()));
1604+
auto computeType =
1605+
genConstInt32FromComputeMode(rewriter, loc, adaptor.getComputeType());
16341606
auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
16351607
auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
16361608
auto stream = adaptor.getAsyncDependencies().front();

mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -462,18 +462,22 @@ static LogicalResult rewriteSpMV(PatternRewriter &rewriter,
462462
Value dnY = dvecY.getResult(0);
463463
token = dvecY.getAsyncToken();
464464

465+
auto dnYType = llvm::cast<ShapedType>(y.getType()).getElementType();
466+
465467
// Precompute buffersize for SpMV.
466468
auto bufferComp = rewriter.create<gpu::SpMVBufferSizeOp>(
467-
loc, indexTp, tokenTp, token, handle, spMatA, dnX, dnY);
469+
loc, indexTp, tokenTp, token, handle, spMatA, dnX, dnY,
470+
/*computeType=*/dnYType);
468471
Value bufferSz = bufferComp.getResult(0);
469472
token = bufferComp.getAsyncToken();
470473
auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
471474
Value buffer = buf.getResult(0);
472475
token = buf.getAsyncToken();
473476

474477
// Perform the SpMV.
475-
auto spmvComp = rewriter.create<gpu::SpMVOp>(loc, tokenTp, token, handle,
476-
spMatA, dnX, dnY, buffer);
478+
auto spmvComp =
479+
rewriter.create<gpu::SpMVOp>(loc, tokenTp, token, handle, spMatA, dnX,
480+
dnY, /*computeType=*/dnYType, buffer);
477481
token = spmvComp.getAsyncToken();
478482

479483
// Copy data back to host and free all the resoures.
@@ -565,18 +569,24 @@ static LogicalResult rewriteSpMM(PatternRewriter &rewriter,
565569
Value dnC = dmatC.getResult(0);
566570
token = dmatC.getAsyncToken();
567571

572+
auto dmatCType = llvm::cast<ShapedType>(c.getType()).getElementType();
573+
568574
// Precompute buffersize for SpMM.
569575
auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>(
570-
loc, indexTp, tokenTp, token, handle, spMatA, dnB, dnC);
576+
loc, indexTp, tokenTp, token, handle, spMatA, dnB, dnC,
577+
/*computeType=*/dmatCType);
571578
Value bufferSz = bufferComp.getResult(0);
572579
token = bufferComp.getAsyncToken();
573580
auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
574581
Value buffer = buf.getResult(0);
575582
token = buf.getAsyncToken();
576583

584+
auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType();
585+
577586
// Perform the SpMM.
578-
auto spmmComp = rewriter.create<gpu::SpMMOp>(loc, tokenTp, token, handle,
579-
spMatA, dnB, dnC, buffer);
587+
auto spmmComp =
588+
rewriter.create<gpu::SpMMOp>(loc, tokenTp, token, handle, spMatA, dnB,
589+
dnC, /*computeType=*/dnCType, buffer);
580590
token = spmmComp.getAsyncToken();
581591

582592
// Copy data back to host and free all the resoures.

mlir/test/Conversion/GPUCommon/lower-sparse-to-gpu-runtime-calls.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ module attributes {gpu.container_module} {
2323
%env, %token3 = gpu.create_sparse_env async [%token2]
2424
%spmat, %token4 = gpu.create_coo async [%token3] %arg0, %arg0, %arg0, %mem1, %mem1, %mem2 : memref<?xindex>, memref<?xindex>, memref<?xf64>
2525
%dnvec, %token5 = gpu.create_dn_vec async [%token4] %mem2, %arg0 : memref<?xf64>
26-
%bufferSz, %token6 = gpu.spmv_buffer_size async [%token5] %env, %spmat, %dnvec, %dnvec
27-
%token7 = gpu.spmv async [%token6] %env, %spmat, %dnvec, %dnvec, %mem2 : memref<?xf64>
26+
%bufferSz, %token6 = gpu.spmv_buffer_size async [%token5] %env, %spmat, %dnvec, %dnvec into f64
27+
%token7 = gpu.spmv async [%token6] %env, %spmat, %dnvec, %dnvec, %mem2 : memref<?xf64> into f64
2828
%token8 = gpu.destroy_sp_mat async [%token7] %spmat
2929
%token9 = gpu.destroy_dn_vec async [%token8] %dnvec
3030
%token10 = gpu.destroy_sparse_env async [%token9] %env
@@ -53,8 +53,8 @@ module attributes {gpu.container_module} {
5353
%env, %token3 = gpu.create_sparse_env async [%token2]
5454
%spmat, %token4 = gpu.create_csr async [%token3] %arg0, %arg0, %arg0, %mem1, %mem1, %mem2 : memref<?xindex>, memref<?xindex>, memref<?xf64>
5555
%dnmat, %token5 = gpu.create_dn_mat async [%token4] %arg0, %arg0, %mem2 : memref<?xf64>
56-
%bufferSz, %token6 = gpu.spmm_buffer_size async [%token5] %env, %spmat, %dnmat, %dnmat
57-
%token7 = gpu.spmm async [%token6] %env, %spmat, %dnmat, %dnmat, %mem2 : memref<?xf64>
56+
%bufferSz, %token6 = gpu.spmm_buffer_size async [%token5] %env, %spmat, %dnmat, %dnmat into f64
57+
%token7 = gpu.spmm async [%token6] %env, %spmat, %dnmat, %dnmat, %mem2 : memref<?xf64> into f64
5858
%token8 = gpu.destroy_sp_mat async [%token7] %spmat
5959
%token9 = gpu.destroy_dn_mat async [%token8] %dnmat
6060
%token10 = gpu.destroy_sparse_env async [%token9] %env
@@ -83,8 +83,8 @@ module attributes {gpu.container_module} {
8383
%env, %token3 = gpu.create_sparse_env async [%token2]
8484
%spmat, %token4 = gpu.create_csr async [%token3] %arg0, %arg0, %arg0, %mem1, %mem1, %mem2 : memref<?xindex>, memref<?xindex>, memref<?xf64>
8585
%dnmat, %token5 = gpu.create_dn_mat async [%token4] %arg0, %arg0, %mem2 : memref<?xf64>
86-
%bufferSz, %token6 = gpu.sddmm_buffer_size async [%token5] %env, %dnmat, %dnmat, %spmat
87-
%token7 = gpu.sddmm async [%token6] %env, %dnmat, %dnmat, %spmat, %mem2 : memref<?xf64>
86+
%bufferSz, %token6 = gpu.sddmm_buffer_size async [%token5] %env, %dnmat, %dnmat, %spmat into f64
87+
%token7 = gpu.sddmm async [%token6] %env, %dnmat, %dnmat, %spmat, %mem2 : memref<?xf64> into f64
8888
%token8 = gpu.destroy_sp_mat async [%token7] %spmat
8989
%token9 = gpu.destroy_dn_mat async [%token8] %dnmat
9090
%token10 = gpu.destroy_sparse_env async [%token9] %env

0 commit comments

Comments
 (0)