Skip to content

Commit

Permalink
[legacy] move trainer to legacy (#4545)
Browse files Browse the repository at this point in the history
* [legacy] move trainer to legacy

* [doc] update docs related to trainer

* [test] ignore legacy test
  • Loading branch information
ver217 committed Sep 5, 2023
1 parent 807e01a commit 89fe027
Show file tree
Hide file tree
Showing 32 changed files with 63 additions and 153 deletions.
Empty file added colossalai/legacy/__init__.py
Empty file.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from typing import Union, List, Any
from typing import Any, List, Union

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from colossalai.engine import Engine
from colossalai.legacy.trainer.hooks import BaseHook
from colossalai.logging import DistributedLogger
from colossalai.utils import MultiTimer
from colossalai.utils import is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage
from colossalai.trainer.hooks import BaseHook
from colossalai.utils import MultiTimer, is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0


class Trainer:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from ._base_hook import BaseHook
from ._checkpoint_hook import SaveCheckpointHook
from ._log_hook import (LogMemoryByEpochHook, LogMetricByEpochHook, LogMetricByStepHook, LogTimingByEpochHook,
TensorboardHook)
from ._log_hook import (
LogMemoryByEpochHook,
LogMetricByEpochHook,
LogMetricByStepHook,
LogTimingByEpochHook,
TensorboardHook,
)
from ._lr_scheduler_hook import LRSchedulerHook
from ._metric_hook import AccuracyHook, LossHook, MetricHook, ThroughputHook

Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
from colossalai.logging import get_dist_logger

from colossalai.legacy.trainer.hooks import BaseHook
from colossalai.logging import get_dist_logger
from colossalai.registry import HOOKS
from colossalai.trainer.hooks import BaseHook
from colossalai.utils.checkpointing import save_checkpoint

from ._lr_scheduler_hook import LRSchedulerHook


Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@

import os
import os.path as osp

from typing import List

from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import HOOKS
from colossalai.legacy.trainer.hooks._metric_hook import ThroughputMetric
from colossalai.logging import DistributedLogger
from colossalai.utils import report_memory_usage, is_dp_rank_0, \
is_tp_rank_0, is_no_pp_or_last_stage, MultiTimer
from colossalai.registry import HOOKS
from colossalai.utils import MultiTimer, is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0, report_memory_usage

from ._base_hook import BaseHook
from ._commons_ import _format_number
from colossalai.trainer.hooks._metric_hook import ThroughputMetric


class LogByEpochHook(BaseHook):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from colossalai.registry import HOOKS
from torch import Tensor

from colossalai.registry import HOOKS

from ._metric_hook import LearningRateMetric, MetricHook


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch
import torch.distributed as dist

from colossalai.communication import all_reduce
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
Expand All @@ -19,8 +20,8 @@
class Metric(ABC):
"""A basic class of metric collectors. It collects a specific
metric during training or evaluation and would always be used with
:class:`MetricHook` to help it update its states and show the
metric. So please use corresponding hook class to make the metric
:class:`MetricHook` to help it update its states and show the
metric. So please use corresponding hook class to make the metric
collector works.
Args:
Expand Down Expand Up @@ -220,9 +221,9 @@ def is_better(a, b) -> bool:


class MetricHook(BaseHook):
"""Specialized hook classes for :class:`Metric`.
Some help metric collectors initialize, reset and
update their states. Others are used to display and
"""Specialized hook classes for :class:`Metric`.
Some help metric collectors initialize, reset and
update their states. Others are used to display and
record the metric.
Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ from colossalai.engine.schedule import (InterleavedPipelineSchedule,
PipelineSchedule)
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper
from colossalai.trainer import Trainer, hooks
from colossalai.legacy.trainer import Trainer, hooks
from colossalai.utils.timer import MultiTimer
from model_zoo.gpt import GPTLMLoss
from torch.nn import functional as F
Expand Down Expand Up @@ -268,3 +268,4 @@ def train():
return_output_label=False,
)
```
<!-- doc-test-command: echo -->
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ from colossalai.builder import build_pipeline_model
from colossalai.engine.schedule import (InterleavedPipelineSchedule,
PipelineSchedule)
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.trainer import Trainer, hooks
from colossalai.legacy.trainer import Trainer, hooks
from colossalai.utils import MultiTimer, get_dataloader
from timm.models import vision_transformer as vit
from torchvision import transforms
Expand Down Expand Up @@ -245,3 +245,4 @@ def train():
hooks=hook_list,
display_progress=True)
```
<!-- doc-test-command: echo -->
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.lr_scheduler import LinearWarmupLR
from colossalai.nn.metric import Accuracy
from colossalai.trainer import Trainer, hooks
from colossalai.legacy.trainer import Trainer, hooks
```

- Other modules
Expand Down Expand Up @@ -644,3 +644,4 @@ torchrun --standalone --nproc_per_node <NUM_GPUs> train_hybrid.py --config ./co
# If your torch >= 1.9.0
# python -m torch.distributed.run --standalone --nproc_per_node= <NUM_GPUs> train_hybrid.py --config ./configs/config_hybrid_parallel.py
```
<!-- doc-test-command: echo -->
7 changes: 4 additions & 3 deletions docs/source/en/basics/engine_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ Trainer is a more high-level wrapper for the user to execute training with fewer

```python
from colossalai.logging import get_dist_logger
from colossalai.trainer import Trainer, hooks
from colossalai.legacy.trainer import Trainer, hooks

# build components and initialize with colossalai.initialize
...
Expand Down Expand Up @@ -107,7 +107,7 @@ If you want to customize your own hook class, you can inherit `hooks.BaseHook` a

```python
from colossalai.logging import get_dist_logger
from colossalai.trainer import hooks
from colossalai.legacy.trainer import hooks

class LogMessageHook(hooks.BaseHook):

Expand Down Expand Up @@ -345,7 +345,7 @@ If you wish to train with a trainer object, you can follow the code snippet belo

```python
from colossalai.nn.metric import Accuracy
from colossalai.trainer import Trainer, hooks
from colossalai.legacy.trainer import Trainer, hooks


# create a trainer object
Expand Down Expand Up @@ -387,3 +387,4 @@ python -m torch.distributed.launch --nproc_per_node <num_gpus> --master_addr loc
# with trainer
python -m torch.distributed.launch --nproc_per_node <num_gpus> --master_addr localhost --master_port 29500 run_resnet_cifar10_with_trainer.py
```
<!-- doc-test-command: echo -->
3 changes: 2 additions & 1 deletion docs/source/en/basics/model_checkpoint.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ for epoch in range(num_epochs):

#### Save when using trainer
```python
from colossalai.trainer import Trainer, hooks
from colossalai.legacy.trainer import Trainer, hooks
model = ...
engine, _, _, _ = colossalai.initialize(model=model, ...)
trainer = Trainer(engine, ...)
Expand All @@ -61,3 +61,4 @@ model = ...
load_checkpoint('xxx.pt', model)
... # train or test
```
<!-- doc-test-command: echo -->
2 changes: 1 addition & 1 deletion docs/source/en/features/mixed_precision_training.md
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ from pathlib import Path
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.utils import get_dataloader
from colossalai.trainer import Trainer, hooks
from colossalai.legacy.trainer import Trainer, hooks
from colossalai.nn.lr_scheduler import LinearWarmupLR
from timm.models import vit_base_patch16_224
from torchvision import datasets, transforms
Expand Down
3 changes: 2 additions & 1 deletion docs/source/en/features/pipeline_parallel.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ import colossalai.nn as col_nn

from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.trainer import Trainer, hooks
from colossalai.legacy.trainer import Trainer, hooks
from colossalai.utils import MultiTimer, get_dataloader
from colossalai.context import ParallelMode
from colossalai.pipeline.pipelinable import PipelinableContext
Expand Down Expand Up @@ -157,3 +157,4 @@ trainer.fit(train_dataloader=train_dataloader,
```

We use `2` pipeline stages and the batch will be split into `4` micro batches.
<!-- doc-test-command: echo -->
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ from colossalai.engine.schedule import (InterleavedPipelineSchedule,
PipelineSchedule)
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper
from colossalai.trainer import Trainer, hooks
from colossalai.legacy.trainer import Trainer, hooks
from colossalai.utils.timer import MultiTimer
from model_zoo.gpt import GPTLMLoss
from torch.nn import functional as F
Expand Down Expand Up @@ -273,3 +273,4 @@ def train():
return_output_label=False,
)
```
<!-- doc-test-command: echo -->
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ from colossalai.builder import build_pipeline_model
from colossalai.engine.schedule import (InterleavedPipelineSchedule,
PipelineSchedule)
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.trainer import Trainer, hooks
from colossalai.legacy.trainer import Trainer, hooks
from colossalai.utils import MultiTimer, get_dataloader
from timm.models import vision_transformer as vit
from torchvision import transforms
Expand Down Expand Up @@ -244,3 +244,4 @@ def train():
hooks=hook_list,
display_progress=True)
```
<!-- doc-test-command: echo -->
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.lr_scheduler import LinearWarmupLR
from colossalai.nn.metric import Accuracy
from colossalai.trainer import Trainer, hooks
from colossalai.legacy.trainer import Trainer, hooks
```

- 其他模块
Expand Down Expand Up @@ -589,3 +589,4 @@ torchrun --standalone --nproc_per_node <NUM_GPUs> train_hybrid.py --config ./co
# If your torch >= 1.9.0
# python -m torch.distributed.run --standalone --nproc_per_node= <NUM_GPUs> train_hybrid.py --config ./configs/config_hybrid_parallel.py
```
<!-- doc-test-command: echo -->
7 changes: 4 additions & 3 deletions docs/source/zh-Hans/basics/engine_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ Trainer 的参数 `schedule` 默认值是 `None` 。在大多数情况下,除

```python
from colossalai.logging import get_dist_logger
from colossalai.trainer import Trainer, hooks
from colossalai.legacy.trainer import Trainer, hooks

# build components and initialize with colossalai.initialize
...
Expand Down Expand Up @@ -104,7 +104,7 @@ trainer.fit(

```python
from colossalai.logging import get_dist_logger
from colossalai.trainer import hooks
from colossalai.legacy.trainer import hooks

class LogMessageHook(hooks.BaseHook):

Expand Down Expand Up @@ -341,7 +341,7 @@ for epoch in range(gpc.config.NUM_EPOCHS):

```python
from colossalai.nn.metric import Accuracy
from colossalai.trainer import Trainer, hooks
from colossalai.legacy.trainer import Trainer, hooks


# create a trainer object
Expand Down Expand Up @@ -384,3 +384,4 @@ python -m torch.distributed.launch --nproc_per_node <num_gpus> --master_addr loc
# with trainer
python -m torch.distributed.launch --nproc_per_node <num_gpus> --master_addr localhost --master_port 29500 run_resnet_cifar10_with_trainer.py
```
<!-- doc-test-command: echo -->
3 changes: 2 additions & 1 deletion docs/source/zh-Hans/basics/model_checkpoint.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ for epoch in range(num_epochs):

#### 用 trainer 保存
```python
from colossalai.trainer import Trainer, hooks
from colossalai.legacy.trainer import Trainer, hooks
model = ...
engine, _, _, _ = colossalai.initialize(model=model, ...)
trainer = Trainer(engine, ...)
Expand All @@ -61,3 +61,4 @@ model = ...
load_checkpoint('xxx.pt', model)
... # train or test
```
<!-- doc-test-command: echo -->
2 changes: 1 addition & 1 deletion docs/source/zh-Hans/features/mixed_precision_training.md
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ from pathlib import Path
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.utils import get_dataloader
from colossalai.trainer import Trainer, hooks
from colossalai.legacy.trainer import Trainer, hooks
from colossalai.nn.lr_scheduler import LinearWarmupLR
from timm.models import vit_base_patch16_224
from torchvision import datasets, transforms
Expand Down
3 changes: 2 additions & 1 deletion docs/source/zh-Hans/features/pipeline_parallel.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ import colossalai.nn as col_nn

from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.trainer import Trainer, hooks
from colossalai.legacy.trainer import Trainer, hooks
from colossalai.utils import MultiTimer, get_dataloader
from colossalai.context import ParallelMode
from colossalai.pipeline.pipelinable import PipelinableContext
Expand Down Expand Up @@ -156,3 +156,4 @@ trainer.fit(train_dataloader=train_dataloader,
```

我们使用 `2` 个流水段,并且 batch 将被切分为 `4` 个 micro batches。
<!-- doc-test-command: echo -->
2 changes: 1 addition & 1 deletion examples/language/gpt/titans/train_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
import colossalai.utils as utils
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.legacy.trainer import Trainer, hooks
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn import LinearWarmupLR
from colossalai.trainer import Trainer, hooks
from colossalai.utils import colo_set_process_memory_fraction, is_using_pp
from colossalai.utils.timer import MultiTimer
from colossalai.zero.legacy.init_ctx import ZeroInitContext
Expand Down
2 changes: 1 addition & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ markers =
gpu: tests which requires a single GPU
dist: tests which are run in a multi-GPU or multi-machine environment
experiment: tests for experimental features
addopts = --ignore=tests/test_analyzer --ignore=tests/test_auto_parallel --ignore=tests/test_autochunk --ignore=tests/test_moe --ignore=tests/test_fx
addopts = --ignore=tests/test_analyzer --ignore=tests/test_auto_parallel --ignore=tests/test_autochunk --ignore=tests/test_moe --ignore=tests/test_fx --ignore=tests/test_legacy
Loading

0 comments on commit 89fe027

Please sign in to comment.