diff --git a/.gitignore b/.gitignore index 386183b26a39..0d37bac550c8 100644 --- a/.gitignore +++ b/.gitignore @@ -131,4 +131,4 @@ dataset/ output/ # gen codes -autogen/ \ No newline at end of file +autogen/ diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000000..3498741eb542 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,6 @@ +[submodule "csrc/third_party/cutlass"] + path = csrc/third_party/cutlass + url = https://github.com/NVIDIA/cutlass.git +[submodule "csrc/third_party/nlohmann_json"] + path = csrc/third_party/nlohmann_json + url = https://github.com/nlohmann/json.git diff --git a/csrc/setup_cuda.py b/csrc/setup_cuda.py index a4e92b02c4b6..c9a49bd676b3 100644 --- a/csrc/setup_cuda.py +++ b/csrc/setup_cuda.py @@ -17,7 +17,14 @@ import paddle from paddle.utils.cpp_extension import CUDAExtension, setup +import subprocess +def update_git_submodule(): + try: + subprocess.run(["git", "submodule", "update", "--init"], check=True) + except subprocess.CalledProcessError as e: + print(f"Error occurred while updating git submodule: {str(e)}") + raise def clone_git_repo(version, repo_url, destination_path): try: @@ -121,18 +128,7 @@ def get_gencode_flags(): cutlass_dir = "third_party/cutlass" nvcc_compile_args = gencode_flags - -if not os.path.exists(cutlass_dir) or not os.listdir(cutlass_dir): - if not os.path.exists(cutlass_dir): - os.makedirs(cutlass_dir) - clone_git_repo("v3.5.0", "https://github.com/NVIDIA/cutlass.git", cutlass_dir) - -json_dir = "third_party/nlohmann_json" -if not os.path.exists(json_dir) or not os.listdir(json_dir): - if not os.path.exists(json_dir): - os.makedirs(json_dir) - clone_git_repo("v3.11.3", "https://github.com/nlohmann/json.git", json_dir) - +update_git_submodule() nvcc_compile_args += [ "-O3", "-U__CUDA_NO_HALF_OPERATORS__", diff --git a/csrc/third_party/cutlass b/csrc/third_party/cutlass new file mode 160000 index 000000000000..7d49e6c7e2f8 --- /dev/null +++ b/csrc/third_party/cutlass @@ -0,0 +1 @@ +Subproject commit 7d49e6c7e2f8896c47f586706e67e1fb215529dc diff --git a/csrc/third_party/nlohmann_json b/csrc/third_party/nlohmann_json new file mode 160000 index 000000000000..9cca280a4d0c --- /dev/null +++ b/csrc/third_party/nlohmann_json @@ -0,0 +1 @@ +Subproject commit 9cca280a4d0ccf0c08f47a99aa71d1b0e52f8d03