diff --git a/colossalai/nn/layer/moe/moe_param.py b/colossalai/nn/layer/moe/moe_param.py index 65aed3870cd6..11d07ef8c804 100644 --- a/colossalai/nn/layer/moe/moe_param.py +++ b/colossalai/nn/layer/moe/moe_param.py @@ -14,5 +14,13 @@ def is_moe_param(tensor: torch.Tensor) -> bool: return hasattr(tensor, "moe_info") -def set_moe_param_info(tensor: torch.Tensor, moe_info: dict): +def set_moe_param_info(tensor: torch.Tensor, moe_info: dict) -> None: + """ + Set moe info for the given tensor. + + Args: + tensor (torch.Tensor): The tensor to be set. + moe_info (dict): The moe info to be set. + + """ tensor.__setattr__('moe_info', moe_info)