From e38ba4588c5145f65d7d0aefcbc1dab15d0533d4 Mon Sep 17 00:00:00 2001 From: mihir-db <141708001+mihir-db@users.noreply.github.com> Date: Wed, 16 Oct 2024 20:23:37 -0700 Subject: [PATCH] Update router.py --- megablocks/layers/router.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/megablocks/layers/router.py b/megablocks/layers/router.py index a6deae0..2c9dcd9 100644 --- a/megablocks/layers/router.py +++ b/megablocks/layers/router.py @@ -79,9 +79,8 @@ def __init__(self, args: Arguments): args.init_method(self.layer.weight) def jitter(self, x: torch.Tensor): - assert isinstance(self.args.moe_jitter_eps, float) - low = 1.0 - self.args.moe_jitter_eps - high = 1.0 + self.args.moe_jitter_eps + low: float = 1.0 - self.args.moe_jitter_eps + high: float = 1.0 + self.args.moe_jitter_eps noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) return low + noise * (high - low)