-
-
Notifications
You must be signed in to change notification settings - Fork 896
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* orpo trainer * rl handling for orpo * support for remove_unused_columns * orpo fixes * fix loader for orpo * chore: lint * fix default for remove_unused_columns * roll ORPO into the main AxolotlTrainer so it can be compatible with some of the other techniques like relora * better handling of system message for orpo * revert system prompt changes for chat templtes * no need for else condition * split dataset parsing into it's own component
- Loading branch information
Showing
14 changed files
with
451 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
""" | ||
module for base dataset transform strategies | ||
""" | ||
|
||
import importlib | ||
import logging | ||
|
||
LOG = logging.getLogger("axolotl") | ||
|
||
|
||
def load(strategy, cfg, module_base=None, **kwargs): | ||
try: | ||
load_fn = strategy.split(".")[-1] | ||
strategy = ".".join(strategy.split(".")[:-1]) | ||
mod = importlib.import_module(f".{strategy}", module_base) | ||
func = getattr(mod, load_fn) | ||
return func(cfg, **kwargs) | ||
except Exception: # pylint: disable=broad-exception-caught | ||
LOG.warning(f"unable to load strategy {strategy}") | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,8 @@ | ||
""" | ||
module for DPO style dataset transform strategies | ||
""" | ||
from functools import partial | ||
|
||
import importlib | ||
import logging | ||
from ..base import load as load_base | ||
|
||
LOG = logging.getLogger("axolotl") | ||
|
||
|
||
def load(strategy, cfg, **kwargs): | ||
try: | ||
load_fn = strategy.split(".")[-1] | ||
strategy = ".".join(strategy.split(".")[:-1]) | ||
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies.dpo") | ||
func = getattr(mod, load_fn) | ||
return func(cfg, **kwargs) | ||
except Exception: # pylint: disable=broad-exception-caught | ||
LOG.warning(f"unable to load strategy {strategy}") | ||
return None | ||
load = partial(load_base, module="axolotl.prompt_strategies.dpo") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
""" | ||
module for ORPO style dataset transform strategies | ||
""" | ||
|
||
from functools import partial | ||
|
||
from ..base import load as load_base | ||
|
||
load = partial(load_base, module="axolotl.prompt_strategies.orpo") |
Oops, something went wrong.