diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index ac7bb8e7b..0dac35bf1 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -434,7 +434,19 @@ def forward(self, x: torch.Tensor): class LinearFP4(Linear4bit): + """ + Implements the FP4 data type. + """ def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None): + """ + Args: + input_features (`str`): + Number of input features of the linear layer. + output_features (`str`): + Number of output features of the linear layer. + bias (`bool`, defaults to `True`): + Whether the linear class uses the bias term as well. + """ super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', quant_storage, device) @@ -450,6 +462,15 @@ class LinearNF4(Linear4bit): the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236. ''' def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None): + """ + Args: + input_features (`str`): + Number of input features of the linear layer. + output_features (`str`): + Number of output features of the linear layer. + bias (`bool`, defaults to `True`): + Whether the linear class uses the bias term as well. + """ super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', quant_storage, device)