diff --git a/setup.py b/setup.py index 2772db1..24d80c3 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setup( name='torchfcpe', description='The official Pytorch implementation of Fast Context-based Pitch Estimation (FCPE)', - version='0.0.1', + version='0.0.2', author='CNChTu', author_email='2921046558@qq.com', url='https://github.com/CNChTu/FCPE', diff --git a/torchfcpe/tools.py b/torchfcpe/tools.py index ab654e9..37cfdf2 100644 --- a/torchfcpe/tools.py +++ b/torchfcpe/tools.py @@ -352,19 +352,26 @@ def get_device(device: str, func_name: str) -> str: """Get device""" if device is None: - device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' + if torch.cuda.is_available(): + device = 'cuda' + elif torch.backends.mps.is_available(): + device = 'mps' + else: + device = 'cpu' + print(f' [INFO]: Using {device} automatically.') print(f' [INFO] > call by: {func_name}') - - # Check if the specified device is available, if not, switch to cpu - if device == 'cuda' and not torch.cuda.is_available() or device == 'mps' and not torch.backends.mps.is_available(): - print(f' [WARNING]: Specified device ({device}) is not available, switching to cpu.') - device = 'cpu' - else: print(f' [INFO]: device is not None, use {device}') print(f' [INFO] > call by:{func_name}') device = device + + # Check if the specified device is available, if not, switch to cpu + if ((device == 'cuda' and not torch.cuda.is_available()) or + (device == 'mps' and not torch.backends.mps.is_available())): + print(f' [WARN]: Specified device ({device}) is not available, switching to cpu.') + device = 'cpu' + return device