-
Notifications
You must be signed in to change notification settings - Fork 486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Fori_loop|While_loop] Enable while_loop/fori_loop, Add linear/MNIST test case #6867
Conversation
4e0dcaa
to
50193f3
Compare
92d17fd
to
a79a609
Compare
res_list.insert(0, lower) | ||
res_list.insert(0, torch.sub(upper, one_value_i)) | ||
return res_list | ||
if (hasattr(body_fun, 'weight') or hasattr(body_fun, 'bias')): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the plan to remove this condition?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually why do we special casing the weight and bias? Is it only for the linear layer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, its only for the linear layer,
special casing the weight and bias due to different body_fn return: weight/bias was not mentioned in inputs, but need to be returned or added in xlacomputation return arguments
plans to remove this condition is:
- A: check
additional_inputs
(weight/bias) like PyTorch and add them into xlacomputation arguments in CPP level - B: only check whether
additional_inputs
(weight/bias) exist for code of weight/bias in the next step
xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, | ||
"UnusedArgumentsPlaceholder"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't quite understand this part, x
is a local variable that will be released after the for loop, is the intention of calling xla::Parameter
to init a parameter at position parameter_idx
with given shape? If so can you add a comment to make it more clear?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, x was added here to add expected shape/type arguments in body/cond xlacomputation's arguments due to unused input arguments missed with built via LTC,
to meet XLA::While
requirement: parameter of condition and body, the result of the body, and init must all have the same shape
xla::Shape shape = | ||
xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {1}); | ||
int64_t parameter_idx = | ||
2; // parameter_idx start from 2 after used upper and lower |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you move the comment above the code? Thanks!
} | ||
} | ||
|
||
// hard-code modify body xlacomputation input arguments with unusedarguments |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit, space between unused
and arguments
weight = body_fun.weight # not be used actually, initialized as placeholder xlacomputation requirement | ||
bias = body_fun.bias # not be used actually, initialized as placeholder xlacomputation requirement |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit, can we move the comment above the code?
return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( | ||
one_value, x), input_value.clone(), bias.clone(), weight.clone( | ||
), output_value.clone() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need to clone?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
miss .clone()
here would cause PyTorch ERROR: torch.while_loop's body_fn might be aliasing the input!
, so add .clone()
to avoid return input arguments directly
return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( | ||
one_value, x), input_value.clone(), bias.clone(), weight.clone( | ||
), output_value.clone() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remind me why do we need to return weight and bias in this function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need to make sure body_fn's xlacomputation's input and output are the same, because input would include weight automatically, so here we return weight and bias from python level to ensure weight and bias are included in ouput too. Add bias to avoid output_value is used as bias in calculation, because bias has the same shape and value as output_value
but we also has plan to lower add weight
and bias
in xlacomputation arguments to CPP level, let me test locally too, if pass, we could avoid return weight and bias from python level
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is too confusing, we need to think of a better UX. body_fn also should take linear_0
as an input instead of calling it from parent scope.
I think instead of manually return the each parameter in the module, we should just return module's named_paramter
. User also shouldn't need to manually order the return parameter(this will be super confusing to the user), we should do it in our layer.
do we have an ETA for this PR? |
I felt like the way we determine the HLO input parameter order is not ideal, it only works on simple examples. In this pr Manfei is trying to support Linear and she need to do some hacks to make the order correct. I am trying to figure out if there is a more general way to determine the parameter order, hopefully I can find some time today and tmr to draft a poc out.. |
return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( | ||
one_value, x), input_value.clone(), bias.clone(), weight.clone( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are we returning a torch.add(one_value, x)
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was used here to confirm calculation run expected times as a timer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think our test case is too complicated, we should aim to support what pytorch support, similar to https://github.com/pytorch/pytorch/blob/8573d9551a7694b9313310412867ac3b6b751f26/test/functorch/test_control_flow.py#L137-L150.
Thanks, @JackCaoG, for comparison, this WIP PR is also trying post-order like |
I have some high level ideas as long as there is no inplace update to any parameters(which will not be true I guess since we need to decrement the iterator), get_hlo(input1, input2, input3, output) should have parameter in the same order. This is because in order to compute the new
The same rule should apply for the named_parameters from the |
Thanks, @JackCaoG, for this idea, do we mean
Thanks, @JackCaoG for order change of checked HLO of passed for order of input/output, we hasn't saw this error happen when we switch order of body's |
thanks for review and comments, since failed to push due to permission, would continue track in #7157 |
fori_loop
orwhile_loop
xla::While
:cond's input, body's input/output and init should be the same shape
;xla::While
's requirement mentioned abovenext plan:
init_python_bindings.cpp
: get body xlacomputation arguments' number first then decide items inadditional_inputs_list
, maybe implement in python level