Skip to content

Commit

Permalink
train
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink committed Feb 15, 2024
1 parent 5371ded commit e45ce1b
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion dacapo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,17 @@


def train(run_name: str, compute_context: ComputeContext = LocalTorch()):
"""Train a run"""
"""
Trains a model with the given run name using the specified compute context.
Args:
run_name (str): The name of the run.
compute_context (ComputeContext, optional): The compute context to use for training. Defaults to LocalTorch(),
Can be set to distribute Bsub() to using LSF cluster.
Returns:
The trained model.
"""
if compute_context.train(run_name):
logger.error("Run %s is already being trained", run_name)
# if compute context runs train in some other process
Expand All @@ -36,6 +45,15 @@ def train_run(
run: Run,
compute_context: ComputeContext = LocalTorch(),
):
"""
Trains the model for a given run.
Args:
run (Run): The run object containing the model, optimizer, and other training parameters.
compute_context (ComputeContext, optional): The compute context for training. Defaults to LocalTorch(),
Can be set to distribute Bsub() to using LSF cluster.
"""
logger.info("Starting/resuming training for run %s...", run)

# create run
Expand Down

0 comments on commit e45ce1b

Please sign in to comment.