Skip to content

Commit

Permalink
feat: more flexibel/configurable TrainValStage
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Apr 5, 2024
1 parent daa9d08 commit 16de972
Showing 1 changed file with 37 additions and 20 deletions.
57 changes: 37 additions & 20 deletions dmlcloud/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,34 @@ def __init__(self):
super().__init__()
self.is_train = True

def train_dataset(self):
train_ds = self.pipeline.datasets.get('train')
if train_ds is None:
raise ValueError(
'No "train" dataset found in pipeline. Use register_dataset("train", ...) to register a dataset.'
)
return train_ds

def val_dataset(self):
val_ds = self.pipeline.datasets.get('val')
if val_ds is None:
raise ValueError(
'No "val" dataset found in pipeline. Use register_dataset("val", ...) to register a dataset.'
)
return val_ds

def optimizers(self):
return self.pipeline.optimizers.values()

def loss_metric_name(self):
return 'loss'

def train_metric_prefix(self):
return 'train'

def val_metric_prefix(self):
return 'val'

def run_epoch(self):
self.train_epoch()
self.val_epoch()
Expand All @@ -240,49 +268,38 @@ def val_step(self, batch):

def train_epoch(self):
self.is_train = True
self.metric_prefix = 'train'

train_ds = self.pipeline.datasets.get('train')
if train_ds is None:
raise ValueError(
'No "train" dataset found in pipeline. Use register_dataset("train", ...) to register a dataset.'
)
self.metric_prefix = self.train_metric_prefix()

train_ds = self.train_dataset()
if hasattr(train_ds, 'sampler') and hasattr(train_ds.sampler, 'set_epoch'):
train_ds.sampler.set_epoch(self.current_epoch)

for batch in train_ds:
for optimizer in self.pipeline.optimizers.values():
for optimizer in self.optimizers():
optimizer.zero_grad()

loss = self.train_step(batch)
loss.backward()

for optimizer in self.pipeline.optimizers.values():
for optimizer in self.optimizers():
optimizer.step()

self.track_reduce('loss', loss)
self.track_reduce(self.loss_metric_name(), loss)

for scheduler in self.pipeline.schedulers.values():
scheduler.step()

@torch.no_grad()
def val_epoch(self):
self.is_train = False
self.metric_prefix = 'val'

val_ds = self.pipeline.datasets.get('val')
if val_ds is None:
raise ValueError(
'No "val" dataset found in pipeline. Use register_dataset("val", ...) to register a dataset.'
)
self.metric_prefix = self.val_metric_prefix()

for batch in val_ds:
for batch in self.val_dataset():
loss = self.val_step(batch)
self.track_reduce('loss', loss)

def table_columns(self):
columns = super().table_columns()
columns.insert(1, {'name': '[Train] Loss', 'metric': 'train/loss'})
columns.insert(2, {'name': '[Val] Loss', 'metric': 'val/loss'})
columns.insert(1, {'name': '[Train] Loss', 'metric': f'{self.train_metric_prefix()}/{self.loss_metric_name()}'})
columns.insert(2, {'name': '[Val] Loss', 'metric': f'{self.val_metric_prefix()}/{self.loss_metric_name()}'})
return columns

0 comments on commit 16de972

Please sign in to comment.