Skip to content

Commit d52cb9c

Browse files
wsmosestgymnich
andauthored
Add memref/llvm.ptr handling for fwd mode (rust-lang#910)
* Add memref handling for fwd mode * Add simple llvm dialect * Update enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp Co-authored-by: Tim Gymnich <[email protected]> * fixup Co-authored-by: Tim Gymnich <[email protected]>
1 parent 7da8126 commit d52cb9c

File tree

9 files changed

+278
-0
lines changed

9 files changed

+278
-0
lines changed

enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
add_mlir_library(MLIREnzymeImplementations
22
ArithAutoDiffOpInterfaceImpl.cpp
3+
LLVMAutoDiffOpInterfaceImpl.cpp
4+
MemRefAutoDiffOpInterfaceImpl.cpp
35
BuiltinAutoDiffTypeInterfaceImpl.cpp
46
SCFAutoDiffOpInterfaceImpl.cpp
57

@@ -8,6 +10,8 @@ add_mlir_library(MLIREnzymeImplementations
810

911
LINK_LIBS PUBLIC
1012
MLIRArithDialect
13+
MLIRLLVMDialect
14+
MLIRMemRefDialect
1115
MLIREnzymeAutoDiffInterface
1216
MLIRIR
1317
MLIRSCFDialect

enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ class DialectRegistry;
1818
namespace enzyme {
1919
void registerArithDialectAutoDiffInterface(DialectRegistry &registry);
2020
void registerBuiltinDialectAutoDiffInterface(DialectRegistry &registry);
21+
void registerLLVMDialectAutoDiffInterface(DialectRegistry &registry);
22+
void registerMemRefDialectAutoDiffInterface(DialectRegistry &registry);
2123
void registerSCFDialectAutoDiffInterface(DialectRegistry &registry);
2224
} // namespace enzyme
2325
} // namespace mlir
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
//===- LLVMAutoDiffOpInterfaceImpl.cpp - Interface external model --------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file contains the external model implementation of the automatic
10+
// differentiation op interfaces for the upstream LLVM dialect.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "Implementations/CoreDialectsAutoDiffImplementations.h"
15+
#include "Interfaces/AutoDiffOpInterface.h"
16+
#include "Interfaces/AutoDiffTypeInterface.h"
17+
#include "Interfaces/GradientUtils.h"
18+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19+
#include "mlir/IR/DialectRegistry.h"
20+
#include "mlir/Support/LogicalResult.h"
21+
22+
using namespace mlir;
23+
using namespace mlir::enzyme;
24+
25+
namespace {
26+
struct LoadOpInterface
27+
: public AutoDiffOpInterface::ExternalModel<LoadOpInterface, LLVM::LoadOp> {
28+
LogicalResult createForwardModeAdjoint(Operation *op, OpBuilder &builder,
29+
MGradientUtils *gutils) const {
30+
auto loadOp = cast<LLVM::LoadOp>(op);
31+
if (!gutils->isConstantValue(loadOp)) {
32+
mlir::Value res = builder.create<LLVM::LoadOp>(
33+
loadOp.getLoc(), gutils->invertPointerM(loadOp.getAddr(), builder));
34+
gutils->setDiffe(loadOp, res, builder);
35+
}
36+
gutils->eraseIfUnused(op);
37+
return success();
38+
}
39+
};
40+
41+
struct StoreOpInterface
42+
: public AutoDiffOpInterface::ExternalModel<StoreOpInterface,
43+
LLVM::StoreOp> {
44+
LogicalResult createForwardModeAdjoint(Operation *op, OpBuilder &builder,
45+
MGradientUtils *gutils) const {
46+
auto storeOp = cast<LLVM::StoreOp>(op);
47+
if (!gutils->isConstantValue(storeOp.getAddr())) {
48+
builder.create<LLVM::StoreOp>(
49+
storeOp.getLoc(), gutils->invertPointerM(storeOp.getValue(), builder),
50+
gutils->invertPointerM(storeOp.getAddr(), builder));
51+
}
52+
gutils->eraseIfUnused(op);
53+
return success();
54+
}
55+
};
56+
57+
struct AllocaOpInterface
58+
: public AutoDiffOpInterface::ExternalModel<AllocaOpInterface,
59+
LLVM::AllocaOp> {
60+
LogicalResult createForwardModeAdjoint(Operation *op, OpBuilder &builder,
61+
MGradientUtils *gutils) const {
62+
auto allocOp = cast<LLVM::AllocaOp>(op);
63+
if (!gutils->isConstantValue(allocOp)) {
64+
Operation *nop = gutils->cloneWithNewOperands(builder, op);
65+
gutils->setDiffe(allocOp, nop->getResult(0), builder);
66+
}
67+
gutils->eraseIfUnused(op);
68+
return success();
69+
}
70+
};
71+
72+
class PointerTypeInterface
73+
: public AutoDiffTypeInterface::ExternalModel<PointerTypeInterface,
74+
LLVM::LLVMPointerType> {
75+
public:
76+
Value createNullValue(Type self, OpBuilder &builder, Location loc) const {
77+
return builder.create<LLVM::NullOp>(loc, self);
78+
}
79+
80+
Type getShadowType(Type self, unsigned width) const {
81+
assert(width == 1 && "unsupported width != 1");
82+
return self;
83+
}
84+
};
85+
} // namespace
86+
87+
void mlir::enzyme::registerLLVMDialectAutoDiffInterface(
88+
DialectRegistry &registry) {
89+
registry.addExtension(+[](MLIRContext *context, LLVM::LLVMDialect *) {
90+
LLVM::LoadOp::attachInterface<LoadOpInterface>(*context);
91+
LLVM::StoreOp::attachInterface<StoreOpInterface>(*context);
92+
LLVM::AllocaOp::attachInterface<AllocaOpInterface>(*context);
93+
LLVM::LLVMPointerType::attachInterface<PointerTypeInterface>(*context);
94+
});
95+
}
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
//===- MemRefAutoDiffOpInterfaceImpl.cpp - Interface external model -------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file contains the external model implementation of the automatic
10+
// differentiation op interfaces for the upstream MLIR memref dialect.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "Implementations/CoreDialectsAutoDiffImplementations.h"
15+
#include "Interfaces/AutoDiffOpInterface.h"
16+
#include "Interfaces/AutoDiffTypeInterface.h"
17+
#include "Interfaces/GradientUtils.h"
18+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
19+
#include "mlir/IR/DialectRegistry.h"
20+
#include "mlir/Support/LogicalResult.h"
21+
22+
using namespace mlir;
23+
using namespace mlir::enzyme;
24+
25+
namespace {
26+
struct LoadOpInterface
27+
: public AutoDiffOpInterface::ExternalModel<LoadOpInterface,
28+
memref::LoadOp> {
29+
LogicalResult createForwardModeAdjoint(Operation *op, OpBuilder &builder,
30+
MGradientUtils *gutils) const {
31+
auto loadOp = cast<memref::LoadOp>(op);
32+
if (!gutils->isConstantValue(loadOp)) {
33+
SmallVector<Value> inds;
34+
for (auto ind : loadOp.getIndices())
35+
inds.push_back(gutils->getNewFromOriginal(ind));
36+
mlir::Value res = builder.create<memref::LoadOp>(
37+
loadOp.getLoc(), gutils->invertPointerM(loadOp.getMemref(), builder),
38+
inds);
39+
gutils->setDiffe(loadOp, res, builder);
40+
}
41+
gutils->eraseIfUnused(op);
42+
return success();
43+
}
44+
};
45+
46+
struct StoreOpInterface
47+
: public AutoDiffOpInterface::ExternalModel<StoreOpInterface,
48+
memref::StoreOp> {
49+
LogicalResult createForwardModeAdjoint(Operation *op, OpBuilder &builder,
50+
MGradientUtils *gutils) const {
51+
auto storeOp = cast<memref::StoreOp>(op);
52+
if (!gutils->isConstantValue(storeOp.getMemref())) {
53+
SmallVector<Value> inds;
54+
for (auto ind : storeOp.getIndices())
55+
inds.push_back(gutils->getNewFromOriginal(ind));
56+
builder.create<memref::StoreOp>(
57+
storeOp.getLoc(), gutils->invertPointerM(storeOp.getValue(), builder),
58+
gutils->invertPointerM(storeOp.getMemref(), builder), inds);
59+
}
60+
gutils->eraseIfUnused(op);
61+
return success();
62+
}
63+
};
64+
65+
struct AllocOpInterface
66+
: public AutoDiffOpInterface::ExternalModel<AllocOpInterface,
67+
memref::AllocOp> {
68+
LogicalResult createForwardModeAdjoint(Operation *op, OpBuilder &builder,
69+
MGradientUtils *gutils) const {
70+
auto allocOp = cast<memref::AllocOp>(op);
71+
if (!gutils->isConstantValue(allocOp)) {
72+
Operation *nop = gutils->cloneWithNewOperands(builder, op);
73+
gutils->setDiffe(allocOp, nop->getResult(0), builder);
74+
}
75+
gutils->eraseIfUnused(op);
76+
return success();
77+
}
78+
};
79+
80+
class MemRefTypeInterface
81+
: public AutoDiffTypeInterface::ExternalModel<MemRefTypeInterface,
82+
MemRefType> {
83+
public:
84+
Value createNullValue(Type self, OpBuilder &builder, Location loc) const {
85+
llvm_unreachable("Cannot create null of memref (todo polygeist null)");
86+
}
87+
88+
Type getShadowType(Type self, unsigned width) const {
89+
assert(width == 1 && "unsupported width != 1");
90+
return self;
91+
}
92+
};
93+
} // namespace
94+
95+
void mlir::enzyme::registerMemRefDialectAutoDiffInterface(
96+
DialectRegistry &registry) {
97+
registry.addExtension(+[](MLIRContext *context, memref::MemRefDialect *) {
98+
memref::LoadOp::attachInterface<LoadOpInterface>(*context);
99+
memref::StoreOp::attachInterface<StoreOpInterface>(*context);
100+
memref::AllocOp::attachInterface<AllocOpInterface>(*context);
101+
MemRefType::attachInterface<MemRefTypeInterface>(*context);
102+
});
103+
}

enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,14 @@ mlir::enzyme::MGradientUtils::getNewFromOriginal(Operation *originst) const {
126126
return found->second;
127127
}
128128

129+
Operation *mlir::enzyme::MGradientUtils::cloneWithNewOperands(OpBuilder &B,
130+
Operation *op) {
131+
BlockAndValueMapping map;
132+
for (auto operand : op->getOperands())
133+
map.map(operand, getNewFromOriginal(operand));
134+
return B.clone(*op, map);
135+
}
136+
129137
bool mlir::enzyme::MGradientUtils::isConstantValue(Value v) const {
130138
if (isa<mlir::IntegerType>(v.getType()))
131139
return true;

enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ class MGradientUtils {
7777
void setDiffe(mlir::Value val, mlir::Value toset, OpBuilder &BuilderM);
7878
void forceAugmentedReturns();
7979

80+
Operation *cloneWithNewOperands(OpBuilder &B, Operation *op);
81+
8082
LogicalResult visitChild(Operation *op);
8183
};
8284

enzyme/Enzyme/MLIR/enzymemlir-opt.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ int main(int argc, char **argv) {
9292
// Register the autodiff interface implementations for upstream dialects.
9393
enzyme::registerArithDialectAutoDiffInterface(registry);
9494
enzyme::registerBuiltinDialectAutoDiffInterface(registry);
95+
enzyme::registerLLVMDialectAutoDiffInterface(registry);
96+
enzyme::registerMemRefDialectAutoDiffInterface(registry);
9597
enzyme::registerSCFDialectAutoDiffInterface(registry);
9698

9799
return mlir::failed(

enzyme/test/MLIR/llvm.mlir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// RUN: %eopt --enzyme %s | FileCheck %s
2+
3+
module {
4+
func.func @square(%x : f64) -> f64 {
5+
%c1_i64 = arith.constant 1 : i64
6+
%tmp = llvm.alloca %c1_i64 x f64 : (i64) -> !llvm.ptr<f64>
7+
%y = arith.mulf %x, %x : f64
8+
llvm.store %y, %tmp : !llvm.ptr<f64>
9+
%r = llvm.load %tmp : !llvm.ptr<f64>
10+
return %r : f64
11+
}
12+
func.func @dsq(%x : f64, %dx : f64) -> f64 {
13+
%r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme<activity enzyme_dup>] } : (f64, f64) -> (f64)
14+
return %r : f64
15+
}
16+
}
17+
18+
// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f64 {
19+
// CHECK-NEXT: %[[c1_i64:.+]] = arith.constant 1 : i64
20+
// CHECK-NEXT: %[[i0:.+]] = llvm.alloca %[[c1_i64]] x f64 : (i64) -> !llvm.ptr<f64>
21+
// CHECK-NEXT: %[[i1:.+]] = llvm.alloca %[[c1_i64]] x f64 : (i64) -> !llvm.ptr<f64>
22+
// CHECK-NEXT: %[[i2:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64
23+
// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64
24+
// CHECK-NEXT: %[[i4:.+]] = arith.addf %[[i2]], %[[i3]] : f64
25+
// CHECK-NEXT: %[[i5:.+]] = arith.mulf %[[arg0]], %[[arg0]] : f64
26+
// CHECK-NEXT: llvm.store %[[i4]], %[[i0]] : !llvm.ptr<f64>
27+
// CHECK-NEXT: llvm.store %[[i5]], %[[i1]] : !llvm.ptr<f64>
28+
// CHECK-NEXT: %[[i6:.+]] = llvm.load %[[i0]] : !llvm.ptr<f64>
29+
// CHECK-NEXT: %[[i7:.+]] = llvm.load %[[i1]] : !llvm.ptr<f64>
30+
// CHECK-NEXT: return %[[i6]] : f64
31+
// CHECK-NEXT: }

enzyme/test/MLIR/memref.mlir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// RUN: %eopt --enzyme %s | FileCheck %s
2+
3+
module {
4+
func.func @square(%x : f64) -> f64 {
5+
%c0 = arith.constant 0 : index
6+
%tmp = memref.alloc() : memref<1xf64>
7+
%y = arith.mulf %x, %x : f64
8+
memref.store %y, %tmp[%c0] : memref<1xf64>
9+
%r = memref.load %tmp[%c0] : memref<1xf64>
10+
return %r : f64
11+
}
12+
func.func @dsq(%x : f64, %dx : f64) -> f64 {
13+
%r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme<activity enzyme_dup>] } : (f64, f64) -> (f64)
14+
return %r : f64
15+
}
16+
}
17+
18+
// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f64 {
19+
// CHECK-NEXT: %[[c0:.+]] = arith.constant 0 : index
20+
// CHECK-NEXT: %[[i0:.+]] = memref.alloc() : memref<1xf64>
21+
// CHECK-NEXT: %[[i1:.+]] = memref.alloc() : memref<1xf64>
22+
// CHECK-NEXT: %[[i2:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64
23+
// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64
24+
// CHECK-NEXT: %[[i4:.+]] = arith.addf %[[i2]], %[[i3]] : f64
25+
// CHECK-NEXT: %[[i5:.+]] = arith.mulf %[[arg0]], %[[arg0]] : f64
26+
// CHECK-NEXT: memref.store %[[i4]], %[[i0]][%[[c0]]] : memref<1xf64>
27+
// CHECK-NEXT: memref.store %[[i5]], %[[i1]][%[[c0]]] : memref<1xf64>
28+
// CHECK-NEXT: %[[i6:.+]] = memref.load %[[i0]][%[[c0]]] : memref<1xf64>
29+
// CHECK-NEXT: %[[i7:.+]] = memref.load %[[i1]][%[[c0]]] : memref<1xf64>
30+
// CHECK-NEXT: return %[[i6]] : f64
31+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)