diff --git a/src/gui.py b/src/gui.py index 8625d2c5..a655c76f 100644 --- a/src/gui.py +++ b/src/gui.py @@ -14,12 +14,21 @@ def set_cuda_paths(): venv_base = Path(sys.executable).parent nvidia_base_path = venv_base / 'Lib' / 'site-packages' / 'nvidia' + cudnn_bin_path = nvidia_base_path / 'cudnn' / 'bin' + + # Set CUDA_PATH and CUDA_PATH_V12_2 for env_var in ['CUDA_PATH', 'CUDA_PATH_V12_2']: current_path = os.environ.get(env_var, '') os.environ[env_var] = os.pathsep.join(filter(None, [str(nvidia_base_path), current_path])) + + # Add nvidia folder and cudnn bin folder to system PATH + current_path = os.environ.get('PATH', '') + new_path = os.pathsep.join(filter(None, [str(cudnn_bin_path), str(nvidia_base_path), current_path])) + os.environ['PATH'] = new_path set_cuda_paths() + class DocQA_GUI(QWidget): def __init__(self): super().__init__()