Skip to content

Fix revert mode for memref.alloca#2826

Open
xys-syx wants to merge 8 commits into
mainfrom
memref-alloca
Open

Fix revert mode for memref.alloca#2826
xys-syx wants to merge 8 commits into
mainfrom
memref-alloca

Conversation

@xys-syx
Copy link
Copy Markdown
Collaborator

@xys-syx xys-syx commented May 15, 2026

  1. use scf.parallel + memref.store/memref.dim to replace linalg.fill
  2. substitue the enzyme.placeholder with real shadow

The key issue for memref.alloca is: the mutable type's enzyme.placeholder does not be subtitute by the real shadow in reverse mode.

The output before fix:

module {
  func.func @foo_flat(%arg0: f64) -> f64 {
    %alloca = memref.alloca() : memref<f64>
    memref.store %arg0, %alloca[] : memref<f64>
    %0 = memref.load %alloca[] : memref<f64>
    return %0 : f6 }
  func.func @dfoo_flat(%arg0: f64, %arg1: f64) -> f64 {
    %0 = call @diffefoo_flat(%arg0, %arg1) : (f64, f64) -> f64
    return %0 : f64
  }
  func.func private @diffefoo_flat(%arg0: f64, %arg1: f64) -> f64 {
    %0 = "enzyme.init"() : () -> !enzyme.Gradient<f64>
    %cst = arith.constant 0.000000e+00 : f64
    "enzyme.set"(%0, %cst) : (!enzyme.Gradient<f64>, f64) -> ()
    %1 = "enzyme.init"() : () -> !enzyme.Cache<memref<f64>>
    %2 = "enzyme.init"() : () -> !enzyme.Cache<memref<f64>>
    %3 = "enzyme.init"() : () -> !enzyme.Gradient<f64>
    %cst_0 = arith.constant 0.000000e+00 : f64
    "enzyme.set"(%3, %cst_0) : (!enzyme.Gradient<f64>, f64) -> ()
    %4 = "enzyme.placeholder"() : () -> memref<f64>
    %alloca = memref.alloca() : memref<f64>
    %cst_1 = arith.constant 0.000000e+00 : f64
    memref.store %cst_1, %alloca[] : memref<f64>
    %alloca_2 = memref.alloca() : memref<f64>
    "enzyme.push"(%1, %4) : (!enzyme.Cache<memref<f64>>, memref<f64>) -> ()
    memref.store %arg0, %alloca_2[] : memref<f64>
    "enzyme.push"(%2, %4) : (!enzyme.Cache<memref<f64>>, memref<f64>) -> ()
    %5 = memref.load %alloca_2[] : memref<f64>
    cf.br ^bb1
  ^bb1:  // pred: ^bb0
    %6 = "enzyme.get"(%3) : (!enzyme.Gradient<f64>) -> f64
    %7 = arith.addf %6, %arg1 : f64
    "enzyme.set"(%3, %7) : (!enzyme.Gradient<f64>, f64) -> ()
    %8 = "enzyme.get"(%3) : (!enzyme.Gradient<f64>) -> f64
    %9 = "enzyme.pop"(%2) : (!enzyme.Cache<memref<f64>>) -> memref<f64>
    %10 = memref.load %9[] : memref<f64>
    %11 = arith.addf %10, %8 : f64
    memref.store %11, %9[] : memref<f64>
    %12 = "enzyme.pop"(%1) : (!enzyme.Cache<memref<f64>>) -> memref<f64>
    %13 = memref.load %12[] : memref<f64>
    %14 = "enzyme.get"(%0) : (!enzyme.Gradient<f64>) -> f64
    %15 = arith.addf %14, %13 : f64
    "enzyme.set"(%0, %15) : (!enzyme.Gradient<f64>, f64) -> ()
    %cst_3 = arith.constant 0.000000e+00 : f64
    memref.store %cst_3, %12[] : memref<f64>
    %16 = "enzyme.get"(%0) : (!enzyme.Gradient<f64>) -> f64
    return %16 : f64
  }
}

@xys-syx xys-syx requested a review from vimarsh6739 May 15, 2026 04:06
void mlir::enzyme::MGradientUtils::setInvertedPointer(Value val, Value toset) {
assert(getShadowType(val.getType()) == toset.getType());
auto found = invertedPointers.lookupOrNull(val);
assert(found != nullptr);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are you removing the assert here?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we expecting missing entries for val?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I previously thought that a mutable block argument is mapped to a real value rather than a PlaceholderOp, so found.getDefiningOp<PlaceholderOp>() would return null. But I have now carefully checked that mutable values do not go through the addToDiffe → setDiffe path, we do not need to worry about that

@xys-syx xys-syx requested a review from vimarsh6739 May 20, 2026 21:47

if (mode == DerivativeMode::ForwardMode ||
mode == DerivativeMode::ForwardModeSplit) {
mode == DerivativeMode::ForwardModeSplit || isMutable) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

func.func @square(%x : memref<f64>){
     %y = memref.load %x[] : f64
     return %y
}

%out = enzyme.fwddiff dsquare(%x : memref<f64>, %dx : memref<f64>) {act = [enzyme_dup] ....}

In this case, the user provides the inverted Pointer right? We should ensure that we arent creating a shadow

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i have add a guard in setDiffe when val is mutable or intertedPointers[val] is not a placeholderop, we skip instead of rewriting it.

@xys-syx xys-syx requested a review from vimarsh6739 May 23, 2026 04:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants