diff --git a/torchhd/tests/test_models.py b/torchhd/tests/test_models.py index 93a9eca7..a721a31f 100644 --- a/torchhd/tests/test_models.py +++ b/torchhd/tests/test_models.py @@ -113,7 +113,9 @@ def test_initialization(self, dtype): assert model.weight.device.type == device.type def test_fit_ridge_regression(self): - samples = torch.eye(10, 12) + a = torch.randn(10) + b = torch.randn(12) + samples = torch.outer(a, b) targets = torch.arange(10) model = models.IntRVFL(12, 1245, 10)