diff --git a/ldm/generate.py b/ldm/generate.py index 8d68314f386..3f2c9347617 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -986,6 +986,8 @@ def _set_sampler(self): self.sampler = KSampler(self.model, 'heun', device=self.device) elif self.sampler_name == 'k_lms': self.sampler = KSampler(self.model, 'lms', device=self.device) + elif self.sampler_name[:2] == 'k_': # and f'sample_{self.sampler_name[2:]}' in K.sampling.__dict__: + self.sampler = KSampler(self.model, self.sampler_name[2:], device=self.device) else: msg = f'>> Unsupported Sampler: {self.sampler_name}, Defaulting to plms' self.sampler = PLMSSampler(self.model, device=self.device)