From cd56f6c1d6099833a3bbd06f6da1589184c7e18b Mon Sep 17 00:00:00 2001 From: macroexpansion Date: Wed, 25 Oct 2023 19:21:11 +0700 Subject: [PATCH] separate tests for convert pytorch tensor --- candle-pyo3/test.py | 13 ------------- candle-pyo3/test_pytorch.py | 14 ++++++++++++++ 2 files changed, 14 insertions(+), 13 deletions(-) create mode 100644 candle-pyo3/test_pytorch.py diff --git a/candle-pyo3/test.py b/candle-pyo3/test.py index 4d0b52f9e7..e4ff772a1f 100644 --- a/candle-pyo3/test.py +++ b/candle-pyo3/test.py @@ -1,5 +1,4 @@ import candle -import torch print(f"mkl: {candle.utils.has_mkl()}") print(f"accelerate: {candle.utils.has_accelerate()}") @@ -30,15 +29,3 @@ dequant_t = quant_t.dequantize() diff2 = (t - dequant_t).sqr() print(diff2.mean_all()) - -# convert from candle tensor to torch tensor -t = candle.randn((3, 512, 512)) -torch_tensor = t.to_torch() -print(torch_tensor) -print(type(torch_tensor)) - -# convert from torch tensor to candle tensor -t = torch.randn((3, 512, 512)) -candle_tensor = candle.Tensor(t) -print(candle_tensor) -print(type(candle_tensor)) diff --git a/candle-pyo3/test_pytorch.py b/candle-pyo3/test_pytorch.py new file mode 100644 index 0000000000..db0f35227c --- /dev/null +++ b/candle-pyo3/test_pytorch.py @@ -0,0 +1,14 @@ +import candle +import torch + +# convert from candle tensor to torch tensor +t = candle.randn((3, 512, 512)) +torch_tensor = t.to_torch() +print(torch_tensor) +print(type(torch_tensor)) + +# convert from torch tensor to candle tensor +t = torch.randn((3, 512, 512)) +candle_tensor = candle.Tensor(t) +print(candle_tensor) +print(type(candle_tensor))