diff --git a/megablocks/layers/router.py b/megablocks/layers/router.py index a6deae0..2c9dcd9 100644 --- a/megablocks/layers/router.py +++ b/megablocks/layers/router.py @@ -79,9 +79,8 @@ def __init__(self, args: Arguments): args.init_method(self.layer.weight) def jitter(self, x: torch.Tensor): - assert isinstance(self.args.moe_jitter_eps, float) - low = 1.0 - self.args.moe_jitter_eps - high = 1.0 + self.args.moe_jitter_eps + low: float = 1.0 - self.args.moe_jitter_eps + high: float = 1.0 + self.args.moe_jitter_eps noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) return low + noise * (high - low)