Skip to content

Commit

Permalink
Merge e02ba76 into sapling-pr-archive-EntilZha
Browse files Browse the repository at this point in the history
  • Loading branch information
EntilZha authored Jan 28, 2025
2 parents b1c12dd + e02ba76 commit c2f1e48
Show file tree
Hide file tree
Showing 8 changed files with 219 additions and 221 deletions.
2 changes: 1 addition & 1 deletion apps/main/lingua_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ def train(args: TrainArgs):
if args.eval is not None and every_n_steps(
train_state, args.checkpoint.eval.every, acc_step=0
):
from apps.main.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval
from bytelatent.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval

eval_args = dataclass_from_dict(EvalArgs, args.eval)

Expand Down
87 changes: 86 additions & 1 deletion bytelatent/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import yaml
from omegaconf import OmegaConf
from pydantic import BaseModel, ConfigDict

from bytelatent.checkpoint import CheckpointArgs
Expand Down Expand Up @@ -39,6 +40,19 @@ def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]:
return np.random.default_rng((seed, rank, world_size)).bit_generator.state


def parse_args(args_cls):
cli_args = OmegaConf.from_cli()
file_cfg = OmegaConf.load(cli_args.config)
# We remove 'config' attribute from config as the underlying DataClass does not have it
del cli_args.config

default_cfg = OmegaConf.create(args_cls().model_dump())
cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
pydantic_args = args_cls.model_validate(cfg)
return pydantic_args


def distribute_data_to_rank(
*,
dataset_path: str,
Expand Down Expand Up @@ -71,6 +85,22 @@ def distribute_data_to_rank(
return rank_to_arrow_iterator_params[rank]


class PackedCausalTransformerGeneratorArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
temperature: float = 0.0
top_p: float | None = None
top_k: float | None = None
max_gen_len: int = 512 # Maximum number of tokens to generate
max_tokens: int = 1024 # Maximum number of tokens that can go through the model
max_prompt_len: int | None = None
until: list[str] = []
compile_prefilling: bool = False
reduce_generation_overhead: bool = False
show_progress: bool = False
dtype: str | None = "bf16"
device: str | None = "cuda"


class DataloaderArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
s3_profile: str | None = None
Expand Down Expand Up @@ -168,6 +198,58 @@ def build_from_rank(
return packing_iterator


class LMHarnessArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
tasks: list[Any] | None = None
num_fewshot: int | None = None
device: str | None = None
use_cache: str | None = None
cache_requests: bool = False
rewrite_requests_cache: bool = False
delete_requests_cache: bool = False
limit: int | float | None = None
bootstrap_iters: int = 100000
check_integrity: bool = False
write_out: bool = False
log_samples: bool = True
system_instruction: str | None = None
apply_chat_template: bool | str = False
fewshot_as_multiturn: bool = False
gen_kwargs: str | None = None
verbosity: str = "INFO"
predict_only: bool = False
random_seed: int = 0
numpy_random_seed: int = 1234
torch_random_seed: int = 1234
fewshot_random_seed: int = 1234


class ValidationArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
max_steps: int | None = (
None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
)
use_val_from_train_src: bool = True # Use the validation set from training sources
root_dir: str = ""
sources: list[str] = [] # Other sources to eval on


class EvalArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
dump_dir: str
ckpt_dir: str
metric_log_dir: str | None = None
generator: PackedCausalTransformerGeneratorArgs = (
PackedCausalTransformerGeneratorArgs()
)

harness: LMHarnessArgs | None = LMHarnessArgs()
validation: ValidationArgs | None = ValidationArgs()

global_step: int | None = None # for in-training evaluation
s3_profile: str | None = None


class TrainArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
name: str = "lingua"
Expand All @@ -186,6 +268,9 @@ class TrainArgs(BaseModel):

# Nb optimizer steps to take
steps: int = 1000
# If not None, halt training after this many steps,
# useful for debugging
max_steps: int | None = None

data: DataloaderArgs = DataloaderArgs()
optim: OptimArgs = OptimArgs()
Expand All @@ -203,7 +288,7 @@ class TrainArgs(BaseModel):

# If set to None, eval is run locally otherwise it launches a new job with the given number of gpus
async_eval_gpus: int | None = None
eval: Any | None = None
eval: EvalArgs | None = None
eval_on_gpus: int | None = None

def dump_to_yaml_file(
Expand Down
9 changes: 7 additions & 2 deletions bytelatent/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pathlib import Path
from typing import List, Optional, Tuple

import fsspec
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
Expand All @@ -21,6 +22,7 @@
set_state_dict,
)

from bytelatent.data.file_util import get_fs
from bytelatent.distributed import get_is_master

logger = logging.getLogger("CHECKPOINT")
Expand Down Expand Up @@ -51,13 +53,14 @@ class CheckpointArgs(BaseModel):
path: str | None = None
init_ckpt_path: str | None = None
continue_training_from_init: bool = False
s3_profile: str | None = None


def _get_key_step(name: str):
return int(re.findall(RE_DIGITS, name)[-1])


def consolidate_checkpoints(ckpt_dir: str):
def consolidate_checkpoints(fs: fsspec.AbstractFileSystem, ckpt_dir: str):
"""
Consolidates all FSDP checkpoints in a directory to a single file
Consolidate checkpoint is saved in a subdirectory of ckpt_dir
Expand Down Expand Up @@ -102,15 +105,17 @@ def load_from_checkpoint(
dcp.load(state_dict, checkpoint_id=ckpt_dir)


# TODO: Rewrite the file operations here to use fsspec to enable s3 writing.
class CheckpointManager:
def __init__(self, args: CheckpointArgs):
self.path = args.path
self.fs = get_fs(self.path, s3_profile=args.s3_profile)
self.dump_every = args.dump
self.eval_every = args.eval
self.init_ckpt_path = args.init_ckpt_path
self.continue_training_from_init = args.continue_training_from_init

assert os.path.exists(
assert self.fs.exists(
self.path
), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)"

Expand Down
10 changes: 0 additions & 10 deletions bytelatent/data/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,6 @@ class BltPackTokensState(BaseModel):
n_views: int = 2


class DataLoaderState(BaseModel):
model_config = ConfigDict(extra="forbid")
multi_choice_state: MultiChoiceState
pack_tokens_state: BltPackTokensState
prefetch_state: PrefetchState


BltIterator = Iterator[tuple[BltExample, DataLoaderState]]


class BltSequence(BaseModel):
tokens: list[int]
mask: list[bool]
Expand Down
15 changes: 15 additions & 0 deletions bytelatent/data/iterators/multiprocess_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,24 @@ def __init__(
self.producer = None
self.stop_iterating_event = None
self.state_dumped_event = None
self.force_shutdown = False

def shutdown(self):
if self.producer is not None:
# This properly shuts things down
self.producer.kill()
self.force_shutdown = True

def get_state(self) -> MultiprocessIteratorState:
"""
This is slightly unusual in effectively destroying the current iterator, its necessary
to halt the background process and allow it to write the state to the main loop
in order to not lose data
"""
if self.force_shutdown:
raise ValueError(
"State will be invalid if shutdown was forced before state persisted."
)
if self.producer is None:
serialized_prefetch_buffer = json.dumps(
[b.to_python_dict() for b in self.prefetch_buffer]
Expand Down Expand Up @@ -187,6 +198,10 @@ def get_state(self) -> MultiprocessIteratorState:
)

def create_iter(self):
if self.force_shutdown:
raise ValueError(
"Iterator may be invalid if shutdown was forced before state persisted."
)
logging.info("Main thread: Creating MP iterator")
# First yield from the stored prefetch buffer.
if self.prefetch_buffer is not None:
Expand Down
Loading

0 comments on commit c2f1e48

Please sign in to comment.