diff --git a/graph_weather/models/gencast/README.md b/graph_weather/models/gencast/README.md new file mode 100644 index 00000000..d5e04e62 --- /dev/null +++ b/graph_weather/models/gencast/README.md @@ -0,0 +1,2 @@ +# GenCast + diff --git a/graph_weather/models/gencast/graph/graph_builder.py b/graph_weather/models/gencast/graph/graph_builder.py index 0a0566e5..23abccb0 100644 --- a/graph_weather/models/gencast/graph/graph_builder.py +++ b/graph_weather/models/gencast/graph/graph_builder.py @@ -71,7 +71,7 @@ def __init__( grid_lat: np.ndarray, splits: int = 5, num_hops: int = 0, - device: str = "cpu", + device: torch.device = torch.device("cpu"), ): """Initialize the GraphBuilder object.