Skip to content

Commit

Permalink
rename base_model to transformers_model
Browse files Browse the repository at this point in the history
  • Loading branch information
clefourrier committed Dec 12, 2024
1 parent e1bd34f commit 3eb7d0f
Show file tree
Hide file tree
Showing 10 changed files with 55 additions and 34 deletions.
6 changes: 3 additions & 3 deletions docs/source/package_reference/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@


## Accelerate and Transformers Models
### BaseModel
[[autodoc]] models.transformers.base_model.BaseModelConfig
[[autodoc]] models.transformers.base_model.BaseModel
### TransformersModel
[[autodoc]] models.transformers.base_model.TransformersModelConfig
[[autodoc]] models.transformers.base_model.TransformersModel

### AdapterModel
[[autodoc]] models.transformers.adapter_model.AdapterModelConfig
Expand Down
6 changes: 3 additions & 3 deletions src/lighteval/main_accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def accelerate( # noqa C901
from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.models.model_input import GenerationParameters
from lighteval.models.transformers.adapter_model import AdapterModelConfig
from lighteval.models.transformers.base_model import BaseModelConfig, BitsAndBytesConfig
from lighteval.models.transformers.delta_model import DeltaModelConfig
from lighteval.models.transformers.transformers_model import BitsAndBytesConfig, TransformersModelConfig
from lighteval.pipeline import EnvConfig, ParallelismManager, Pipeline, PipelineParameters

accelerator = Accelerator(kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=3000))])
Expand Down Expand Up @@ -183,13 +183,13 @@ def accelerate( # noqa C901
elif config["merged_weights"]["base_model"] not in ["", None]:
raise ValueError("You can't specify a base model if you are not using delta/adapter weights")
else:
model_config = BaseModelConfig(**args_dict)
model_config = TransformersModelConfig(**args_dict)
else:
model_args_dict: dict = {k.split("=")[0]: k.split("=")[1] if "=" in k else True for k in model_args.split(",")}
model_args_dict["accelerator"] = accelerator
model_args_dict["use_chat_template"] = use_chat_template
model_args_dict["compile"] = bool(model_args_dict["compile"]) if "compile" in model_args_dict else False
model_config = BaseModelConfig(**model_args_dict)
model_config = TransformersModelConfig(**model_args_dict)

pipeline = Pipeline(
tasks=tasks,
Expand Down
14 changes: 7 additions & 7 deletions src/lighteval/models/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
from lighteval.models.endpoints.openai_model import OpenAIClient, OpenAIModelConfig
from lighteval.models.endpoints.tgi_model import ModelClient, TGIModelConfig
from lighteval.models.transformers.adapter_model import AdapterModel, AdapterModelConfig
from lighteval.models.transformers.base_model import BaseModel, BaseModelConfig
from lighteval.models.transformers.delta_model import DeltaModel, DeltaModelConfig
from lighteval.models.transformers.transformers_model import TransformersModel, TransformersModelConfig
from lighteval.models.vllm.vllm_model import VLLMModel, VLLMModelConfig
from lighteval.utils.imports import (
NO_TGI_ERROR_MSG,
Expand All @@ -50,7 +50,7 @@

def load_model( # noqa: C901
config: Union[
BaseModelConfig,
TransformersModelConfig,
AdapterModelConfig,
DeltaModelConfig,
TGIModelConfig,
Expand All @@ -60,7 +60,7 @@ def load_model( # noqa: C901
OpenAIModelConfig,
],
env_config: EnvConfig,
) -> Union[BaseModel, AdapterModel, DeltaModel, ModelClient, DummyModel]:
) -> Union[TransformersModel, AdapterModel, DeltaModel, ModelClient, DummyModel]:
"""Will load either a model from an inference server or a model from a checkpoint, depending
on the config type.
Expand All @@ -74,7 +74,7 @@ def load_model( # noqa: C901
ValueError: If you did not specify a base model when using delta weights or adapter weights
Returns:
Union[BaseModel, AdapterModel, DeltaModel, ModelClient]: The model that will be evaluated
Union[TransformersModel, AdapterModel, DeltaModel, ModelClient]: The model that will be evaluated
"""
# Inference server loading
if isinstance(config, TGIModelConfig):
Expand All @@ -83,7 +83,7 @@ def load_model( # noqa: C901
if isinstance(config, InferenceEndpointModelConfig) or isinstance(config, InferenceModelConfig):
return load_model_with_inference_endpoints(config, env_config=env_config)

if isinstance(config, BaseModelConfig):
if isinstance(config, TransformersModelConfig):
return load_model_with_accelerate_or_default(config=config, env_config=env_config)

if isinstance(config, DummyModelConfig):
Expand Down Expand Up @@ -123,7 +123,7 @@ def load_model_with_inference_endpoints(config: InferenceEndpointModelConfig, en


def load_model_with_accelerate_or_default(
config: Union[AdapterModelConfig, BaseModelConfig, DeltaModelConfig], env_config: EnvConfig
config: Union[AdapterModelConfig, TransformersModelConfig, DeltaModelConfig], env_config: EnvConfig
):
if isinstance(config, AdapterModelConfig):
model = AdapterModel(config=config, env_config=env_config)
Expand All @@ -135,7 +135,7 @@ def load_model_with_accelerate_or_default(
model = VLLMModel(config=config, env_config=env_config)
return model
else:
model = BaseModel(config=config, env_config=env_config)
model = TransformersModel(config=config, env_config=env_config)

return model

Expand Down
2 changes: 1 addition & 1 deletion src/lighteval/models/nanotron/nanotron_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
LoglikelihoodResponse,
LoglikelihoodSingleTokenResponse,
)
from lighteval.models.transformers.base_model import LightevalModel, ModelInfo
from lighteval.models.transformers.transformers_model import LightevalModel, ModelInfo
from lighteval.tasks.requests import (
GreedyUntilRequest,
LoglikelihoodRequest,
Expand Down
6 changes: 3 additions & 3 deletions src/lighteval/models/transformers/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import torch
from transformers import AutoModelForCausalLM, PreTrainedTokenizer

from lighteval.models.transformers.base_model import BaseModel, BaseModelConfig
from lighteval.models.transformers.transformers_model import TransformersModel, TransformersModelConfig
from lighteval.models.utils import _get_dtype
from lighteval.utils.imports import NO_PEFT_ERROR_MSG, is_peft_available
from lighteval.utils.utils import EnvConfig
Expand All @@ -40,7 +40,7 @@


@dataclass
class AdapterModelConfig(BaseModelConfig):
class AdapterModelConfig(TransformersModelConfig):
# Adapter models have the specificity that they look at the base model (= the parent) for the tokenizer and config
base_model: str = None

Expand All @@ -57,7 +57,7 @@ def init_configs(self, env_config: EnvConfig):
return self._init_configs(self.base_model, env_config)


class AdapterModel(BaseModel):
class AdapterModel(TransformersModel):
def _create_auto_tokenizer(self, config: AdapterModelConfig, env_config: EnvConfig) -> PreTrainedTokenizer:
# By default, we look at the model config for the model stored in `base_model`
# (= the parent model, not the model of interest)
Expand Down
6 changes: 3 additions & 3 deletions src/lighteval/models/transformers/delta_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from tqdm import tqdm
from transformers import AutoModelForCausalLM

from lighteval.models.transformers.base_model import BaseModel, BaseModelConfig
from lighteval.models.transformers.transformers_model import TransformersModel, TransformersModelConfig
from lighteval.models.utils import _get_dtype, _get_model_sha
from lighteval.utils.utils import EnvConfig

Expand All @@ -37,7 +37,7 @@


@dataclass
class DeltaModelConfig(BaseModelConfig):
class DeltaModelConfig(TransformersModelConfig):
# Delta models look at the pretrained (= the delta weights) for the tokenizer and model config
base_model: str = None

Expand All @@ -53,7 +53,7 @@ def get_model_sha(self):
return _get_model_sha(repo_id=self.pretrained, revision="main")


class DeltaModel(BaseModel):
class DeltaModel(TransformersModel):
def _create_auto_model(
self,
config: DeltaModelConfig,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@


@dataclass
class BaseModelConfig:
class TransformersModelConfig:
"""
Base configuration class for models.
Expand Down Expand Up @@ -228,11 +228,21 @@ def get_model_sha(self):
return _get_model_sha(repo_id=self.pretrained, revision=self.revision)


class BaseModel(LightevalModel):
@dataclass
class BaseModelConfig(TransformersModelConfig):
def __post_init__(self):
super()

logger.warning(
"Careful, BaseModelConfig is deprecated and will be removed, you should use TransformersModelConfig instead!"
)


class TransformersModel(LightevalModel):
def __init__(
self,
env_config: EnvConfig,
config: BaseModelConfig,
config: TransformersModelConfig,
):
"""Initializes a HuggingFace `AutoModel` and `AutoTokenizer` for evaluation."""
self._config = config.init_configs(env_config)
Expand Down Expand Up @@ -403,7 +413,9 @@ def init_model_parallel(self, model_parallel: bool | None = None) -> Tuple[bool,
)
return model_parallel, max_mem_this_process, device_map

def _create_auto_model(self, config: BaseModelConfig, env_config: EnvConfig) -> transformers.PreTrainedModel:
def _create_auto_model(
self, config: TransformersModelConfig, env_config: EnvConfig
) -> transformers.PreTrainedModel:
"""
Creates an instance of the pretrained HF model.
Expand Down Expand Up @@ -440,7 +452,7 @@ def _create_auto_model(self, config: BaseModelConfig, env_config: EnvConfig) ->
return model

def _create_auto_tokenizer(
self, config: BaseModelConfig, env_config: EnvConfig
self, config: TransformersModelConfig, env_config: EnvConfig
) -> transformers.PreTrainedTokenizer:
return self._create_auto_tokenizer_with_name(
model_name=config.pretrained,
Expand Down Expand Up @@ -1324,6 +1336,15 @@ def _loglikelihood_single_token(
return dataset.get_original_order(res)


class BaseModel(TransformersModel):
def __post_init__(self):
super()

logger.warning(
"Careful, the BaseModel name is deprecated and will be removed, you should use TransformersModel instead!"
)


class MultiTokenEOSCriteria(transformers.StoppingCriteria):
"""Criteria to stop on the specified multi-token sequence."""

Expand Down
6 changes: 3 additions & 3 deletions src/lighteval/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.metrics.utils.metric_utils import MetricCategory
from lighteval.models.model_loader import BaseModel, load_model
from lighteval.models.model_loader import TransformersModel, load_model
from lighteval.models.model_output import ModelResponse
from lighteval.tasks.lighteval_task import LightevalTask, create_requests_from_tasks
from lighteval.tasks.registry import Registry, taskinfo_selector
Expand Down Expand Up @@ -180,10 +180,10 @@ def _init_model(self, model_config, model):
)
else:
return load_model(config=model_config, env_config=self.pipeline_parameters.env_config)
if isinstance(model, BaseModel):
if isinstance(model, TransformersModel):
return model
else:
return BaseModel.from_model(
return TransformersModel.from_model(
model=model,
use_chat_template=self.pipeline_parameters.use_chat_template,
env_config=self.pipeline_parameters.env_config,
Expand Down
6 changes: 3 additions & 3 deletions src/lighteval/tasks/lighteval_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
apply_target_perplexity_metric,
)
from lighteval.metrics.metrics import Metric, MetricCategory, Metrics
from lighteval.models.transformers.base_model import BaseModel
from lighteval.models.transformers.transformers_model import TransformersModel
from lighteval.tasks.prompt_manager import PromptManager
from lighteval.tasks.requests import (
Doc,
Expand Down Expand Up @@ -578,7 +578,7 @@ def create_requests_from_tasks( # noqa: C901
task_dict: dict[str, LightevalTask],
fewshot_dict: dict[str, list[Tuple[int, bool]]],
num_fewshot_seeds: int,
lm: BaseModel,
lm: TransformersModel,
max_samples: int | None,
evaluation_tracker: "EvaluationTracker",
use_chat_template: bool,
Expand All @@ -594,7 +594,7 @@ def create_requests_from_tasks( # noqa: C901
fewshot_dict (dict[str, list[Tuple[int, bool]]]): A dictionary of few
shot examples.
num_fewshot_seeds (int): number of few shot seeds.
lm (BaseModel): language model class that will be used to eventually
lm (TransformersModel): language model class that will be used to eventually
truncate the few shot examples (we need the maximum input size of the
model)
max_samples (int): maximum number of samples.
Expand Down
6 changes: 3 additions & 3 deletions tests/models/test_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@
# SOFTWARE.

from lighteval.models.model_loader import load_model
from lighteval.models.transformers.base_model import BaseModel, BaseModelConfig
from lighteval.models.transformers.transformers_model import TransformersModel, TransformersModelConfig
from lighteval.utils.utils import EnvConfig


def test_empty_requests():
model_config = BaseModelConfig("hf-internal-testing/tiny-random-LlamaForCausalLM")
model: BaseModel = load_model(config=model_config, env_config=EnvConfig(cache_dir="."))
model_config = TransformersModelConfig("hf-internal-testing/tiny-random-LlamaForCausalLM")
model: TransformersModel = load_model(config=model_config, env_config=EnvConfig(cache_dir="."))

assert model.loglikelihood([]) == []
assert model.loglikelihood_single_token([]) == []
Expand Down

0 comments on commit 3eb7d0f

Please sign in to comment.