Skip to content

Commit

Permalink
Updated readme and added checkpointing to training funciton
Browse files Browse the repository at this point in the history
  • Loading branch information
fkapl committed Aug 14, 2023
1 parent f3cd999 commit a174a7f
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
12 changes: 9 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ inVAE incorporates biological covariates and mechanisms such as disease states,

1. Load the data: <br/>
```adata = sc.read(path/to/data)```<br/>
2. Optional - Split the data into train, val, test (in supervised case for training prediction)<br/>
2. Optional - Split the data into train, val, test (in supervised case for training classifier as well)<br/>
3. Initialize the model, either Factorized or Non-Factorized:<br/>

```
Expand Down Expand Up @@ -61,7 +61,7 @@ model = NFinVAE(
```

4. Train the generative model: <br/>
```model.train(n_epochs=1, lr_train=0.001, weight_decay=0.0001)```<br/>
```model.train(n_epochs=500, lr_train=0.001, weight_decay=0.0001)```<br/>
5. Get the latent representation: <br/>
```latent = model.get_latent_representation(adata)```<br/>
6. Optional - Train the classifer (for cell types):
Expand All @@ -75,7 +75,13 @@ model.train_classifier(
```

7. Optional - Predict cell types: <br/>
```pred_train = model.predict(adata_test, dataset_type='test')```<br/>
```pred_test = model.predict(adata_test, dataset_type='test')```<br/>

8. Optional - Saving and loading model: <br/>
```
model.save('./checkpoints/path.pt')
model.load('./checkpoints/path.pt')
```<br/>
## Dependencies
Expand Down
13 changes: 10 additions & 3 deletions src/inVAE/model/_abstract_invae.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,11 @@ def get_negative_elbo(
def save(
self,
save_dir: str,
):
print('Saving the pytorch module...')
print('To load the model later you need to save the hyperparameters in a separate file/dictionary.')
verbose: bool = True,
):
if verbose:
print('Saving the pytorch module...')
print('To load the model later you need to save the hyperparameters in a separate file/dictionary.')

torch.save(self.module.state_dict(), save_dir)

Expand Down Expand Up @@ -402,6 +404,8 @@ def train(
log_dir: str = None,
log_freq: int = 25, # in iterations
print_every_n_epochs: int = None,
checkpoint_dir: str = None,
n_checkpoints: int = 0,
):
if n_epochs is None:
n_epochs = 500
Expand Down Expand Up @@ -479,6 +483,9 @@ def train(
if np.isnan(loss_epoch):
print(f'Loss is nan at epoch {int(iteration/len(self.data_loader))}/{n_epochs}, stopping training!')
break

if (n_checkpoints > 0) and (iteration % int(max_iter/n_checkpoints) == 0):
self.save(f'{checkpoint_dir}/checkpoint_epoch_{int(iteration/len(self.data_loader))}.pt')

self.module.eval()
print('Training done!')
Expand Down

0 comments on commit a174a7f

Please sign in to comment.