diff --git a/hubconf.py b/hubconf.py index 84d0665..ab47bc6 100644 --- a/hubconf.py +++ b/hubconf.py @@ -126,8 +126,9 @@ def load_model(repo_id, unet_subfolder, device="cuda", local_dir=None): vae.load_state_dict(load_ckpt_vae) # Load empty text embed - empty_text_embed = torch.from_numpy(np.load('empty_text_embed.npy')).to(device, torch.float32)[None] - + empty_text_embed = torch.from_numpy(np.load(os.path.join(os.path.dirname(os.path.abspath(__file__)), + 'empty_text_embed.npy'))).to(device, torch.float32)[None] + genpercept_params_ckpt = dict( unet=unet, vae=vae,