Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanLee97 committed May 29, 2024
1 parent 94657ea commit dada88d
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion angle_emb/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def l2_normalize(arr: np.ndarray) -> np.ndarray:
Normalize array using L2
:param arr: np.ndarray, input array
:return: np.ndarray
"""
norms = (arr**2).sum(axis=1, keepdims=True)**0.5
Expand Down Expand Up @@ -791,6 +792,7 @@ def distillation_loss(self,
:param targets: torch.Tensor. Target tensor.
:param mse_weight: float. MSE weight. Default 1.0.
:param kl_temperature: float. KL temperature. Default 1.0.
:return: torch.Tensor. Distillation loss.
"""
loss = 0.
Expand All @@ -809,6 +811,7 @@ def compute_loss(self, model, inputs, return_outputs=False):
:param model: Huggingface model.
:param inputs: Dict. Model inputs.
:param return_outputs: bool. Return outputs or not. Default False.
:return: torch.Tensor. Loss.
"""
labels = inputs.pop("labels", None)
Expand Down Expand Up @@ -843,7 +846,6 @@ class AngleESETrainer(AngleTrainer):
:param loss_kwargs: Optional[Dict]. Default None.
:param dataset_format: str. Default DatasetFormats.A
:param teacher_name_or_path: Optional[str]. For distribution alignment.
:param **kwargs: other parameters of Trainer.
"""
def __init__(self,
Expand All @@ -869,8 +871,10 @@ def __init__(self,
@torch.no_grad()
def pca_compress(self, m: torch.Tensor, k: int) -> torch.Tensor:
""" Get topk feature via PCA.
:param m: torch.Tensor. Input tensor.
:param k: int. Top-k feature size.
:return: torch.Tensor. Top-k feature.
"""
A = F.softmax(m.T @ m / m.shape[-1]**0.5, dim=-1)
Expand Down Expand Up @@ -907,9 +911,11 @@ def compute_student_loss(self,

def compute_loss(self, model, inputs, return_outputs=False):
""" Compute loss for Espresso.
:param model: Huggingface model.
:param inputs: Dict. Model inputs.
:param return_outputs: bool. Return outputs or not. Default False.
:return: torch.Tensor. Loss.
"""
labels = inputs.pop("labels", None)
Expand Down Expand Up @@ -997,6 +1003,7 @@ def __call__(self,
:param labels: torch.Tensor. Labels.
:param outputs: torch.Tensor. Outputs.
:return: torch.Tensor. Loss.
"""
if self.dataset_format == DatasetFormats.A:
Expand Down

0 comments on commit dada88d

Please sign in to comment.