Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
Browse files Browse the repository at this point in the history
…nto develop
  • Loading branch information
DesmonDay committed Nov 7, 2024
2 parents d51abba + 140ea48 commit 85b207e
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,4 @@ dataset/
output/

# gen codes
autogen/
autogen/
6 changes: 6 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[submodule "csrc/third_party/cutlass"]
path = csrc/third_party/cutlass
url = https://github.com/NVIDIA/cutlass.git
[submodule "csrc/third_party/nlohmann_json"]
path = csrc/third_party/nlohmann_json
url = https://github.com/nlohmann/json.git
20 changes: 8 additions & 12 deletions csrc/setup_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@

import paddle
from paddle.utils.cpp_extension import CUDAExtension, setup
import subprocess

def update_git_submodule():
try:
subprocess.run(["git", "submodule", "update", "--init"], check=True)
except subprocess.CalledProcessError as e:
print(f"Error occurred while updating git submodule: {str(e)}")
raise

def clone_git_repo(version, repo_url, destination_path):
try:
Expand Down Expand Up @@ -121,18 +128,7 @@ def get_gencode_flags():

cutlass_dir = "third_party/cutlass"
nvcc_compile_args = gencode_flags

if not os.path.exists(cutlass_dir) or not os.listdir(cutlass_dir):
if not os.path.exists(cutlass_dir):
os.makedirs(cutlass_dir)
clone_git_repo("v3.5.0", "https://github.com/NVIDIA/cutlass.git", cutlass_dir)

json_dir = "third_party/nlohmann_json"
if not os.path.exists(json_dir) or not os.listdir(json_dir):
if not os.path.exists(json_dir):
os.makedirs(json_dir)
clone_git_repo("v3.11.3", "https://github.com/nlohmann/json.git", json_dir)

update_git_submodule()
nvcc_compile_args += [
"-O3",
"-U__CUDA_NO_HALF_OPERATORS__",
Expand Down
1 change: 1 addition & 0 deletions csrc/third_party/cutlass
Submodule cutlass added at 7d49e6
1 change: 1 addition & 0 deletions csrc/third_party/nlohmann_json
Submodule nlohmann_json added at 9cca28

0 comments on commit 85b207e

Please sign in to comment.