From 8dcd40ac7844d4c520d6b5885be5ccb080ed0b55 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 15 Sep 2023 04:03:32 -0400 Subject: [PATCH] prevent cli functions from getting fired on import (#581) --- src/axolotl/cli/inference.py | 3 ++- src/axolotl/cli/merge_lora.py | 3 ++- src/axolotl/cli/shard.py | 3 ++- src/axolotl/cli/train.py | 3 ++- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/axolotl/cli/inference.py b/src/axolotl/cli/inference.py index 1a5a1a2be4..f3daac83dd 100644 --- a/src/axolotl/cli/inference.py +++ b/src/axolotl/cli/inference.py @@ -23,4 +23,5 @@ def do_cli(config: Path = Path("examples/"), **kwargs): do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args) -fire.Fire(do_cli) +if __name__ == "__main__": + fire.Fire(do_cli) diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index 473aa8260c..79b7112b56 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -23,4 +23,5 @@ def do_cli(config: Path = Path("examples/"), **kwargs): do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) -fire.Fire(do_cli) +if __name__ == "__main__": + fire.Fire(do_cli) diff --git a/src/axolotl/cli/shard.py b/src/axolotl/cli/shard.py index ad7d9a1368..85901b0f2a 100644 --- a/src/axolotl/cli/shard.py +++ b/src/axolotl/cli/shard.py @@ -38,4 +38,5 @@ def do_cli(config: Path = Path("examples/"), **kwargs): shard(cfg=parsed_cfg, cli_args=parsed_cli_args) -fire.Fire(do_cli) +if __name__ == "__main__": + fire.Fire(do_cli) diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 166af2595b..72a9250c8d 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -32,4 +32,5 @@ def do_cli(config: Path = Path("examples/"), **kwargs): train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) -fire.Fire(do_cli) +if __name__ == "__main__": + fire.Fire(do_cli)