Skip to content

Commit

Permalink
Format bert.py
Browse files Browse the repository at this point in the history
  • Loading branch information
LLukas22 committed Oct 20, 2023
1 parent 3f21ebb commit d7c3fde
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions candle-pyo3/py_src/candle/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ def forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor:
attention_scores = attention_scores / float(self.attention_head_size) ** 0.5
if attention_mask is not None:
b_size, _, _, last_dim = attention_scores.shape
attention_scores = attention_scores.broadcast_add(
attention_mask.reshape((b_size, 1, 1, last_dim)))
attention_scores = attention_scores.broadcast_add(attention_mask.reshape((b_size, 1, 1, last_dim)))
attention_probs = F.softmax(attention_scores, dim=-1)

context_layer = attention_probs.matmul(value)
Expand Down Expand Up @@ -198,7 +197,9 @@ def __init__(self, config: Config, add_pooling_layer=True) -> None:
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config) if add_pooling_layer else None

def forward(self, input_ids: Tensor, token_type_ids: Tensor, attention_mask=None) -> Tuple[Tensor, Optional[Tensor]]:
def forward(
self, input_ids: Tensor, token_type_ids: Tensor, attention_mask=None
) -> Tuple[Tensor, Optional[Tensor]]:
if attention_mask is not None:
# Replace 0s with -inf, and 1s with 0s.
attention_mask = masked_fill(float("-inf"), attention_mask, 1.0)
Expand Down

0 comments on commit d7c3fde

Please sign in to comment.