From bc2e0359939faa0309bd298c941d9eb76a76a300 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 10 Apr 2024 22:22:46 +0000 Subject: [PATCH] test --- torch_xla/experimental/fori_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index eababe82d55..544105a6de4 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -46,7 +46,7 @@ def fori_loop(one_value, lower, upper, body_fun, init_val, *input_value): #, wei # b=fake_carried_inputs[-3], # c=fake_carried_inputs[-2], # output_value=fake_carried_inputs[-1] - def cond_fn(lower, upper, one_value, x, bias_0, weight_0, *input_value, output_value): + def cond_fn(lower, upper, one_value, x, bias_0, weight_0, output_value, *input_value): return lower[0] <= upper[0] # one_value, init_val, l_in_i @@ -57,7 +57,7 @@ def cond_fn(lower, upper, one_value, x, bias_0, weight_0, *input_value, output_v # s32[1] # s32[1], s32[1], s32[1], s32[1], f32[20], f32[20,10], f32[10], f32[20])) # def body_fn(upper, lower, x, *input_value, a, b, c, output_value): - def body_fn(one_value, lower, upper, x, bias_0, weight_0, *input_value, output_value): + def body_fn(one_value, lower, upper, x, bias_0, weight_0, output_value, *input_value): # one_value, upper, lower, x_i, bias, weight, l_in_i # init_one_value = torch.ones(1, dtype=torch.int32, device=device)