Skip to content

Commit

Permalink
fix docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
mcw92 committed Feb 7, 2024
1 parent 5b5eb8f commit 464ed73
Showing 1 changed file with 15 additions and 16 deletions.
31 changes: 15 additions & 16 deletions tutorials/torch_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@


GPUS_PER_NODE: int = 4

log_path = "torch_ckpts"


Expand All @@ -36,13 +35,13 @@ def __init__(
Parameters
----------
conv_layers: int
number of convolutional layers
The number of convolutional layers.
activation: torch.nn.modules.activation
activation function to use
The activation function to use.
lr: float
learning rate
loss_fn: torch.nn.modules.loss
loss function
The loss function.
"""
super(Net, self).__init__()

Expand Down Expand Up @@ -78,12 +77,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
Parameters
----------
x: torch.Tensor
data sample
The data sample.
Returns
-------
torch.Tensor
The model's predictions for input data sample
The model's predictions for input data sample.
"""
b, c, w, h = x.size()
x = self.conv_layers(x)
Expand All @@ -100,14 +99,14 @@ def training_step(
Parameters
----------
batch: Tuple[torch.Tensor, torch.Tensor]
input batch
The input batch.
batch_idx: int
batch index
Its batch index.
Returns
-------
torch.Tensor
training loss for input batch
The training loss for this input batch.
"""
x, y = batch
pred = self(x)
Expand All @@ -126,14 +125,14 @@ def validation_step(
Parameters
----------
batch: Tuple[torch.Tensor, torch.Tensor]
current batch
The current batch
batch_idx: int
batch index
The batch index.
Returns
-------
torch.Tensor
validation loss for input batch
The validation loss for the input batch.
"""
x, y = batch
pred = self(x)
Expand All @@ -150,7 +149,7 @@ def configure_optimizers(self) -> torch.optim.SGD:
Returns
-------
torch.optim.sgd.SGD
stochastic gradient descent optimizer
A stochastic gradient descent optimizer.
"""
return torch.optim.SGD(self.parameters(), lr=self.lr)

Expand All @@ -171,14 +170,14 @@ def get_data_loaders(batch_size: int) -> Tuple[DataLoader, DataLoader]:
Parameters
----------
batch_size: int
batch size
The batch size.
Returns
-------
DataLoader
training dataloader
The training dataloader.
DataLoader
validation dataloader
The validation dataloader.
"""
data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
num_workers = 2
Expand Down

0 comments on commit 464ed73

Please sign in to comment.