diff --git a/tests/test_end2end.py b/tests/test_end2end.py index c2c18482..7a7161d5 100644 --- a/tests/test_end2end.py +++ b/tests/test_end2end.py @@ -2,5 +2,5 @@ def test_model_trainer_fit(multimodal_model, sample_train_val_datamodule): - trainer = lightning.pytorch.trainer.trainer.Trainer(fast_dev_run=True) + trainer = lightning.pytorch.trainer.trainer.Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model=multimodal_model, datamodule=sample_train_val_datamodule)