Skip to content

Commit

Permalink
fix: corrected randomization
Browse files Browse the repository at this point in the history
  • Loading branch information
aski02 committed Jan 8, 2025
1 parent 5882bc8 commit 81e9fdc
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 18 deletions.
31 changes: 25 additions & 6 deletions quanda/metrics/heuristics/model_randomization.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,12 +212,31 @@ def _randomize_model(self) -> Tuple[torch.nn.Module, List[str]]:
parent = get_parent_module_from_name(rand_model, name)
param_name = name.split(".")[-1]

random_param_tensor = torch.nn.init.normal_(
param, generator=self.generator
)
parent.__setattr__(
param_name, torch.nn.Parameter(random_param_tensor)
)
if "weight" in name:
if isinstance(
parent,
(
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
),
):
torch.nn.init.ones_(param)
elif (
isinstance(
parent, (torch.nn.LayerNorm, torch.nn.Embedding)
)
or param.dim() == 1
):
torch.nn.init.normal_(param)
else:
torch.nn.init.kaiming_normal_(
param, generator=self.generator
)
else:
torch.nn.init.normal_(param)

parent.__setattr__(param_name, torch.nn.Parameter(param))

# save randomized checkpoint
chckpt_path = os.path.join(
Expand Down
62 changes: 50 additions & 12 deletions tests/metrics/test_heuristics_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,31 @@ def _randomize_model(
parent = get_parent_module_from_name(rand_model, name)
param_name = name.split(".")[-1]

random_param_tensor = torch.nn.init.normal_(
param, generator=generator
)
parent.__setattr__(
param_name, torch.nn.Parameter(random_param_tensor)
)
if "weight" in name:
if isinstance(
parent,
(
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
),
):
torch.nn.init.ones_(param)
elif (
isinstance(
parent, (torch.nn.LayerNorm, torch.nn.Embedding)
)
or param.dim() == 1
):
torch.nn.init.normal_(param)
else:
torch.nn.init.kaiming_normal_(
param, generator=generator
)
else:
torch.nn.init.normal_(param)

parent.__setattr__(param_name, torch.nn.Parameter(param))

chckpt_path = os.path.join(cache_dir, f"{model_id}_rand_{i}.pth")
torch.save(
Expand Down Expand Up @@ -230,12 +249,31 @@ def _randomize_model(
parent = get_parent_module_from_name(rand_model, name)
param_name = name.split(".")[-1]

random_param_tensor = torch.nn.init.normal_(
param, generator=generator
)
parent.__setattr__(
param_name, torch.nn.Parameter(random_param_tensor)
)
if "weight" in name:
if isinstance(
parent,
(
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
),
):
torch.nn.init.ones_(param)
elif (
isinstance(
parent, (torch.nn.LayerNorm, torch.nn.Embedding)
)
or param.dim() == 1
):
torch.nn.init.normal_(param)
else:
torch.nn.init.kaiming_normal_(
param, generator=generator
)
else:
torch.nn.init.normal_(param)

parent.__setattr__(param_name, torch.nn.Parameter(param))

chckpt_path = os.path.join(cache_dir, f"{model_id}_rand_{i}.pth")
torch.save(rand_model.state_dict(), chckpt_path)
Expand Down

0 comments on commit 81e9fdc

Please sign in to comment.