Skip to content

Commit

Permalink
Fix bottleneck average composition computation (#590)
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt authored Oct 12, 2023
1 parent 5652c0b commit 6da07d1
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/adapters/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,8 +631,8 @@ def adapter_average_output(self, adapter_setup: Average, hidden_states, input_te
)
# Case X: No adapter which is part of this module -> ignore

weights = torch.tensor(adapter_setup.weights).unsqueeze(1).unsqueeze(1).to(hidden_states.device)
hidden_states = torch.mean(torch.cat(children_hidden, 0) * weights, 0)
weights = torch.tensor(adapter_setup.weights)[:, None, None, None].to(hidden_states.device)
hidden_states = torch.mean(torch.stack(children_hidden, 0) * weights, 0)

return hidden_states

Expand Down
4 changes: 2 additions & 2 deletions tests_adapters/composition/test_adapter_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,9 @@ def test_average(self):
model.set_active_adapters(Average("a", "b", "c", "d"))

inputs = {}
inputs["input_ids"] = ids_tensor((1, 128), 1000)
inputs["input_ids"] = ids_tensor((2, 128), 1000)
logits = model(**inputs).logits
self.assertEqual(logits.shape, (1, 2))
self.assertEqual(logits.shape, (2, 2))


class PrefixTuningCompositionTest(AdapterCompositionTest):
Expand Down

0 comments on commit 6da07d1

Please sign in to comment.