Skip to content

Commit

Permalink
old
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai committed May 6, 2024
1 parent 33fa1fb commit 38fae84
Show file tree
Hide file tree
Showing 5 changed files with 519 additions and 9 deletions.
170 changes: 170 additions & 0 deletions test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# We need to import the underlying implementation function to register with the dispatcher
import torch_xla.experimental.fori_loop
from torch_xla.experimental.fori_loop import fori_loop
from torch_xla.experimental.fori_loop import _xla_while_loop, get_module_parameters
from torch._higher_order_ops.while_loop import while_loop
import torch_xla.core.xla_model as xm
import torch_xla.core.xla_builder as xb
Expand Down Expand Up @@ -121,6 +122,40 @@ def body_fn(upper, lower, one_value, x, input_value, output_value):

return self.assertTrue(torch.all(torch.eq(expected, output_value_real__)))

def test_while_loop_tpu_simple_linear_clean(self):

xm.mark_step()
device = xm.xla_device()
torch.set_grad_enabled(False)

linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device())

def cond_fn(upper, lower, one_value, x, input_value, output_value):
return lower[0] < upper[0]

def body_fn(upper, lower, one_value, x, input_value, output_value):
new_lower = torch.add(one_value, lower)
output_value = linear_0(input_value)
return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(
one_value, x), input_value.clone(), output_value.clone()
# return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(
# one_value, x), input_value.clone(), output_value.clone()

upper = torch.tensor([1], dtype=torch.int32, device=device)
lower = torch.tensor([0], dtype=torch.int32, device=device)
one_value = torch.tensor([1], dtype=torch.int32, device=device)
init_val = torch.tensor([1], dtype=torch.int32, device=device)
l_in_0 = torch.rand(10, device=xm.xla_device())
output_value = torch.zeros([20], dtype=torch.float32, device=device)

upper__, lower__, one_value__, torch_add_res__, input_value__, bias__, weight__, output_value_real__, = while_loop(
cond_fn, body_fn,
(upper, lower, one_value, init_val, l_in_0, output_value))

expected = _fake_fori_loop(lower, upper, linear_0, l_in_0)

return self.assertTrue(torch.all(torch.eq(expected, output_value_real__)))

def test_while_loop_tpu_simple_linear_class(self):

xm.mark_step()
Expand Down Expand Up @@ -180,6 +215,135 @@ def body_fn(upper, lower, one_value, x, input_value, output_value):
self.assertTrue(torch.all(torch.eq(expected, output_value_real__)))
return aaa

def test_while_loop_tpu_simple_linear_class_clean(self):

xm.mark_step()
device = xm.xla_device()
torch.set_grad_enabled(False)

# define simple linear model class
class SimpleWithLinear(torch.nn.Module):

def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 20).to(xm.xla_device())

def forward(self, one_value, x, input_value):
output_value_real = self.linear(input_value)
torch_add_res = torch.add(one_value, x)
return torch_add_res, output_value_real

simple_with_linear = SimpleWithLinear()

# define cond and body
def cond_fn(upper, lower, one_value, x, input_value, output_value, *args):
return lower[0] < upper[0]

def body_fn(upper, lower, one_value, x, input_value, output_value, *args):
new_lower = torch.add(one_value, lower)
output_value_real = simple_with_linear(one_value, x, input_value)
res = [upper.clone(), new_lower.clone(), one_value.clone(), output_value_real[0], input_value.clone(), output_value_real[1]]
res = get_module_parameters(res, simple_with_linear, contain_output_balue=True)
return tuple(res)

# simple_with_linear = SimpleWithLinear()
upper = torch.tensor([52], dtype=torch.int32, device=device)
lower = torch.tensor([1], dtype=torch.int32, device=device)
one_value = torch.tensor([1], dtype=torch.int32, device=device)
init_val = torch.tensor([1], dtype=torch.int32, device=device)
l_in_0 = torch.rand(10, device=xm.xla_device())
output_value = torch.zeros([20], dtype=torch.float32, device=device)

aaa = {
"simple_with_linear":
(simple_with_linear, (upper, lower, one_value, init_val, l_in_0,
output_value))
}

additional_inputs = get_module_parameters([], simple_with_linear, contain_output_balue=False)

upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, output_value_real__, = _xla_while_loop(
cond_fn, body_fn,
(upper, lower, one_value, init_val, l_in_0, output_value), tuple(additional_inputs))

print("while_loop run times: ", torch_add_res__)
print("output_value_real__: ", output_value_real__)

expected = simple_with_linear(one_value, init_val, l_in_0)
expected = expected[-1]
print("expected: ", expected)

self.assertTrue(torch.all(torch.eq(expected, output_value_real__)))
return aaa

def test_while_loop_tpu_simple_linear_class_clean_only_linear(self):

xm.mark_step()
device = xm.xla_device()
torch.set_grad_enabled(False)

# define simple linear model class
class SimpleWithLinear(torch.nn.Module):

def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 20).to(xm.xla_device())

def forward(self, x):
x = self.linear(x)
return x

simple_with_linear = SimpleWithLinear()

# define cond and body
def cond_fn(upper, lower, one_value, x, input_value, output_value, *args):
return lower[0] < upper[0]

def body_fn(upper, lower, one_value, x, input_value, output_value, *args):
new_lower = torch.add(one_value, lower)
output_value_real = simple_with_linear(input_value)
res = [upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x).clone(), input_value.clone(), output_value_real.clone()]
res = get_module_parameters(res, simple_with_linear, contain_output_balue=False)
return tuple(res)

# simple_with_linear = SimpleWithLinear()
upper = torch.tensor([52], dtype=torch.int32, device=device)
lower = torch.tensor([1], dtype=torch.int32, device=device)
one_value = torch.tensor([1], dtype=torch.int32, device=device)
init_val = torch.tensor([1], dtype=torch.int32, device=device)
l_in_0 = torch.rand(10, dtype=torch.float32, device=device) # xm.xla_device()) # float
output_value = torch.zeros([20], dtype=torch.float32, device=device)
print("output_value: ", output_value)

aaa = {
"simple_with_linear":
(simple_with_linear, (upper, lower, one_value, init_val, l_in_0,
output_value))
}

additional_inputs = get_module_parameters([], simple_with_linear, contain_output_balue=False)
# print("additional_inputs: ", additional_inputs)

# while_loop
upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, output_value_real__, = while_loop(
cond_fn, body_fn,
(upper, lower, one_value, init_val, l_in_0, output_value))

# _xla_while_loop
# upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, output_value_real__, = _xla_while_loop(
# cond_fn, body_fn,
# (upper, lower, one_value, init_val, l_in_0, output_value), tuple(additional_inputs))

print("while_loop run times: ", torch_add_res__)
print("output_value_real__: ", output_value_real__)

expected = simple_with_linear(l_in_0)
print("expected: ", expected)

self.assertTrue(torch.all(torch.eq(expected, output_value_real__)))
return aaa


def test_fori_loop_tpu_addition(self):

xm.mark_step()
Expand Down Expand Up @@ -218,6 +382,12 @@ def test_fori_loop_tpu_simple_linear(self):

self.assertTrue(torch.all(torch.eq(expected, l_out_)))

def test_get_xlacomputation(self):
xla_device = xm.xla_device()
t1 = torch.randn(20, 5).to(xla_device)
t2 = torch.randn(20, 5).to(xla_device)
# expected_data_handles = torch_xla._XLAC._get_xla_computation([t1], [], False)
print("finish test here")

if __name__ == '__main__':
test = unittest.main()
Expand Down
94 changes: 94 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,87 @@ runtime::ComputationClient::ComputationPtr CreateComputationFromProto(
name, std::move(computation));
}

// // std::shared_ptr<XLAGraphExecutor::Async>
// // XLAGraphExecutor::CompilationResult
// // runtime::ComputationClient::ComputationPtr
// // XLAGraphExecutor::saveComputation*
// // std::vector<runtime::ComputationClient::ComputationPtr>
// // XLAGraphExecutor::CachedComputation
// runtime::ComputationClient::ComputationPtr GetXLAComputation(std::vector<XLATensorPtr> tensors,
// absl::Span<const std::string> devices, bool warm_up_cache_only = false) {
// tsl::profiler::TraceMe activity("GetXLAComputation", tsl::profiler::TraceMeLevel::kInfo);
// SyncTensorsConfig config;
// config.force_ltc_data = false;
// SyncTensorCollection coll = CollectSyncTensors(*tensors, config);
// if (coll.indices.empty()) {
// TensorCollectionBarrier(&coll);
// return nullptr;
// }
// DebugUtil::SaveTensorsGraphInfo("ScheduleSyncTensorsGraph", *tensors,
// &coll.indices);
// std::vector<torch::lazy::Value> ir_values;
// std::vector<torch::lazy::BackendDataPtr> tensor_data_vec;
// ExtractIRAndPrepareXlaData_(tensors, coll.config, coll.indices, ir_values,
// tensor_data_vec);
// PostOrderData po_data = RunPostOrder(ir_values, &coll);
// coll.hash = torch::lazy::HashCombine(
// coll.hash, torch::lazy::Hash(po_data.parameter_sequence));
// if (GetAliasWithBufferDonorConfig()) {
// std::vector<size_t> buffer_donor_index =
// GetBufferDonorIndexFromUserConfig(po_data.parameters_data);
// if (buffer_donor_index.size() > 0) {
// // Do not include hash on a empty vector.
// coll.hash = torch::lazy::HashCombine(
// coll.hash, torch::lazy::Hash(buffer_donor_index));
// }
// }
// {
// // Auto-sharding configs
// coll.hash = torch::lazy::HashCombine(
// coll.hash, torch::lazy::MHash(ShardingUtil::GetAutoSharding()));
// coll.hash = torch::lazy::HashCombine(
// coll.hash,
// torch::lazy::StringHash(
// runtime::sys_util::GetEnvString("XLA_AUTO_SPMD_MESH", "").c_str()));
// }

// DebugUtil::SaveGraphHash(coll.hash);
// TF_VLOG(4) << "Parameter sequence graph hash "
// << torch::lazy::HashToString(coll.hash);

// std::pair<bool, std::shared_ptr<XLAGraphExecutor::Async>> cache_res =
// TryRunCachedSync(tensors, &coll, &po_data, tensor_data_vec,
// warm_up_cache_only);
// if (cache_res.first) {
// // we have a cache hit, execution has been scheduled by TryRunCachedSync.
// return cache_res.second;
// }

// // CompilationResult compile_result =
// // Compile(*tensors, devices, coll, &po_data, ir_values);

// // runtime::ComputationClient::ComputationPtr
// // saveComputation* compile_result = std::move(
// // Compile(*tensors, devices, coll, &po_data, ir_values).computation)
// XLAGraphExecutor::saveComputation* compile_result = Compile(*tensors, devices, coll, &po_data, ir_values).computation
// return compile_result

// // TORCH_LAZY_VALUE_METRIC("TensorsGraphSize", compile_result.emitted_nodes);
// // TF_VLOG(5) << "TensorsGraphSize=" << compile_result.emitted_nodes;
// // auto cached_computation = std::make_shared<CachedComputation>(
// // std::move(compile_result.computation), compile_result.is_sharded);
// // GetComputationCache()->Add(coll.hash, cached_computation);

// // if (warm_up_cache_only) {
// // return nullptr;
// // } else {
// // return ScheduleSyncTensorsGraph(
// // tensors, &coll, std::move(compile_result.parameters_data),
// // compile_result.device.toString(), std::move(cached_computation),
// // tensor_data_vec);
// // }
// }

xla::Shape GetTensorShape(const at::Tensor& tensor,
const std::string& device_str) {
auto xtensor = bridge::TryGetXlaTensor(tensor);
Expand Down Expand Up @@ -885,6 +966,8 @@ class PyLoweringContext {
public:
PyLoweringContext() : PyLoweringContext(bridge::GetCurrentDevice()) {}

// PostOrderData po_data = RunPostOrder(ir_values, &coll);

PyLoweringContext(torch::lazy::BackendDevice device)
: lowering_ctx("PyLoweringContext", device) {}

Expand Down Expand Up @@ -2379,6 +2462,17 @@ void InitXlaModuleBindings(py::module m) {
BuildProfilerSubmodule(&m);
BuildLoweringContextSubmodule(&m);

m.def("_get_xla_computation", [](const std::vector<at::Tensor>& tensors,
const std::vector<std::string>& devices, const bool warm_up_cache_only) {
std::vector<XLATensorPtr> xtensors;
xtensors.reserve(tensors.size());
for (auto& tensor : tensors) {
xtensors.push_back(bridge::GetXlaTensor(tensor));
}

runtime::ComputationClient::ComputationPtr xla_computation = XLAGraphExecutor::Get()->GetXLAComputation(xtensors, {}, true);
return xla_computation;
});
m.def("_get_tensors_handle",
[](const std::vector<at::Tensor>& tensors) -> std::vector<int64_t> {
std::vector<torch::lazy::BackendData::Handle> handles;
Expand Down
Loading

0 comments on commit 38fae84

Please sign in to comment.