Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fori_loop|While_loop] Enable while_loop/fori_loop, Add linear/MNIST test case #6867

Closed
wants to merge 323 commits into from

Conversation

ManfeiBai
Copy link
Collaborator

@ManfeiBai ManfeiBai commented Apr 1, 2024

  • add linear model test case with fori_loop or while_loop
  • modify logic of buildforiloop to chaneg cond/body's xla_computation to meet requirement of xla::While: cond's input, body's input/output and init should be the same shape;
  • modify input arguments shape/order to meet xla::While's requirement mentioned above

next plan:

  • init_python_bindings.cpp: get body xlacomputation arguments' number first then decide items in additional_inputs_list, maybe implement in python level

@ManfeiBai ManfeiBai force-pushed the fori_loop_simple_case_test branch 3 times, most recently from 4e0dcaa to 50193f3 Compare April 9, 2024 20:17
@ManfeiBai ManfeiBai force-pushed the fori_loop_simple_case_test branch 8 times, most recently from 92d17fd to a79a609 Compare April 16, 2024 17:34
@ManfeiBai ManfeiBai marked this pull request as ready for review April 18, 2024 00:31
res_list.insert(0, lower)
res_list.insert(0, torch.sub(upper, one_value_i))
return res_list
if (hasattr(body_fun, 'weight') or hasattr(body_fun, 'bias')):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the plan to remove this condition?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually why do we special casing the weight and bias? Is it only for the linear layer?

Copy link
Collaborator Author

@ManfeiBai ManfeiBai Apr 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, its only for the linear layer,

special casing the weight and bias due to different body_fn return: weight/bias was not mentioned in inputs, but need to be returned or added in xlacomputation return arguments

plans to remove this condition is:

  • A: check additional_inputs(weight/bias) like PyTorch and add them into xlacomputation arguments in CPP level
  • B: only check whether additional_inputs(weight/bias) exist for code of weight/bias in the next step

Comment on lines +926 to +927
xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape,
"UnusedArgumentsPlaceholder");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand this part, x is a local variable that will be released after the for loop, is the intention of calling xla::Parameter to init a parameter at position parameter_idx with given shape? If so can you add a comment to make it more clear?

Copy link
Collaborator Author

@ManfeiBai ManfeiBai Apr 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, x was added here to add expected shape/type arguments in body/cond xlacomputation's arguments due to unused input arguments missed with built via LTC,

to meet XLA::While requirement: parameter of condition and body, the result of the body, and init must all have the same shape

xla::Shape shape =
xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {1});
int64_t parameter_idx =
2; // parameter_idx start from 2 after used upper and lower
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you move the comment above the code? Thanks!

}
}

// hard-code modify body xlacomputation input arguments with unusedarguments
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, space between unused and arguments

Comment on lines +32 to +33
weight = body_fun.weight # not be used actually, initialized as placeholder xlacomputation requirement
bias = body_fun.bias # not be used actually, initialized as placeholder xlacomputation requirement
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, can we move the comment above the code?

Comment on lines +34 to +36
return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(
one_value, x), input_value.clone(), bias.clone(), weight.clone(
), output_value.clone()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to clone?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

miss .clone() here would cause PyTorch ERROR: torch.while_loop's body_fn might be aliasing the input!, so add .clone() to avoid return input arguments directly

Comment on lines +105 to +107
return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(
one_value, x), input_value.clone(), bias.clone(), weight.clone(
), output_value.clone()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remind me why do we need to return weight and bias in this function?

Copy link
Collaborator Author

@ManfeiBai ManfeiBai Apr 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to make sure body_fn's xlacomputation's input and output are the same, because input would include weight automatically, so here we return weight and bias from python level to ensure weight and bias are included in ouput too. Add bias to avoid output_value is used as bias in calculation, because bias has the same shape and value as output_value

but we also has plan to lower add weight and bias in xlacomputation arguments to CPP level, let me test locally too, if pass, we could avoid return weight and bias from python level

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is too confusing, we need to think of a better UX. body_fn also should take linear_0 as an input instead of calling it from parent scope.

I think instead of manually return the each parameter in the module, we should just return module's named_paramter . User also shouldn't need to manually order the return parameter(this will be super confusing to the user), we should do it in our layer.

@miladm
Copy link
Collaborator

miladm commented May 6, 2024

do we have an ETA for this PR?

@JackCaoG
Copy link
Collaborator

JackCaoG commented May 7, 2024

I felt like the way we determine the HLO input parameter order is not ideal, it only works on simple examples. In this pr Manfei is trying to support Linear and she need to do some hacks to make the order correct. I am trying to figure out if there is a more general way to determine the parameter order, hopefully I can find some time today and tmr to draft a poc out..

Comment on lines +105 to +106
return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(
one_value, x), input_value.clone(), bias.clone(), weight.clone(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we returning a torch.add(one_value, x) here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was used here to confirm calculation run expected times as a timer

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think our test case is too complicated, we should aim to support what pytorch support, similar to https://github.com/pytorch/pytorch/blob/8573d9551a7694b9313310412867ac3b6b751f26/test/functorch/test_control_flow.py#L137-L150.

@ManfeiBai
Copy link
Collaborator Author

I felt like the way we determine the HLO input parameter order is not ideal, it only works on simple examples. In this pr Manfei is trying to support Linear and she need to do some hacks to make the order correct. I am trying to figure out if there is a more general way to determine the parameter order, hopefully I can find some time today and tmr to draft a poc out..

Thanks, @JackCaoG, for comparison, this WIP PR is also trying post-order like XLAGraphExecutor::Compile according to your idea and suggestion: #7031, then add python interface for while_loop to generate xlacomputation for cond/body

@JackCaoG
Copy link
Collaborator

JackCaoG commented May 7, 2024

I have some high level ideas

as long as there is no inplace update to any parameters(which will not be true I guess since we need to decrement the iterator), get_hlo(input1, input2, input3, output) should have parameter in the same order. This is because in order to compute the new input1, we need input1 as result. The compilation of this approach is if somehow input1 is updated, for example

fn(input1, input2, input3, output):
  input1 += input2
  ...

get_hlo(input1, input2, input3, output) will likely result in the parameter order of input2, input1, input3, output. The key is always put all inputs in front of the output.

The same rule should apply for the named_parameters from the module , if they are being passed as additional_input from the while loop, we just need to stich them to inputs list(not from user code level, in our while loop implementation level). Hopefully this will just work..

@ManfeiBai
Copy link
Collaborator Author

ManfeiBai commented May 8, 2024

get_hlo

Thanks, @JackCaoG,

for this idea, do we mean get_hlo is code and code?

I have some high level ideas

as long as there is no inplace update to any parameters(which will not be true I guess since we need to decrement the iterator), get_hlo(input1, input2, input3, output) should have parameter in the same order. This is because in order to compute the new input1, we need input1 as result. The compilation of this approach is if somehow input1 is updated, for example

fn(input1, input2, input3, output):
  input1 += input2
  ...

get_hlo(input1, input2, input3, output) will likely result in the parameter order of input2, input1, input3, output. The key is always put all inputs in front of the output.

The same rule should apply for the named_parameters from the module , if they are being passed as additional_input from the while loop, we just need to stich them to inputs list(not from user code level, in our while loop implementation level). Hopefully this will just work..

Thanks, @JackCaoG

for order change of input1 and input2, tried with test of add and linear: test code and test log based on current PR's branch, we saw the wrong result might due to the order change reason too when we change order of lower and upper

checked HLO of passed linear test: HLO, and failed linear test after switch lower and upper: HLO, looks like cond's xlacomputation didn't change after we switch order or inputs, so we do need to fix order of input and output like your suggestion, both in fori_loop's and while_loop's implementation

for order of input/output, we hasn't saw this error happen when we switch order of body's input_value and output_value yet, do we have more situation that missed here?

@ManfeiBai ManfeiBai changed the title [Fori_loop|While_loop] Modify XlaComputation and add linear model test case [Fori_loop|While_loop] Enable while_loop/fori_loop, Add linear/MNIST test case May 31, 2024
@ManfeiBai
Copy link
Collaborator Author

ManfeiBai commented May 31, 2024

thanks for review and comments, since failed to push due to permission, would continue track in #7157

@ManfeiBai ManfeiBai closed this May 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants