diff --git a/megablocks/layers/arguments.py b/megablocks/layers/arguments.py index 6492dfd..ec0eb4a 100644 --- a/megablocks/layers/arguments.py +++ b/megablocks/layers/arguments.py @@ -52,7 +52,7 @@ class Arguments: # Initialization arguments. fp16: bool = True bf16: bool = False - device: Union[int, torch.device] = torch.cuda.current_device() + device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device) init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02) output_layer_init_method: InitFn = init_method