Skip to content

Commit

Permalink
Merge branch 'main' into function_prompts
Browse files Browse the repository at this point in the history
  • Loading branch information
clefourrier authored Jul 4, 2024
2 parents f275f60 + 7fcaab3 commit cde6c04
Show file tree
Hide file tree
Showing 9 changed files with 109 additions and 111 deletions.
22 changes: 20 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ accelerate launch --multi_gpu --num_processes=<num_gpus> run_evals_accelerate.py
--output_dir output_dir
```

Examples of possible configuration files are provided in `examples/model_configs`.
You can find the template of the expected model configuration in [examples/model_configs/base_model.yaml_](./examples/model_configs/base_model.yaml).

### Evaluating a large model with pipeline parallelism

Expand Down Expand Up @@ -182,6 +182,25 @@ python run_evals_accelerate.py \
--output_dir output_dir
```

### Evaluate the model on a server/container.

An alternative to launching the evaluation locally is to serve the model on a TGI-compatible server/container and then run the evaluation by sending requests to the server. The command is the same as before, except you specify a path to a yaml config file (detailed below):

```shell
python run_evals_accelerate.py \
--model_config_path="/path/to/config/file"\
--tasks <task parameters> \
--output_dir output_dir
```

There are two types of configuration files that can be provided for running on the server:

1. [endpoint_model.yaml](./examples/model_configs/endpoint_model.yaml): This configuration allows you to launch the model using [HuggingFace's Inference Endpoints](https://huggingface.co/inference-endpoints/dedicated). You can specify in the configuration file all the relevant parameters, and then `lighteval` will automatically deploy the endpoint, run the evaluation, and finally delete the endpoint (unless you specify an endpoint that was already launched, in which case the endpoint won't be deleted afterwards).

2. [tgi_model.yaml](./examples/model_configs/tgi_model.yaml): This configuration lets you specify the URL of a model running in a TGI container, such as one deployed on HuggingFace's serverless inference.

Templates for these configurations can be found in [examples/model_configs](./examples/model_configs/).

### Evaluate a model on extended, community, or custom tasks.

Independently of the default tasks provided in `lighteval` that you will find in the `tasks_table.jsonl` file, you can use `lighteval` to evaluate models on tasks that require special processing (or have been added by the community). These tasks have their own evaluation suites and are defined as follows:
Expand All @@ -190,7 +209,6 @@ Independently of the default tasks provided in `lighteval` that you will find in
* `community`: tasks that have been added by the community. See the [`community_tasks`](./community_tasks) folder for examples.
* `custom`: tasks that are defined locally and not present in the core library. Use this suite if you want to experiment with designing a special metric or task.


For example, to run an extended task like `ifeval`, you can run:
```shell
python run_evals_accelerate.py \
Expand Down
5 changes: 4 additions & 1 deletion examples/model_configs/endpoint_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ model:
model: "meta-llama/Llama-2-7b-hf"
revision: "main"
dtype: "float16" # can be any of "awq", "eetq", "gptq", "4bit' or "8bit" (will use bitsandbytes), "bfloat16" or "float16"
reuse_existing: false # if true, ignore all params in instance
reuse_existing: false # if true, ignore all params in instance, and don't delete the endpoint after evaluation
instance:
accelerator: "gpu"
region: "eu-west-1"
Expand All @@ -15,5 +15,8 @@ model:
framework: "pytorch"
endpoint_type: "protected"
namespace: null # The namespace under which to launch the endopint. Defaults to the current user's namespace
image_url: null # Optionally specify the docker image to use when launching the endpoint model. E.g., launching models with later releases of the TGI container with support for newer models.
env_vars:
null # Optional environment variables to include when launching the endpoint. e.g., `MAX_INPUT_LENGTH: 2048`
generation:
add_special_tokens: true
1 change: 1 addition & 0 deletions examples/model_configs/tgi_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ model:
instance:
inference_server_address: ""
inference_server_auth: null
model_id: null # Optional, only required if the TGI container was launched with model_id pointing to a local directory
5 changes: 3 additions & 2 deletions run_evals_accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

""" Example run command:
"""Example run command:
accelerate config
accelerate launch run_evals_accelerate.py --tasks="leaderboard|hellaswag|5|1" --output_dir "/scratch/evals" --model_args "pretrained=gpt2"
"""

import argparse

from lighteval.main_accelerate import CACHE_DIR, main
Expand Down Expand Up @@ -70,7 +71,7 @@ def get_parser():
"--tasks",
type=str,
default=None,
help="Id of a task, e.g. 'original|mmlu:abstract_algebra|5' or path to a texte file with a list of tasks",
help="Id of a task, e.g. 'original|mmlu:abstract_algebra|5|0' or path to a texte file with a list of tasks",
)
parser.add_argument("--num_fewshot_seeds", type=int, default=1, help="Number of trials the few shots")
return parser
Expand Down
35 changes: 18 additions & 17 deletions src/lighteval/models/endpoint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,9 @@ def __init__(
"MAX_TOTAL_TOKENS": "2048",
"MODEL_ID": "/repository",
**config.get_dtype_args(),
**config.get_custom_env_vars(),
},
"url": "ghcr.io/huggingface/text-generation-inference:1.1.0",
"url": (config.image_url or "ghcr.io/huggingface/text-generation-inference:1.1.0"),
},
)
hlog("Deploying your endpoint. Please wait.")
Expand Down Expand Up @@ -149,7 +150,7 @@ def max_length(self):
self._max_length = 2048
return self._max_length

def __async_process_request(
def _async_process_request(
self, context: str, stop_tokens: list[str], max_tokens: int
) -> Coroutine[None, list[TextGenerationOutput], str]:
# Todo: add an option to launch with conversational instead for chat prompts
Expand All @@ -165,7 +166,7 @@ def __async_process_request(

return generated_text

def __process_request(self, context: str, stop_tokens: list[str], max_tokens: int) -> TextGenerationOutput:
def _process_request(self, context: str, stop_tokens: list[str], max_tokens: int) -> TextGenerationOutput:
# Todo: add an option to launch with conversational instead for chat prompts
# https://huggingface.co/docs/huggingface_hub/v0.20.3/en/package_reference/inference_client#huggingface_hub.AsyncInferenceClient.conversational
generated_text = self.client.text_generation(
Expand All @@ -179,13 +180,13 @@ def __process_request(self, context: str, stop_tokens: list[str], max_tokens: in

return generated_text

async def __async_process_batch_generate(
async def _async_process_batch_generate(
self,
requests: list[GreedyUntilRequest],
) -> list[TextGenerationOutput]:
return await asyncio.gather(
*[
self.__async_process_request(
self._async_process_request(
context=request.context,
stop_tokens=as_list(request.stop_sequence),
max_tokens=request.generation_size,
Expand All @@ -194,25 +195,25 @@ async def __async_process_batch_generate(
]
)

def __process_batch_generate(
def _process_batch_generate(
self,
requests: list[GreedyUntilRequest],
) -> list[TextGenerationOutput]:
return [
self.__process_request(
self._process_request(
context=request.context,
stop_tokens=as_list(request.stop_sequence),
max_tokens=request.generation_size,
)
for request in requests
]

async def __async_process_batch_logprob(
async def _async_process_batch_logprob(
self, requests: list[LoglikelihoodRequest], rolling: bool = False
) -> list[TextGenerationOutput]:
return await asyncio.gather(
*[
self.__async_process_request(
self._async_process_request(
context=request.context if rolling else request.context + request.choice,
stop_tokens=[],
max_tokens=1,
Expand All @@ -221,11 +222,11 @@ async def __async_process_batch_logprob(
]
)

def __process_batch_logprob(
def _process_batch_logprob(
self, requests: list[LoglikelihoodRequest], rolling: bool = False
) -> list[TextGenerationOutput]:
return [
self.__process_request(
self._process_request(
context=request.context if rolling else request.context + request.choice,
stop_tokens=[],
max_tokens=1,
Expand Down Expand Up @@ -267,9 +268,9 @@ def greedy_until(
)

if self.use_async:
responses = asyncio.run(self.__async_process_batch_generate(batch))
responses = asyncio.run(self._async_process_batch_generate(batch))
else:
responses = self.__process_batch_generate(batch)
responses = self._process_batch_generate(batch)
for response in responses:
results.append(
GenerateReturn(
Expand Down Expand Up @@ -303,9 +304,9 @@ def loglikelihood(

for batch in tqdm(dataloader, desc="Loglikelihoods", position=1, leave=False, disable=self.disable_tqdm):
if self.use_async:
responses = asyncio.run(self.__async_process_batch_logprob(batch))
responses = asyncio.run(self._async_process_batch_logprob(batch))
else:
responses = self.__process_batch_logprob(batch)
responses = self._process_batch_logprob(batch)
for cur_request, response in zip(batch, responses):
cont_toks = torch.tensor(cur_request.tokenized_continuation)
len_choice = len(cont_toks)
Expand Down Expand Up @@ -351,9 +352,9 @@ def loglikelihood_rolling(
dataloader, desc="Loglikelihoods, rolling", position=1, leave=False, disable=self.disable_tqdm
):
if self.use_async:
responses = asyncio.run(self.__async_process_batch_logprob(batch, rolling=True))
responses = asyncio.run(self._async_process_batch_logprob(batch, rolling=True))
else:
responses = self.__process_batch_logprob(batch, rolling=True)
responses = self._process_batch_logprob(batch, rolling=True)
for response in responses:
logits = [t.logprob for t in response.details.tokens[:-1]]

Expand Down
17 changes: 13 additions & 4 deletions src/lighteval/models/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def init_configs(self, env_config: EnvConfig):
class TGIModelConfig:
inference_server_address: str
inference_server_auth: str
model_id: str


@dataclass
Expand All @@ -224,6 +225,8 @@ class InferenceEndpointModelConfig:
add_special_tokens: bool = True
revision: str = "main"
namespace: str = None # The namespace under which to launch the endopint. Defaults to the current user's namespace
image_url: str = None
env_vars: dict = None

def get_dtype_args(self) -> Dict[str, str]:
model_dtype = self.model_dtype.lower()
Expand All @@ -237,14 +240,17 @@ def get_dtype_args(self) -> Dict[str, str]:
return {"DTYPE": model_dtype}
return {}

def get_custom_env_vars(self) -> Dict[str, str]:
return {k: str(v) for k, v in self.env_vars.items()} if self.env_vars else {}

@staticmethod
def nullable_keys() -> list[str]:
"""
Returns the list of optional keys in an endpoint model configuration. By default, the code requires that all the
keys be specified in the configuration in order to launch the endpoint. This function returns the list of keys
that are not required and can remain None.
"""
return ["namespace"]
return ["namespace", "env_vars", "image_url"]


def create_model_config(args: Namespace, accelerator: Union["Accelerator", None]) -> BaseModelConfig: # noqa: C901
Expand All @@ -271,16 +277,17 @@ def create_model_config(args: Namespace, accelerator: Union["Accelerator", None]

return BaseModelConfig(**args_dict)

if args.model_config:
if hasattr(args, "model_config") and args.model_config:
config = args.model_config["model"]
else:
with open(args.model_config_path, "r") as f:
config = yaml.safe_load(f)["model"]

if config["type"] == "tgi":
return TGIModelConfig(
inference_server_address=args["instance"]["inference_server_address"],
inference_server_auth=args["instance"]["inference_server_auth"],
inference_server_address=config["instance"]["inference_server_address"],
inference_server_auth=config["instance"]["inference_server_auth"],
model_id=config["instance"]["model_id"],
)

if config["type"] == "endpoint":
Expand All @@ -303,6 +310,8 @@ def create_model_config(args: Namespace, accelerator: Union["Accelerator", None]
instance_size=config["instance"]["instance_size"],
instance_type=config["instance"]["instance_type"],
namespace=config["instance"]["namespace"],
image_url=config["instance"].get("image_url", None),
env_vars=config["instance"].get("env_vars", None),
)
return InferenceModelConfig(model=config["base_params"]["endpoint_name"])

Expand Down
6 changes: 4 additions & 2 deletions src/lighteval/models/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,12 @@ def load_model_with_tgi(config: TGIModelConfig):
raise ImportError(NO_TGI_ERROR_MSG)

hlog(f"Load model from inference server: {config.inference_server_address}")
model = ModelClient(address=config.inference_server_address, auth_token=config.inference_server_auth)
model = ModelClient(
address=config.inference_server_address, auth_token=config.inference_server_auth, model_id=config.model_id
)
model_name = str(model.model_info["model_id"])
model_sha = model.model_info["model_sha"]
model_precision = model.model_info["dtype"]
model_precision = model.model_info["model_dtype"]
model_size = -1
model_info = ModelInfo(
model_name=model_name,
Expand Down
Loading

0 comments on commit cde6c04

Please sign in to comment.