Skip to content

Commit

Permalink
remove bellm
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanLee97 committed Feb 23, 2024
1 parent 74e5ba2 commit 019df4a
Show file tree
Hide file tree
Showing 6 changed files with 5 additions and 1,945 deletions.
17 changes: 5 additions & 12 deletions angle_emb/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
)
from peft.tuners.lora import LoraLayer

from . import bellm
from .utils import logger


Expand Down Expand Up @@ -990,7 +989,6 @@ class AnglE:
:param apply_bfloat16: Optional[bool]. Whether load using torch.bfloat16. Default None.
:param torch_dtype: Optional[torch.dtype]. Specify torch_dtype. Default None.
:param device: Optional[str]. Specify device. Default None.
:param bellm_class_name: Optional[str]. Specify bellm class name. Default None.
:param kbit_kwargs: Optional[Dict]. kwargs for kbit. Default None.
details refer to: https://huggingface.co/docs/peft/package_reference/peft_model#peft.prepare_model_for_kbit_training
:param **kwargs: Any.
Expand All @@ -1012,7 +1010,6 @@ def __init__(self,
apply_bfloat16: Optional[bool] = None,
torch_dtype: Optional[torch.dtype] = None,
device: Optional[str] = None,
bellm_class_name: Optional[str] = None,
kbit_kwargs: Optional[Dict] = None,
**kwargs: Any):
super().__init__()
Expand All @@ -1025,17 +1022,14 @@ def __init__(self,
self.device = device
else:
self.device = set_device()
self.is_bellm = bellm.check_bellm(bellm_class_name)
if self.is_bellm:
logger.info('BeLLM detected!')
if is_llm is None:
self.is_llm = check_llm(model_name_or_path) or self.is_bellm
self.is_llm = check_llm(model_name_or_path)
if self.is_llm:
logger.info('LLM detected, automatically set is_llm=True.'
'If it is wrong, you can manually set `is_llm`.')
self.apply_lora = apply_lora
if self.apply_lora is None:
if self.is_llm or self.is_bellm:
if self.is_llm:
self.apply_lora = True
logger.info('LLM detected, automatically set apply_lora=True.'
'If it is wrong, you can manually set `apply_lora`.')
Expand All @@ -1047,8 +1041,7 @@ def __init__(self,
self.gpu_count = 0

self.prompt = None
if self.is_llm and not self.is_bellm:
# do not set prompt for bellm
if self.is_llm:
logger.info('LLM detected, automatically set prompt. '
'You can change this setting by manually configuring the `set_prompt()` function.')
self.set_prompt()
Expand Down Expand Up @@ -1080,13 +1073,13 @@ def __init__(self,
kbit_kwargs = kbit_kwargs if kbit_kwargs is not None else {}
if self.is_llm:
device_map = "auto"
MODEL_CLASS = getattr(bellm, bellm_class_name) if self.is_bellm else AutoModelForCausalLM
MODEL_CLASS = AutoModelForCausalLM
if train_mode and self.gpu_count > 1:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
# LLM
if self.apply_lora:
lora_config['bias'] = "none"
lora_config['task_type'] = TaskType.FEATURE_EXTRACTION if self.is_bellm else TaskType.CAUSAL_LM
lora_config['task_type'] = TaskType.CAUSAL_LM

if load_kbit == 4:
model = MODEL_CLASS.from_pretrained(
Expand Down
23 changes: 0 additions & 23 deletions angle_emb/bellm/__init__.py

This file was deleted.

251 changes: 0 additions & 251 deletions angle_emb/bellm/modeling_llama.py

This file was deleted.

Loading

0 comments on commit 019df4a

Please sign in to comment.