Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Noir not optimizing out increments to mutable variable in Brillig #6971

Open
asterite opened this issue Jan 7, 2025 · 6 comments
Open

Noir not optimizing out increments to mutable variable in Brillig #6971

asterite opened this issue Jan 7, 2025 · 6 comments
Labels
bug Something isn't working

Comments

@asterite
Copy link
Collaborator

asterite commented Jan 7, 2025

Aim

We got this reduction with @aakoshh while debugging an increase in opcodes in some program.

Consider this code:

global MAX: u32 = 2;
global MIN: u32 = 1;

fn main(input: [u8; MIN]) -> pub [u8; MAX] {
    let mut output = [0; MAX];
    let mut offset = 0;

    for i in 0..MIN {
        output[offset] = input[i];
    }
    offset += MIN;

    let size = MAX - offset;
    for _ in 0..size {
        output[offset] = 0;
    }

    output
}

The final SSA, when compiled with --force-brillig is this:

brillig(inline) fn main f0 {
  b0(v0: [u8; 1]):
    v2 = allocate -> &mut [u8; 2]
    v3 = allocate -> &mut u32
    v5 = array_get v0, index u32 0 -> u8
    v7 = make_array [v5, u8 0] : [u8; 2]
    store v7 at v2
    store u32 1 at v3
    jmp b1(u32 0)
  b1(v1: u32):
    v9 = eq v1, u32 0
    jmpif v9 then: b3, else: b2
  b2():
    v10 = load v2 -> [u8; 2]
    return v10
  b3():
    v11 = load v2 -> [u8; 2]
    v12 = load v3 -> u32
    v13 = array_set v11, index v12, value u8 0
    store v13 at v2
    v14 = add v1, u32 1
    jmp b1(v14)
}

Now, if we change offset += MIN to offset = MIN:

global MAX: u32 = 2;
global MIN: u32 = 1;

fn main(input: [u8; MIN]) -> pub [u8; MAX] {
    let mut output = [0; MAX];
    let mut offset = 0;

    for i in 0..MIN {
        output[offset] = input[i];
    }
    offset = MIN;

    let size = MAX - offset;
    for _ in 0..size {
        output[offset] = 0;
    }

    output
}

we get this SSA:

brillig(inline) fn main f0 {
  b0(v0: [u8; 1]):
    v2 = array_get v0, index u32 0 -> u8
    v4 = make_array [v2, u8 0] : [u8; 2]
    return v4
}

Given that tracking an offset value, mutating it, and using it to modify arrays is a very common operation, this could be a good optimization to try... or first understand why the compiler isn't already optimizing this.

Expected Behavior

The compiler should produce the second final SSA in both cases.

Bug

The compiler isn't fully optimizing the program as much as it could.

Relatedly, if the code is this:

global MAX: u32 = 2;
global MIN: u32 = 1;

fn main(input: [u8; MIN]) -> pub [u8; MAX] {
    let mut output = [0; MAX];
    let mut offset = 0;

    for i in 0..MIN {
        output[offset] = input[i];
    }
    offset += MIN;

    let size = MAX - MIN;
    for _ in 0..size {
        output[offset] = 0;
    }

    output
}

then the program is correctly optimized... so there might be more things going on here.

To Reproduce

Workaround

None

Workaround Description

No response

Additional Context

No response

Project Impact

None

Blocker Context

No response

Nargo Version

No response

NoirJS Version

No response

Proving Backend Tooling & Version

No response

Would you like to submit a PR for this Issue?

None

Support Needs

No response

@asterite asterite added the bug Something isn't working label Jan 7, 2025
@asterite asterite changed the title Noir not optimizing increments to mutable variable in Brillig Noir not optimizing out increments to mutable variable in Brillig Jan 7, 2025
@TomAFrench TomAFrench transferred this issue from noir-lang/vscode-noir Jan 7, 2025
@jfecher
Copy link
Contributor

jfecher commented Jan 7, 2025

Given that tracking an offset value, mutating it, and using it to modify arrays is a very common operation, this could be a good optimization to try... or first understand why the compiler isn't already optimizing this.

This looks like a known downside of the current mem2reg pass which is in handling loops. See the early comments on #4535 as an example. The short is that due to the back edges in the call graph we're unsure if the reference may be mutated or not at that point so we can't optimize offset it to a known value. When you do offset += MIN; we have to load from offset first to get the current value which would be unknown, so the result is unknown. When you do offset = MIN; we set regardless of the previous value so the result becomes known, hence the improved optimizations afterward..

@asterite
Copy link
Collaborator Author

asterite commented Jan 7, 2025

When you do offset += MIN; we have to load from offset first to get the current value which would be unknown, so the result is unknown.

Couldn't the previous set value be tracked, so that if we have to load a value to add to it we'd know it?

@asterite
Copy link
Collaborator Author

asterite commented Jan 7, 2025

Some other things I'm finding. If we don't use offset before incrementing it:

global MAX: u32 = 2;
global MIN: u32 = 1;

fn main(input: [u8; MIN]) -> pub [u8; MAX] {
    let mut output = [0; MAX];
    let mut offset = 0;

    // for i in 0..MIN {
    //     output[offset] = input[i];
    // }
    offset += MIN;

    let size = MAX - offset;
    for _ in 0..size {
        output[offset] = 0;
    }

    output
}

then it's correctly optimized.

Also if we print offset instead of using it in a loop:

global MAX: u32 = 2;
global MIN: u32 = 1;

fn main(input: [u8; MIN]) -> pub [u8; MAX] {
    let mut output = [0; MAX];
    let mut offset = 0;

    for i in 0..MIN {
        output[offset] = input[i];
    }
    offset += MIN;

    println(offset);

    output
}

the SSA has the 1 value hardcoded in it, so it was able to compute offset from the += operation. Maybe the loop is preventing the initial value in the loop from being computed...

@asterite
Copy link
Collaborator Author

asterite commented Jan 7, 2025

I'll keep writing the things I found, I'm trying to see if I can improve this.

For this program:

global MAX: u32 = 2;
global MIN: u32 = 1;

fn main(input: [u8; MIN]) -> pub [u8; MAX] {
    let mut output = [0; MAX];
    let mut offset: u32 = 0;

    for i in 0..MIN {
        output[offset] = input[i];
    }
    offset += MIN;

    let size = MAX - offset;
    for i in 0..size {
        output[offset + i] = 0;
    }

    output
}

at one point (with --force-brillig) we get:

After After Removing Bit Shifts:
brillig(inline) fn main f0 {
  b0(v0: [u8; 1]):
    v3 = make_array [u8 0, u8 0] : [u8; 2]
    v4 = allocate -> &mut [u8; 2]
    store v3 at v4
    v5 = allocate -> &mut u32
    store u32 0 at v5
    v7 = load v4 -> [u8; 2]
    v8 = load v5 -> u32
    v9 = array_get v0, index u32 0 -> u8
    v10 = array_set v7, index v8, value v9
    v12 = add v8, u32 1
    store v10 at v4
    v13 = load v5 -> u32
    v14 = add v13, u32 1
    store v14 at v5
    v16 = sub u32 2, v14
    jmp b1(u32 0)
  b1(v1: u32):
    v17 = lt v1, v16
    jmpif v17 then: b3, else: b2
  b2():
    v18 = load v4 -> [u8; 2]
    return v18
  b3():
    v19 = load v4 -> [u8; 2]
    v20 = load v5 -> u32
    v21 = add v20, v1
    v22 = array_set v19, index v21, value u8 0
    v23 = add v21, u32 1
    store v22 at v4
    v24 = add v1, u32 1
    jmp b1(v24)
}

After Mem2Reg (2nd):
brillig(inline) fn main f0 {
  b0(v0: [u8; 1]):
    v3 = make_array [u8 0, u8 0] : [u8; 2]
    v4 = allocate -> &mut [u8; 2]
    v5 = allocate -> &mut u32
    v7 = array_get v0, index u32 0 -> u8
    v8 = make_array [v7, u8 0] : [u8; 2]
    store v8 at v4
    store u32 1 at v5
    jmp b1(u32 0)
  b1(v1: u32):
    v10 = eq v1, u32 0
    jmpif v10 then: b3, else: b2
  b2():
    v11 = load v4 -> [u8; 2]
    return v11
  b3():
    v12 = load v4 -> [u8; 2]
    v13 = load v5 -> u32
    v14 = add v13, v1
    v15 = array_set v12, index v14, value u8 0
    v16 = add v14, u32 1
    store v15 at v4
    v17 = add v1, u32 1
    jmp b1(v17)
}

We can see that these:

    v13 = load v5 -> u32
    v14 = add v13, u32 1
    store v14 at v5

were replaced by this:

    store u32 1 at v5

However, a bit below there's still this:

    v13 = load v5 -> u32

I don't know why that load wasn't replaced with the constant "1". Maybe it's because b3 predecesor is b1, and its predecesors are b0 and b3 so it gets confused (v5 is read in b3 but not written, maybe that confuses mem2reg).

@asterite
Copy link
Collaborator Author

@jfecher I was wondering if mem2reg could be optimized in the following way:

  • First do a run on all blocks to compute what addresses are stored to
  • When unifying the value of two blocks, if one of the blocks don't have stores to that address (and its predecessors also don't have stores to that address), keep the value of the other block

For example, given this SSA:

brillig(inline) fn main f0 {
  b0(v0: [u8; 1]):
    v2 = allocate -> &mut u32
    store u32 1 at v2
    jmp b1(u32 0)
  b1(v1: u32):
    jmpif u1 1 then: b3, else: b2
  b2():
    v7 = load v2 -> u32
    return v7
  b3():
    v8 = load v2 -> u32
    jmp b1(v8)
}

when computing the value of v2 in block b1, we have that b1 has b0 and b3 as predecessors. For b0 we know it's u32 1. For b3 we don't know yet, but we can see that there are no stores to v2 there. And the predecessor of b3 is b1, which is the block we are currently analyzing, so stop the recursion and conclude that "coming from b3 there are no stores at v2" so we keep the value we have from b0.

Given this code, though:

brillig(inline) fn main f0 {
  b0(v0: [u8; 1]):
    v2 = allocate -> &mut u32
    store u32 1 at v2
    jmp b1(u32 0)
  b1(v1: u32):
    jmpif u1 1 then: b3, else: b2
  b2():
    v7 = load v2 -> u32
    return v7
  b3():
    jmpif u1 1 then: b4, else: b5
  b4():
    store u32 2 at v2
    jmp b6()
  b5():
    jmp b6()
  b6():
    v8 = load v2 -> u32
    jmp b1(v8)
}

When computing the value of v2 in block b1, we have that b1 has b0 and b6 as predecessors. For b0 we know it's u32 1. For b6 we don't know but there are no stores at v2 there. But b6 has b4 and b5 as predecessors, but b4 has a store at v2 so we conclude the value at v2 is changed at this branch, so we don't optimize this out.

I actually started to code this to see if it works, but Block.references is not what I thought it is in the mem2reg code so now I'm not sure how to do it... but I wanted to leave this idea here in case it's useful.

@jfecher
Copy link
Contributor

jfecher commented Jan 13, 2025

@asterite it depends on your "(and its predecessors also don't have stores to that address)" qualification.

If by that you mean only its direct predecessors, then it could be broken by adding more blocks so that the store is more hidden. If it doesn't find the store, it'll assume the other value which ignores the store (but shouldn't).

If by that you mean recursively check all predecessors (for every reference?) then I think it almost works but gets caught in the weeds. Where it gets caught is: how do you know the values for each reference in b6 when you've only computed up to b1 so far? We need to do an actual pass/check for these blocks instead of just looking for stores after the fact because if we don't know the reference values, we don't actually know what the stores may be storing to. It may say store v10 in v100 but does v100 alias another reference?

You mentioned doing a scanning pass beforehand but this'd just move the same question to that scanning pass. As long as you're only doing one scanning pass you'll inherently have this issue I think.

My current thinking on this is that we'll probably have to move to do something more similar to other existing mem2reg passes and actually do multiple scans through the function. How those passes work is they scan the entire function multiple times until the results become stable (stop changing after each scan). In our case we'd expect loops would initially start with Unknown values but after the second pass would become Known. For our current algorithm though I think this new way would have performance implications. Since after the second pass the value becomes known... but what if there is just another loop afterward? It's starting value would be known so it is the same situation as we started with, so we'd need another pass for each subsequent loop. For a program with possibly many loops we'd be iterating at least once for each loop if we really wanted to run this until the results were stable. One option from here would be just to statically limit this to a maximum of say 5 iterations (arbitrary) and just accept the generated code could be optimized more.

Edit: If we don't go with a new mem2reg algorithm then to solve this with (mostly) our current algorithm I think we'll need some way to let mem2reg know "v1 or its aliases do not change in this loop (set of blocks)" ahead of time so that we can use that in the unify check. If we're finding starting values and are between b0 where *v1 = 9 and b4 where *v1 = ? but know that v1 or its aliases don't change in b4 or predecessors then we can just take the result from b0. Getting that information is more difficult though due to the "or its aliases" part of the constraint.

I actually started to code this to see if it works, but Block.references is not what I thought it is in the mem2reg code so now I'm not sure how to do it... but I wanted to leave this idea here in case it's useful.

Block.references should contain the value a given reference is thought to store. Either Known(value) or Unknown. To get that though, you have to go through Block.expressions and Block.aliases first. Block.expressions to give you the "canonical" expression for a reference value, and Block.aliases to check if that reference is known to refer to one alias or may possibly refer to a bunch. Only if it refers to one can you then proceed to Block.references with that single alias value. These fields should really be private. I've made them private in a few PRs in the past but those have been riding along with other experimental changes that were never merged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
Status: No status
Development

No branches or pull requests

2 participants