Skip to content

Commit

Permalink
Update README with DBRX (#1069)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanlint authored Mar 27, 2024
1 parent f044d6c commit 8a69bd7
Show file tree
Hide file tree
Showing 3 changed files with 286 additions and 0 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,20 @@ You'll find in this repo:
* `mcli/` - launch any of these workloads using [MCLI](https://docs.mosaicml.com/projects/mcli/en/latest/) and the [MosaicML platform](https://www.mosaicml.com/platform)
* `TUTORIAL.md` - a deeper dive into the repo, example workflows, and FAQs

# DBRX

DBRX is a state-of-the-art open source LLM trained by Databricks Mosaic team. It uses the Mixture-of-Experts (MoE) architecture and was trained with optimized versions of [Composer](https://github.com/mosaicml/composer), LLM Foundry, and [MegaBlocks](https://github.com/databricks/megablocks). The model has 132B total parameters and 36B active parameters. We have released two DBRX models:


| Model | Context Length | Download |
| ------------------ | -------------- | -------------------------------------------------- |
| DBRX Base | 32768 | https://huggingface.co/databricks/dbrx-base |
| DBRX Instruct | 32768 | https://huggingface.co/databricks/dbrx-instruct |

Our model weights and code are licensed for both researchers and commercial entities. The Databricks Open Source License can be found at [LICENSE](https://github.com/databricks/dbrx/LICENSE), and our Acceptable Use Policy can be found [here](https://www.databricks.com/legal/acceptable-use-policy-open-model).

For more information about the DBRX models, see https://github.com/databricks/dbrx.

# MPT

Mosaic Pretrained Transformers (MPT) are GPT-style models with some special features -- Flash Attention for efficiency, ALiBi for context length extrapolation, and stability improvements to mitigate loss spikes. As part of MosaicML's Foundation series, we have open-sourced several MPT models:
Expand Down
132 changes: 132 additions & 0 deletions scripts/train/yamls/finetune/dbrx-full-ft.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Note: This requires ~64x80GB GPUs
max_seq_len: 4096
icl_seq_len: 1024

# Run Name
run_name: # If left blank, will be read from env var $RUN_NAME

# Model
model:
name: hf_causal_lm
pretrained: true
init_device: mixed
use_auth_token: true
config_overrides: {}
use_flash_attention_2: true
pretrained_model_name_or_path: databricks/dbrx-instruct

# Tokenizer
tokenizer:
name: databricks/dbrx-instruct
kwargs:
model_max_length: ${max_seq_len}
trust_remote_code: true

# Dataloaders
train_loader:
name: finetuning
dataset:
split: train
hf_name: mosaicml/dolly_hhrlhf
shuffle: true
max_seq_len: ${max_seq_len}
eos_token_id: 0
packing_ratio: auto
allow_pad_trimming: false
decoder_only_format: true
drop_last: true
pin_memory: true
num_workers: 8
prefetch_factor: 2
persistent_workers: true

eval_loader:
name: finetuning
dataset:
split: test
hf_name: mosaicml/dolly_hhrlhf
shuffle: false
max_seq_len: ${max_seq_len}
packing_ratio: null
allow_pad_trimming: false
decoder_only_format: true
drop_last: true
pin_memory: true
num_workers: 8
prefetch_factor: 2
persistent_workers: true

# Optimization
optimizer:
lr: 0.000001
name: decoupled_lionw
betas:
- 0.9
- 0.95
weight_decay: 1.0e-06

scheduler:
name: cosine_with_warmup
alpha_f: 0
t_warmup: 0.02dur

algorithms:
gradient_clipping:
clipping_type: norm
clipping_threshold: 1

max_duration: 2ep
eval_interval: 1ep
global_train_batch_size: 64
eval_first: false
# eval_subset_num_batches: -1

# System
seed: 17
device_train_microbatch_size: 1
device_eval_batch_size: 1
precision: amp_bf16
autoresume: true
dist_timeout: 3600

# FSDP
fsdp_config:
mixed_precision: PURE
state_dict_type: sharded
limit_all_gathers: true
sharding_strategy: FULL_SHARD
activation_cpu_offload: false
activation_checkpointing: true
activation_checkpointing_reentrant: false

# Logging
progress_bar: false
log_to_console: true
console_log_interval: 1ba

# Callbacks
callbacks:
lr_monitor: {}
speed_monitor:
window_size: 1
memory_monitor: {}
hf_checkpointer:
overwrite: true
precision: bfloat16
save_folder: ./{run_name}/checkpoints
save_interval: 1dur
runtime_estimator: {}
# Checkpoint to local filesystem or remote object store
# save_interval: 5000ba
# save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK
# save_folder: ./{run_name}/checkpoints
# save_folder: s3://my-bucket/my-folder/{run_name}/checkpoints

# Logging
# loggers:
# wandb:
# name:
# group:
# mlflow:
# tracking_uri:
# experiment_name:
140 changes: 140 additions & 0 deletions scripts/train/yamls/finetune/dbrx-lora-ft.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Note: This requires ~16x80GB GPUs
max_seq_len: 4096
icl_seq_len: 1024

# Run Name
run_name: # If left blank, will be read from env var $RUN_NAME

# Model
model:
name: hf_causal_lm
pretrained: true
init_device: mixed
peft_config:
r: 64
peft_type: LORA
task_type: CAUSAL_LM
lora_alpha: 128
lora_dropout: 0.05
target_modules:
- Wqkv
use_auth_token: true
config_overrides: {}
use_flash_attention_2: true
pretrained_model_name_or_path: databricks/dbrx-instruct

# Tokenizer
tokenizer:
name: databricks/dbrx-instruct
kwargs:
model_max_length: ${max_seq_len}
trust_remote_code: true

# Dataloaders
train_loader:
name: finetuning
dataset:
split: train
hf_name: mosaicml/dolly_hhrlhf
shuffle: true
max_seq_len: ${max_seq_len}
eos_token_id: 0
packing_ratio: auto
allow_pad_trimming: false
decoder_only_format: true
drop_last: true
pin_memory: true
num_workers: 8
prefetch_factor: 2
persistent_workers: true

eval_loader:
name: finetuning
dataset:
split: test
hf_name: mosaicml/dolly_hhrlhf
shuffle: false
max_seq_len: ${max_seq_len}
packing_ratio: null
allow_pad_trimming: false
decoder_only_format: true
drop_last: true
pin_memory: true
num_workers: 8
prefetch_factor: 2
persistent_workers: true

# Optimization
optimizer:
lr: 0.0001
name: decoupled_lionw
betas:
- 0.9
- 0.95
weight_decay: 1.0e-06

scheduler:
name: cosine_with_warmup
alpha_f: 0
t_warmup: 0.02dur

algorithms:
gradient_clipping:
clipping_type: norm
clipping_threshold: 1

max_duration: 2ep
eval_interval: 1ep
global_train_batch_size: 16
eval_first: false
# eval_subset_num_batches: -1

# System
seed: 17
device_train_microbatch_size: 1
device_eval_batch_size: 1
precision: amp_bf16
autoresume: true
dist_timeout: 3600

# FSDP
fsdp_config:
mixed_precision: PURE
state_dict_type: sharded
limit_all_gathers: true
sharding_strategy: FULL_SHARD
activation_cpu_offload: false
activation_checkpointing: true
activation_checkpointing_reentrant: false

# Logging
progress_bar: false
log_to_console: true
console_log_interval: 1ba

# Callbacks
callbacks:
lr_monitor: {}
speed_monitor:
window_size: 1
memory_monitor: {}
hf_checkpointer:
overwrite: true
precision: bfloat16
save_folder: ./{run_name}/checkpoints
save_interval: 1dur
runtime_estimator: {}
# Checkpoint to local filesystem or remote object store
# save_interval: 5000ba
# save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK
# save_folder: ./{run_name}/checkpoints
# save_folder: s3://my-bucket/my-folder/{run_name}/checkpoints

# Logging
# loggers:
# wandb:
# name:
# group:
# mlflow:
# tracking_uri:
# experiment_name:

0 comments on commit 8a69bd7

Please sign in to comment.