Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can MONeT use for custom models and higher PyTorch version? #5

Open
Musisoul opened this issue Sep 19, 2022 · 3 comments
Open

Can MONeT use for custom models and higher PyTorch version? #5

Musisoul opened this issue Sep 19, 2022 · 3 comments

Comments

@Musisoul
Copy link

Thanks for your work!
Currently we have two questions:

  1. Can MONeT work for PyTorch with version higher than 1.5.1? We have tried PyTorch 1.11.0 with CUDA 11.3, but we got error in https://github.com/utsaslab/MONeT/blob/master/monet/lm_ops/conv.py#L8 load function when running examples/training.py. We have also tried PyTorch 1.5.0 with CUDA 10.1, we didn't get previous error but got cuDNN error: CUDNN_STATUS_EXECUTION_FAILED at forward function in monet/lm_ops/bn.py, and the program(examples/training.py) took a long time on initialization. Can you post the detailed configurations, including PyTorch, CUDA, g++, etc. ?
  2. Can MONeT use for custom models? In README, you mention that to create a MONeT solution we could use python cvxpy_solver.py MODEL ..., and the model format should be "torchvision.models.<model>()". Can we use MONeT to generate solutions for our own models?
@TraceCS
Copy link

TraceCS commented Sep 19, 2022

I also try to use MONeT on a higher version and run the example code:

import torch, torchvision
from monet.cvxpy_solver import Solution
from monet.monet_wrapper import MONeTWrapper
import time

input = torch.randn(184,3,224,224).cuda()
model = torchvision.models.resnet50()
input_shape = (3,224,224)

# Can change to use absolute path instead of relative
sol_file = "/data/dev/MONeT/data/monet_r50_184_24hr/solution_resnet50_184_inplace_conv_multiway_newnode_10.00.pkl"
# import pickle
# with open(sol_file, 'rb') as f:
#     data = pickle.load(f)
#     # print(data)
train_model = MONeTWrapper(model, sol_file, (3,224,224)).cuda()
output = train_model(input)
output.sum().backward()
print("Memory used: %6.2f MB" % (torch.cuda.max_memory_allocated()/1024/1024))

and I got the error message:

  File "/data/tmp/miniconda3/envs/tdy/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 1881, in _run_ninja_build
    subprocess.run(
  File "/data/tmp/miniconda3/envs/tdy/lib/python3.8/subprocess.py", line 516, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/data/dev/MONeT/monet/lm_ops/compress.py", line 6, in <module>
    compress_cpp = load(name="compress_cpp", sources=[this_dir / "compress.cpp", this_dir / "compress.cu"], extra_cflags=['-std=c++17', '-lcusparse'], extra_cuda_cflags=['-lcusparse'],extra_ldflags=['-lcusparse'])
  File "/data/tmp/miniconda3/envs/tdy/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 1265, in load
    return _jit_compile(
  File "/data/tmp/miniconda3/envs/tdy/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 1489, in _jit_compile
    _write_ninja_file_and_build_library(
  File "/data/tmp/miniconda3/envs/tdy/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 1604, in _write_ninja_file_and_build_library
    _run_ninja_build(
  File "/data/tmp/miniconda3/envs/tdy/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 1897, in _run_ninja_build
    raise RuntimeError(message) from e
RuntimeError: Error building extension 'compress_cpp': [1/2] c++ -MMD -MF compress.o.d -DTORCH_EXTENSION_NAME=compress_cpp -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /data/tmp/miniconda3/envs/tdy/lib/python3.8/site-packages/torch/include -isystem /data/tmp/miniconda3/envs/tdy/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -isystem /data/tmp/miniconda3/envs/tdy/lib/python3.8/site-packages/torch/include/TH -isystem /data/tmp/miniconda3/envs/tdy/lib/python3.8/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /data/tmp/miniconda3/envs/tdy/include/python3.8 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++14 -std=c++17 -lcusparse -c /data/dev/MONeT/monet/lm_ops/compress.cpp -o compress.o 
FAILED: compress.o 
c++ -MMD -MF compress.o.d -DTORCH_EXTENSION_NAME=compress_cpp -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /data/tmp/miniconda3/envs/tdy/lib/python3.8/site-packages/torch/include -isystem /data/tmp/miniconda3/envs/tdy/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -isystem /data/tmp/miniconda3/envs/tdy/lib/python3.8/site-packages/torch/include/TH -isystem /data/tmp/miniconda3/envs/tdy/lib/python3.8/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /data/tmp/miniconda3/envs/tdy/include/python3.8 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++14 -std=c++17 -lcusparse -c /data/dev/MONeT/monet/lm_ops/compress.cpp -o compress.o 
/data/dev/MONeT/monet/lm_ops/compress.cpp: In function ‘std::tuple<at::Tensor, at::Tensor, at::Tensor> compress_csr_256(const at::Tensor&, const at::Tensor&, size_t)’:
/data/dev/MONeT/monet/lm_ops/compress.cpp:31:17: warning: ‘at::DeprecatedTypeProperties& at::Tensor::type() const’ is deprecated: Tensor.type() is deprecated. Instead use Tensor.options(), which in many cases (e.g. in a constructor) is a drop-in replacement. If you were using data from type(), that is now available from Tensor itself, so instead of tensor.type().scalar_type(), use tensor.scalar_type() instead and instead of tensor.type().backend() use tensor.device(). [-Wdeprecated-declarations]
     if (ip.type().is_cuda()) {
                 ^
In file included from /data/tmp/miniconda3/envs/tdy/lib/python3.8/site-packages/torch/include/ATen/core/Tensor.h:3:0,
                 from /data/tmp/miniconda3/envs/tdy/lib/python3.8/site-packages/torch/include/ATen/Tensor.h:3,
                 from /data/tmp/miniconda3/envs/tdy/lib/python3.8/site-packages/torch/include/torch/csrc/autograd/function_hook.h:3,
                 from /data/tmp/miniconda3/envs/tdy/lib/python3.8/site-packages/torch/include/torch/csrc/autograd/cpp_hook.h:2,
                 from /data/tmp/miniconda3/envs/tdy/lib/python3.8/site-packages/torch/include/torch/csrc/autograd/variable.h:6,
                 from /data/tmp/miniconda3/envs/tdy/lib/python3.8/site-packages/torch/include/torch/csrc/autograd/autograd.h:3,
                 from /data/tmp/miniconda3/envs/tdy/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/autograd.h:3,
                 from /data/tmp/miniconda3/envs/tdy/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/all.h:7,
                 from /data/tmp/miniconda3/envs/tdy/lib/python3.8/site-packages/torch/include/torch/extension.h:4,
                 from /data/dev/MONeT/monet/lm_ops/compress.cpp:1:
/data/tmp/miniconda3/envs/tdy/lib/python3.8/site-packages/torch/include/ATen/core/TensorBody.h:214:30: note: declared here
   DeprecatedTypeProperties & type() const {
                              ^
/data/dev/MONeT/monet/lm_ops/compress.cpp:40:29: error: converting to ‘std::tuple<at::Tensor, at::Tensor, at::Tensor>’ from initializer list would use explicit constructor ‘constexpr std::tuple< <template-parameter-1-1> >::tuple(_UElements&& ...) [with _UElements = {at::Tensor&, at::Tensor&, at::Tensor&}; <template-parameter-2-2> = void; _Elements = {at::Tensor, at::Tensor, at::Tensor}]’
     return {cip, idx, rowidx};
                             ^
/data/dev/MONeT/monet/lm_ops/compress.cpp: In function ‘at::Tensor uncompress_csr_256(const at::Tensor&, const at::Tensor&, const at::Tensor&, size_t)’:
/data/dev/MONeT/monet/lm_ops/compress.cpp:47:45: warning: narrowing conversion of ‘((N + 255ul) / 256ul)’ from ‘size_t {aka long unsigned int}’ to ‘long int’ inside { } [-Wnarrowing]
     torch::Tensor op = torch::zeros({(N+255)/256,256}, torch::dtype(torch::kFloat32).device(compip.device()));
                                             ^
/data/dev/MONeT/monet/lm_ops/compress.cpp:47:45: warning: narrowing conversion of ‘((N + 255ul) / 256ul)’ from ‘size_t {aka long unsigned int}’ to ‘long int’ inside { } [-Wnarrowing]
/data/dev/MONeT/monet/lm_ops/compress.cpp:50:21: warning: ‘at::DeprecatedTypeProperties& at::Tensor::type() const’ is deprecated: Tensor.type() is deprecated. Instead use Tensor.options(), which in many cases (e.g. in a constructor) is a drop-in replacement. If you were using data from type(), that is now available from Tensor itself, so instead of tensor.type().scalar_type(), use tensor.scalar_type() instead and instead of tensor.type().backend() use tensor.device(). [-Wdeprecated-declarations]
     if (compip.type().is_cuda()) {
                     ^
In file included from /data/tmp/miniconda3/envs/tdy/lib/python3.8/site-packages/torch/include/ATen/core/Tensor.h:3:0,
                 from /data/tmp/miniconda3/envs/tdy/lib/python3.8/site-packages/torch/include/ATen/Tensor.h:3,
                 from /data/tmp/miniconda3/envs/tdy/lib/python3.8/site-packages/torch/include/torch/csrc/autograd/function_hook.h:3,
                 from /data/tmp/miniconda3/envs/tdy/lib/python3.8/site-packages/torch/include/torch/csrc/autograd/cpp_hook.h:2,
                 from /data/tmp/miniconda3/envs/tdy/lib/python3.8/site-packages/torch/include/torch/csrc/autograd/variable.h:6,
                 from /data/tmp/miniconda3/envs/tdy/lib/python3.8/site-packages/torch/include/torch/csrc/autograd/autograd.h:3,
                 from /data/tmp/miniconda3/envs/tdy/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/autograd.h:3,
                 from /data/tmp/miniconda3/envs/tdy/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/all.h:7,
                 from /data/tmp/miniconda3/envs/tdy/lib/python3.8/site-packages/torch/include/torch/extension.h:4,
                 from /data/dev/MONeT/monet/lm_ops/compress.cpp:1:
/data/tmp/miniconda3/envs/tdy/lib/python3.8/site-packages/torch/include/ATen/core/TensorBody.h:214:30: note: declared here
   DeprecatedTypeProperties & type() const {
                              ^
ninja: build stopped: subcommand failed.```

I use pytorch 1.13.0.dev20220801+cu113, g++(GCC) 5.4.0

@Jack47
Copy link

Jack47 commented Sep 20, 2022

  1. Can you post the detailed configurations, including PyTorch, CUDA, g++, etc. ?

Please see https://github.com/utsaslab/MONeT/blob/master/install.sh#L11

conda install pytorch==1.5.1 torchvision==0.6.1 cudatoolkit=10.1 -c pytorch -y

@aashaka
Copy link
Member

aashaka commented Sep 26, 2022

Thanks for checking out MONeT.

  1. MONeT has only been tested for PyTorch 1.5.1, Torchvision 0.6.1, and cudatoolkit 10.1. Some possible reasons for not working with a newer PyTorch version are:
  • In order to allow explicitly picking an algorithm for convolution, we adopt some part of the convolution code from PyTorch 1.5.1 into lm_ops/conv.cpp.
  • We also make use of functions from the at and at::native namespace in order to implement other functions like output-activated backward operations.

If the aten operations are deprecated or signature-modified in later versions, that could also cause a problem. Identifying which function is failing will help to update it to newer versions of PyTorch.

  1. Yes, you can use MONeT for custom models (as long as all ops in the model are implemented in MONeT). UNet is one example of how we allow a custom model not in torchvision (explicit checking of name), or you could add your custom model package in eval (like torchvision here). Similar changes should be done in schedule.py when you want to run the generated solution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants