This repository contains a custom implementation of the LLaMA 2 model, as described in the paper "LLaMA 2: Open Foundation and Fine-Tuned Chat Models" (ArXiv). This implementation focuses on reproducing and extending some of the key features that distinguish LLaMA 2, including RMS-Normalization, the SwiGLU activation function, Rotary Positional Embeddings (RoPE), increased context length with Grouped-Query Attention (GQA), and the KV-caching technique.
This project aims to build the LLaMA 2 architecture from scratch, incorporating essential advancements in transformer models. Key enhancements include RMS-Normalization, SwiGLU activation, Rotary Positional Embeddings, and advanced attention mechanisms like Grouped-Query Attention, all designed to improve model performance, particularly in handling longer context windows and enhancing the model's positional understanding.
- RMS-Normalization: A simplified version of layer normalization that stabilizes layer activations and aids in model convergence.
- SwiGLU Activation Function: Replaces ReLU to improve training performance through more efficient activation.
- Rotary Positional Embeddings (RoPE): Enhances positional awareness at each token by adding distance between tokens, featured in RoFormer: Enhanced Transformer with Rotary Position Embedding (ArXiv).
- Increased Context Length with GQA: Expands the context window to 4096 tokens and employs grouped-query attention for better long document processing.
- KV-Cache: A caching technique to improve decoding efficiency and speed.
- Inference with Top-P Sampling: Introduces a more dynamic sampling method that adjusts the number of tokens based on their cumulative probability.
- Data & Training Utilities: The project adds torch wrappers to onboard with pretriaing on any
.txt
file.
To install the necessary dependencies, clone this repository and run:
git clone https://github.com/abdallah197/llama2-from-scratch.git
cd llama2-from-scratch
pip install -r requirements.txt
This section guides you through the process of using the repository for inference, ensuring you can easily generate outputs from the LLaMA 2 model. Follow these steps to set up and run inference tasks:
-
Tokenizer: Begin by downloading the LLaMA 2 SentencePiece tokenizer model, necessary for preprocessing your input text. You can find the tokenizer here. Ensure that you place the downloaded model in an accessible directory within your project.
-
Model Weights: You have two options for obtaining the model weights:
- Download Pre-trained Weights: Follow the instructions provided here to download the official LLaMA model weights.
- Train Your Own Model: Alternatively, you can train your own LLaMA 2 model using this repository.
-
Configuration: Configure your inference settings in the
config.py
file. This file should include settings such as the path to the model weights, the tokenizer model, and any other inference parameters like the maximum sequence length.
Once you have set up the tokenizer and the model weights, and configured your inference settings, you can run inference by passing a list of prompts through the command line: The repo only have Top P sampling at the moment
python inference.py "Your first prompt" "Your second prompt"
The configuration for the model and training is defined using data classes in Python. You can adjust these configurations to suit your dataset and training needs. We have three main config dataclasses:
- ModelArgs.
- DataArgs and
- TrainArgs
- InferenceArgs
- DeepSpeedArgs
To adjust these configurations, modify the respective fields in the data class instances before initializing your model or training process. For instance, to increase the number of layers and attention heads, you might do:
- **ModelArgs**: Contains configurations related to the model architecture, such as the number of layers and attention heads.
- **DataArgs**: Includes settings for data processing and input data, such as the file path to the dataset.
- **TrainArgs**: Encompasses parameters for training, such as learning rate and number of epochs.
- **InferenceArgs**: Holds configurations specific to the inference process.
- **DeepSpeedArgs**: Manages configurations for distributed computing using DeepSpeed.
I adjusted the model original HP to fit my compute. Here's a summary of the main configuration settings:
- Model Dimensionality: 2048
- Number of Transformer Layers: 32
- Number of Query Attention Heads: 32
- Optional Number of Heads for Key and Value (n_kv_heads): Can be set for specific requirements
- Vocabulary Size: Set dynamically upon loading the llama2 Sentence Piece tokenizer.
- Operating Mode: 'train/inference', when choosing inference, we apply KV-Cache.
This repository supports distributed computing using DeepSpeed. To enable it, follow these steps:
- Set DeepSpeedArgs.deepspeed to True: In your configuration file (config.py), set DeepSpeedArgs.deepspeed to True to indicate that you want to utilize DeepSpeed for distributed computing.
- Populate the DeepSpeed Config File: Create or modify the deepspeed_config.json file to configure DeepSpeed according to your requirements. This file should contain settings such as the number of GPUs, optimizer parameters, and any other DeepSpeed-specific configurations.
By configuring DeepSpeedArgs and populating deepspeed_config.json, you can enable distributed computing using DeepSpeed in your training process.
This project has been inspired and informed by various resources and individuals in the AI and machine learning community. We'd like to extend our gratitude to the following:
- Andrej Karpathy for his tutorial on training a GPT from scratch. His insights into neural network architectures and training methodologies have been invaluable.
- Umar Jamil's guide on Training LLama2 from scratch. This resource provided practical insights and a foundational understanding necessary for this implementation.
- The Meta LLaMA GitHub repository has been an essential resource for understanding the intricacies of the LLaMA 2 model and its implementation.
- DeepSpeed Megatron-LM GPT2 tutorial which details how to integrate deepspeed when training a torch based model.
I am grateful for the knowledge shared by these individuals and communities, which has significantly contributed to the development of this project.