diff --git a/galai/__init__.py b/galai/__init__.py index 659c961..ba535a6 100644 --- a/galai/__init__.py +++ b/galai/__init__.py @@ -71,7 +71,7 @@ def load_model( dtype = default_dtype if isinstance(dtype, str): - dtype = getattr(torch, "float16", None) + dtype = getattr(torch, dtype, None) if dtype not in (torch.float16, torch.float32, torch.bfloat16): raise ValueError( f"Unsupported dtype: {dtype}" diff --git a/setup.py b/setup.py index cf85197..c263d02 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ from setuptools import setup, find_packages PACKAGE_NAME = 'galai' -VERSION = "1.1.6" +VERSION = "1.1.7.dev0" DESCRIPTION = "API for the GALACTICA model" KEYWORDS = "Scientific Intelligence" URL = 'https://github.com/paperswithcode/galai'