Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

badcase #41

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
5 changes: 2 additions & 3 deletions configs/palm_16b_tp1_zero.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
from colossalai.zero.shard_utils import TensorShardStrategy

SEQ_LENGTH = 2048
BATCH_SIZE = 4
BATCH_SIZE = 12
NUM_EPOCHS = 1
# WARMUP_EPOCHS = 1

parallel = dict(
# tensor=dict(mode="1d", size=2),
)

model = dict(
type="palm_16b",
use_grad_checkpoint=True,
use_act_offload=True,
use_act_offload=False,
)

zero = dict(
Expand Down
2 changes: 1 addition & 1 deletion configs/palm_16b_tp2_zero.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from colossalai.zero.shard_utils import TensorShardStrategy

SEQ_LENGTH = 2048
BATCH_SIZE = 4
BATCH_SIZE = 12
NUM_EPOCHS = 1
# WARMUP_EPOCHS = 1

Expand Down
2 changes: 1 addition & 1 deletion configs/palm_16b_tp4.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from colossalai.zero.shard_utils import TensorShardStrategy

SEQ_LENGTH = 2048
BATCH_SIZE = 4
BATCH_SIZE = 8
NUM_EPOCHS = 1
# WARMUP_EPOCHS = 1

Expand Down
6 changes: 3 additions & 3 deletions configs/palm_8b_zero.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from colossalai.zero.shard_utils import TensorShardStrategy

SEQ_LENGTH = 2048
BATCH_SIZE = 8
BATCH_SIZE = 16
NUM_EPOCHS = 1
# WARMUP_EPOCHS = 1

parallel = dict(
tensor=dict(mode="1d", size=2),
tensor=dict(mode="2.5d", size=4, depth = 1),
)

model = dict(
Expand All @@ -18,7 +18,7 @@
zero = dict(
model_config=dict(
shard_strategy=TensorShardStrategy(),
tensor_placement_policy='auto',
tensor_placement_policy='cpu',
),
optimizer_config=dict(
gpu_margin_mem_ratio = 0.8,
Expand Down
28 changes: 28 additions & 0 deletions configs/palm_8b_zero_2d_badcase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from colossalai.zero.shard_utils import TensorShardStrategy

SEQ_LENGTH = 2048
BATCH_SIZE = 16
NUM_EPOCHS = 1
# WARMUP_EPOCHS = 1

parallel = dict(
tensor=dict(mode="2d", size=8),
)

model = dict(
type="palm_8b",
use_grad_checkpoint=True,
use_act_offload=False,
)

zero = dict(
model_config=dict(
shard_strategy=TensorShardStrategy(),
tensor_placement_policy='cpu',
),
optimizer_config=dict(
gpu_margin_mem_ratio = 0.8,
initial_scale=2**5,
)
)

28 changes: 28 additions & 0 deletions configs/palm_8b_zero_3d_badcase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from colossalai.zero.shard_utils import TensorShardStrategy

SEQ_LENGTH = 2048
BATCH_SIZE = 16
NUM_EPOCHS = 1
# WARMUP_EPOCHS = 1

parallel = dict(
tensor=dict(mode="3d", size=8),
)

model = dict(
type="palm_8b",
use_grad_checkpoint=True,
use_act_offload=False,
)

zero = dict(
model_config=dict(
shard_strategy=TensorShardStrategy(),
tensor_placement_policy='cpu',
),
optimizer_config=dict(
gpu_margin_mem_ratio = 0.8,
initial_scale=2**5,
)
)

31 changes: 31 additions & 0 deletions configs/palm_8b_zero_gemini_badcase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from colossalai.zero.shard_utils import TensorShardStrategy

SEQ_LENGTH = 2048
BATCH_SIZE = 16
NUM_EPOCHS = 1
# WARMUP_EPOCHS = 1

parallel = dict(
tensor=dict(mode="1d", size=2),
)

model = dict(
type="palm_8b",
use_grad_checkpoint=True,
use_act_offload=False,
)


# auto will fail
# cpu will pass
zero = dict(
model_config=dict(
shard_strategy=TensorShardStrategy(),
tensor_placement_policy='cpu',
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto will fail

),
optimizer_config=dict(
gpu_margin_mem_ratio = 0.8,
initial_scale=2**5,
)
)

2 changes: 1 addition & 1 deletion data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def build_data(**args):
), f"Invalid dataset name. dataset should be in {_datasets.keys()} or use default wikitext"
builder = _datasets[gpc.config.dataset]
else:
builder = _datasets["wikitext"]
builder = _datasets["test"]
return builder(**args)


Expand Down
16 changes: 12 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from data import build_data
from model import build_loss, build_model
from utils import AutoregressiveWrapper, calc_local_model_size, calc_mem
from colossalai.utils import colo_set_process_memory_fraction, colo_device_memory_capacity
from colossalai.utils import colo_set_process_memory_fraction, colo_device_memory_capacity, colo_set_cpu_memory_capacity


def limit_cuda_memory(size_in_GB: int):
Expand All @@ -24,11 +24,18 @@ def limit_cuda_memory(size_in_GB: int):
colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
logger = get_dist_logger()
logger.info("Using {} GB of GPU memory".format(size_in_GB))


def limit_cpu_memory(size_in_GB: int):
colo_set_cpu_memory_capacity(size_in_GB * 1024 ** 3)

def train_palm():
assert torch.cuda.is_available()
# set to 40GB, if you are using a high-end GPU.
# limit cuda memory of each GPU to 40GB, if you are using a high-end GPU.
limit_cuda_memory(40)

# limit the cpu memory of the CPU to 312 GB
# limit_cpu_memory(312)

disable_existing_loggers()
parser = colossalai.get_default_parser()
parser.add_argument("--from_torch", default=False, action="store_true")
Expand Down Expand Up @@ -81,6 +88,7 @@ def train_palm():
else:
numel = calc_local_model_size(model)

# global Tera FLOating Points operations per iteration.
tflop = numel * batch_size * seq_len \
* gpc.get_world_size(ParallelMode.MODEL) * gpc.get_world_size(ParallelMode.DATA) * 8 / (1024 ** 4)

Expand Down Expand Up @@ -144,7 +152,7 @@ def batch_data_process_func(batch_data):
hooks.LogMetricByEpochHook(logger=logger),
hooks.LogMetricByStepHook(),
hooks.LossHook(),
hooks.ThroughputHook(ignored_steps=10, tflop_per_step = tflop),
hooks.ThroughputHook(ignored_steps=10, tflop_per_step = tflop, use_local = False),
# hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False),
hooks.LogMemoryByEpochHook(logger),
# hooks.SaveCheckpointHook(checkpoint_dir="./palm.ckpt", model=model),
Expand Down