-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Added fine-tuning L2 loss * Updated hypers * Updated default hypers * Fix of no fine-tuning weights effect * Linting fix * Added new way of updating the model state dict * Minor fix of the state_dict update * Minr fix of the state_dict naming convetion * Added all_species update on fine-tuning * Added LoRA for PET fine-tuning * Minor fix * Minor fix * Minor fix * Added parameters reset for LoRALayer * Slightly changed weights addition in LoRA * Allow mtt eval on cuda * Fixed the bug with manually provided continue_from * Fixed negative remaining scheduler steps * Fixed device mismatch while writing the outputs.xyz during mtt eval * Linting fix * Removed redundant hyperparameter * Linting fix * Fixed device choice in mtt eval * Added docs for fine-tuning * Updated docs for fine-tuning * Docs linting fix * Docs fix * Docs fix * Updated docs format * Tiny docs update * Finishing docs * Updated default hypers for PET * Removed leftouver test code * Added PEFT definition to docs
- Loading branch information
Showing
14 changed files
with
330 additions
and
48 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
Fine-tuning | ||
=========== | ||
|
||
.. warning:: | ||
|
||
This section of the documentation is only relevant for PET model so far. | ||
|
||
This section describes the process of fine-tuning a pre-trained model to | ||
adapt it to new tasks or datasets. Fine-tuning is a common technique used | ||
in transfer learning, where a model is trained on a large dataset and then | ||
fine-tuned on a smaller dataset to improve its performance on specific tasks. | ||
So far the fine-tuning capabilities are only available for PET model. | ||
|
||
|
||
Fine-Tuning PET Model with LoRA | ||
------------------------------- | ||
|
||
Fine-tuning a PET model using LoRA (Low-Rank Adaptation) can significantly | ||
enhance the model's performance on specific tasks while reducing the | ||
computational cost. Below are the steps to fine-tune a PET model from | ||
``metatrain.experimental.pet`` using LoRA. | ||
|
||
What is LoRA? | ||
^^^^^^^^^^^^^ | ||
|
||
LoRA (Low-Rank Adaptation) stands for a Parameter-Efficient Fine-Tuning (PEFT) | ||
technique used to adapt pre-trained models to new tasks by introducing low-rank | ||
matrices into the model's architecture. This approach reduces the number of | ||
trainable parameters, making the fine-tuning process more efficient and less | ||
resource-intensive. LoRA is particularly useful in scenarios where computational | ||
resources are limited or when quick adaptation to new tasks is required. | ||
|
||
Given a pre-trained model with the weights matrix :math:`W_0`, LoRA introduces | ||
low-rank matrices :math:`A` and :math:`B` of a rank :math:`r` such that the | ||
new weights matrix :math:`W` is computed as: | ||
|
||
.. math:: | ||
W = W_0 + \frac{\alpha}{r} A B | ||
where :math:`\alpha` is a regularization factor that controls the influence | ||
of the low-rank matrices on the model's weights. By adjusting the rank :math:`r` | ||
and the regularization factor :math:`\alpha`, you can fine-tune the model | ||
to achieve better performance on specific tasks. | ||
|
||
Prerequisites | ||
^^^^^^^^^^^^^ | ||
|
||
1. Train the Base Model. You can train the base model using the command: | ||
``mtt train options.yaml``. Alternatively, you can use a pre-trained | ||
foundational model, if you have access to its state dict. | ||
|
||
2. Define Paths in ``options.yaml``. Specify the paths to ``model_state_dict``, | ||
``all_species.npy``, and ``self_contributions.npy`` in the ``training`` | ||
section of the ``options.yaml`` file: | ||
|
||
.. code-block:: yaml | ||
training: | ||
MODEL_TO_START_WITH: <path_to_model_state_dict> | ||
ALL_SPECIES_PATH: <path_to_all_species.npy> | ||
SELF_CONTRIBUTIONS_PATH: <path_to_self_contributions.npy> | ||
These parameters are relevant for the outputs of the PET model. If you are | ||
not familiar with their meaning, please refer to the :ref:`architecture-pet` | ||
model documentation. | ||
|
||
|
||
3. Set the LoRA parameters in the ``architecture.model`` | ||
section of the ``options.yaml``: | ||
|
||
.. code-block:: yaml | ||
architecture: | ||
model: | ||
LORA_RANK: <desired_rank> | ||
LORA_ALPHA: <desired_alpha> | ||
USE_LORA_PEFT: True | ||
These parameters control whether to use LoRA for pre-trained model fine-tuning | ||
(``USE_LORA_PEFT``), the rank of the low-rank matrices introduced by LoRA | ||
(``LORA_RANK``), and the regularization factor for the low-rank matrices | ||
(``LORA_ALPHA``). | ||
|
||
4. Run ``mtt train options.yaml`` to fine-tune the model. | ||
|
||
Fine-Tuning Options | ||
^^^^^^^^^^^^^^^^^^^ | ||
|
||
When ``USE_LORA_PEFT`` is set to ``True``, the original model's weights will be | ||
frozen, and only the adapter layers introduced by LoRA will be trained. This | ||
allows for efficient fine-tuning with fewer parameters. If ``USE_LORA_PEFT`` is | ||
set to ``False``, all the weights of the model will be trained. This can lead to | ||
better performance on the specific task but may require more computational | ||
resources, and the model may be prone to overfitting (i.e. loosing accuracy on | ||
the original training set). | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.