Skip to content

Commit

Permalink
Add LoRA fine-tuning for PET (#379)
Browse files Browse the repository at this point in the history
* Added fine-tuning L2 loss

* Updated hypers

* Updated default hypers

* Fix of no fine-tuning weights effect

* Linting fix

* Added new way of updating the model state dict

* Minor fix of the state_dict update

* Minr fix of the state_dict naming convetion

* Added all_species update on fine-tuning

* Added LoRA for PET fine-tuning

* Minor fix

* Minor fix

* Minor fix

* Added parameters reset for LoRALayer

* Slightly changed weights addition in LoRA

* Allow mtt eval on cuda

* Fixed the bug with manually provided continue_from

* Fixed negative remaining scheduler steps

* Fixed device mismatch while writing the outputs.xyz during mtt eval

* Linting fix

* Removed redundant hyperparameter

* Linting fix

* Fixed device choice in mtt eval

* Added docs for fine-tuning

* Updated docs for fine-tuning

* Docs linting fix

* Docs fix

* Docs fix

* Updated docs format

* Tiny docs update

* Finishing docs

* Updated default hypers for PET

* Removed leftouver test code

* Added PEFT definition to docs
  • Loading branch information
abmazitov authored Nov 8, 2024
1 parent 5eefa32 commit a69d9ab
Show file tree
Hide file tree
Showing 14 changed files with 330 additions and 48 deletions.
97 changes: 97 additions & 0 deletions docs/src/advanced-concepts/fine-tuning.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
Fine-tuning
===========

.. warning::

This section of the documentation is only relevant for PET model so far.

This section describes the process of fine-tuning a pre-trained model to
adapt it to new tasks or datasets. Fine-tuning is a common technique used
in transfer learning, where a model is trained on a large dataset and then
fine-tuned on a smaller dataset to improve its performance on specific tasks.
So far the fine-tuning capabilities are only available for PET model.


Fine-Tuning PET Model with LoRA
-------------------------------

Fine-tuning a PET model using LoRA (Low-Rank Adaptation) can significantly
enhance the model's performance on specific tasks while reducing the
computational cost. Below are the steps to fine-tune a PET model from
``metatrain.experimental.pet`` using LoRA.

What is LoRA?
^^^^^^^^^^^^^

LoRA (Low-Rank Adaptation) stands for a Parameter-Efficient Fine-Tuning (PEFT)
technique used to adapt pre-trained models to new tasks by introducing low-rank
matrices into the model's architecture. This approach reduces the number of
trainable parameters, making the fine-tuning process more efficient and less
resource-intensive. LoRA is particularly useful in scenarios where computational
resources are limited or when quick adaptation to new tasks is required.

Given a pre-trained model with the weights matrix :math:`W_0`, LoRA introduces
low-rank matrices :math:`A` and :math:`B` of a rank :math:`r` such that the
new weights matrix :math:`W` is computed as:

.. math::
W = W_0 + \frac{\alpha}{r} A B
where :math:`\alpha` is a regularization factor that controls the influence
of the low-rank matrices on the model's weights. By adjusting the rank :math:`r`
and the regularization factor :math:`\alpha`, you can fine-tune the model
to achieve better performance on specific tasks.

Prerequisites
^^^^^^^^^^^^^

1. Train the Base Model. You can train the base model using the command:
``mtt train options.yaml``. Alternatively, you can use a pre-trained
foundational model, if you have access to its state dict.

2. Define Paths in ``options.yaml``. Specify the paths to ``model_state_dict``,
``all_species.npy``, and ``self_contributions.npy`` in the ``training``
section of the ``options.yaml`` file:

.. code-block:: yaml
training:
MODEL_TO_START_WITH: <path_to_model_state_dict>
ALL_SPECIES_PATH: <path_to_all_species.npy>
SELF_CONTRIBUTIONS_PATH: <path_to_self_contributions.npy>
These parameters are relevant for the outputs of the PET model. If you are
not familiar with their meaning, please refer to the :ref:`architecture-pet`
model documentation.


3. Set the LoRA parameters in the ``architecture.model``
section of the ``options.yaml``:

.. code-block:: yaml
architecture:
model:
LORA_RANK: <desired_rank>
LORA_ALPHA: <desired_alpha>
USE_LORA_PEFT: True
These parameters control whether to use LoRA for pre-trained model fine-tuning
(``USE_LORA_PEFT``), the rank of the low-rank matrices introduced by LoRA
(``LORA_RANK``), and the regularization factor for the low-rank matrices
(``LORA_ALPHA``).

4. Run ``mtt train options.yaml`` to fine-tune the model.

Fine-Tuning Options
^^^^^^^^^^^^^^^^^^^

When ``USE_LORA_PEFT`` is set to ``True``, the original model's weights will be
frozen, and only the adapter layers introduced by LoRA will be trained. This
allows for efficient fine-tuning with fewer parameters. If ``USE_LORA_PEFT`` is
set to ``False``, all the weights of the model will be trained. This can lead to
better performance on the specific task but may require more computational
resources, and the model may be prone to overfitting (i.e. loosing accuracy on
the original training set).

1 change: 1 addition & 0 deletions docs/src/advanced-concepts/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ such as output naming, auxiliary outputs, and wrapper models.
auxiliary-outputs
multi-gpu
auto-restarting
fine-tuning
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ alchemical-model = [
"torch_alchemical @ git+https://github.com/abmazitov/torch_alchemical.git@51ff519",
]
pet = [
"pet @ git+https://github.com/lab-cosmo/pet@7eddb2e",
"pet @ git+https://github.com/lab-cosmo/pet@ee90692",
]
gap = [
"rascaline-torch @ git+https://github.com/luthaf/rascaline@5326b6e#subdirectory=python/rascaline-torch",
Expand Down
6 changes: 5 additions & 1 deletion src/metatrain/cli/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,11 @@ def _eval_targets(
# Infer the device and dtype from the model
model_tensor = next(itertools.chain(model.parameters(), model.buffers()))
dtype = model_tensor.dtype
device = model_tensor.device
device = "cpu"
if torch.cuda.is_available() and "cuda" in model.capabilities().supported_devices:
device = "cuda"
logger.info(f"Running on device {device} with dtype {dtype}")
model.to(dtype=dtype, device=device)

# Create a dataloader
dataloader = torch.utils.data.DataLoader(
Expand Down
2 changes: 2 additions & 0 deletions src/metatrain/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def _process_continue_from(continue_from: str) -> Optional[str]:
# process and the other processes might detect it by mistake if they're
# still executing this function
time.sleep(3)
else:
new_continue_from = continue_from

return new_continue_from

Expand Down
9 changes: 7 additions & 2 deletions src/metatrain/experimental/pet/default-hypers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,14 @@ architecture:
TARGET_INDEX_KEY: target_index
RESIDUAL_FACTOR: 0.5
USE_ZBL: False
USE_LORA_PEFT: False
LORA_RANK: 4
LORA_ALPHA: 0.5

training:
INITIAL_LR: 1e-4
EPOCH_NUM: 1000
EPOCHS_WARMUP: 50
EPOCH_NUM_ATOMIC: 1000000000
EPOCHS_WARMUP_ATOMIC: 100000000
SCHEDULER_STEP_SIZE_ATOMIC: 500000000 # structural version is called "SCHEDULER_STEP_SIZE"
GLOBAL_AUG: True
SLIDING_FACTOR: 0.7
Expand All @@ -50,6 +53,8 @@ architecture:
RANDOM_SEED: 0
CUDA_DETERMINISTIC: False
MODEL_TO_START_WITH: null
ALL_SPECIES_PATH: null
SELF_CONTRIBUTIONS_PATH: null
SUPPORT_MISSING_VALUES: False
USE_WEIGHT_DECAY: False
WEIGHT_DECAY: 0.0
Expand Down
12 changes: 7 additions & 5 deletions src/metatrain/experimental/pet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@

from ...utils.additive import ZBL
from ...utils.dtype import dtype_to_str
from .utils import systems_to_batch_dict
from .utils import systems_to_batch_dict, update_state_dict
from .utils.fine_tuning import LoRAWrapper


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -149,11 +150,12 @@ def load_checkpoint(cls, path: Union[str, Path]) -> "PET":

ARCHITECTURAL_HYPERS = Hypers(model.hypers)
raw_pet = RawPET(ARCHITECTURAL_HYPERS, 0.0, len(model.atomic_types))
if ARCHITECTURAL_HYPERS.USE_LORA_PEFT:
lora_rank = ARCHITECTURAL_HYPERS.LORA_RANK
lora_alpha = ARCHITECTURAL_HYPERS.LORA_ALPHA
raw_pet = LoRAWrapper(raw_pet, lora_rank, lora_alpha)

new_state_dict = {}
for name, value in state_dict.items():
name = name.replace("model.pet_model.", "")
new_state_dict[name] = value
new_state_dict = update_state_dict(state_dict)

dtype = next(iter(new_state_dict.values())).dtype
raw_pet.to(dtype).load_state_dict(new_state_dict)
Expand Down
29 changes: 29 additions & 0 deletions src/metatrain/experimental/pet/schema-hypers.json
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,15 @@
},
"USE_ZBL": {
"type": "boolean"
},
"USE_LORA_PEFT": {
"type": "boolean"
},
"LORA_RANK": {
"type": "integer"
},
"LORA_ALPHA": {
"type": "number"
}
},
"additionalProperties": false
Expand Down Expand Up @@ -191,6 +200,26 @@
}
]
},
"ALL_SPECIES_PATH": {
"oneOf": [
{
"type": "string"
},
{
"type": "null"
}
]
},
"SELF_CONTRIBUTIONS_PATH": {
"oneOf": [
{
"type": "string"
},
{
"type": "null"
}
]
},
"SUPPORT_MISSING_VALUES": {
"type": "boolean"
},
Expand Down
Loading

0 comments on commit a69d9ab

Please sign in to comment.