-
Notifications
You must be signed in to change notification settings - Fork 38
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
252 additions
and
281 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,32 +1,74 @@ | ||
import sys | ||
from PySide6.QtWidgets import QApplication, QMessageBox | ||
import torch | ||
|
||
def display_info(): | ||
app = QApplication(sys.argv) | ||
info_message = "" | ||
try: | ||
import torch | ||
except ImportError: | ||
def display_info(): | ||
app = QApplication(sys.argv) | ||
msg_box = QMessageBox(QMessageBox.Information, "PyTorch Not Installed", "PyTorch is not installed on this system.") | ||
msg_box.exec() | ||
|
||
if torch.cuda.is_available(): | ||
info_message += "CUDA is available!\n" | ||
info_message += "CUDA version: {}\n\n".format(torch.version.cuda) | ||
else: | ||
info_message += "CUDA is not available.\n\n" | ||
else: | ||
def check_bitsandbytes(): | ||
try: | ||
import bitsandbytes as bnb | ||
p = torch.nn.Parameter(torch.rand(10, 10).cuda()) | ||
a = torch.rand(10, 10).cuda() | ||
|
||
if torch.backends.mps.is_available(): | ||
info_message += "Metal/MPS is available!\n\n" | ||
else: | ||
info_message += "Metal/MPS is not available.\n\n" | ||
p1 = p.data.sum().item() | ||
|
||
info_message += "If you want to check the version of Metal and MPS on your macOS device, you can go to \"About This Mac\" -> \"System Report\" -> \"Graphics/Displays\" and look for information related to Metal and MPS.\n\n" | ||
adam = bnb.optim.Adam([p]) | ||
|
||
if torch.version.hip is not None: | ||
info_message += "ROCm is available!\n" | ||
info_message += "ROCm version: {}\n".format(torch.version.hip) | ||
else: | ||
info_message += "ROCm is not available.\n" | ||
out = a * p | ||
loss = out.sum() | ||
loss.backward() | ||
adam.step() | ||
|
||
msg_box = QMessageBox(QMessageBox.Information, "GPU Acceleration Available?", info_message) | ||
msg_box.exec() | ||
p2 = p.data.sum().item() | ||
|
||
assert p1 != p2 | ||
return "SUCCESS!\nInstallation of bitsandbytes was successful!" | ||
except ImportError: | ||
return "bitsandbytes is not installed." | ||
except AssertionError: | ||
return "bitsandbytes is installed, but the installation seems incorrect." | ||
except Exception as e: | ||
return f"An error occurred while checking bitsandbytes: {e}" | ||
|
||
def display_info(): | ||
app = QApplication(sys.argv) | ||
info_message = "" | ||
|
||
if torch.cuda.is_available(): | ||
info_message += "CUDA is available!\n" | ||
info_message += "CUDA version: {}\n\n".format(torch.version.cuda) | ||
else: | ||
info_message += "CUDA is not available.\n\n" | ||
|
||
if torch.backends.mps.is_available(): | ||
info_message += "Metal/MPS is available!\n\n" | ||
else: | ||
info_message += "Metal/MPS is not available.\n\n" | ||
if not torch.backends.mps.is_built(): | ||
info_message += "MPS not available because the current PyTorch install was not built with MPS enabled.\n\n" | ||
else: | ||
info_message += "MPS not available because the current MacOS version is not 12.3+ and/or you do not have an MPS-enabled device on this machine.\n\n" | ||
|
||
info_message += "If you want to check the version of Metal and MPS on your macOS device, you can go to \"About This Mac\" -> \"System Report\" -> \"Graphics/Displays\" and look for information related to Metal and MPS.\n\n" | ||
|
||
if torch.version.hip is not None: | ||
info_message += "ROCm is available!\n" | ||
info_message += "ROCm version: {}\n".format(torch.version.hip) | ||
else: | ||
info_message += "ROCm is not available.\n" | ||
|
||
# Check for bitsandbytes | ||
bitsandbytes_message = check_bitsandbytes() | ||
info_message += "\n" + bitsandbytes_message | ||
|
||
msg_box = QMessageBox(QMessageBox.Information, "GPU Acceleration and Library Check", info_message) | ||
msg_box.exec() | ||
|
||
if __name__ == "__main__": | ||
display_info() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.