Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update llm_engine.py #33332

Merged
merged 11 commits into from
Nov 10, 2024
79 changes: 73 additions & 6 deletions src/transformers/agents/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,25 +68,92 @@ def get_clean_message_list(message_list: List[Dict[str, str]], role_conversions:


class HfApiEngine:
"""This engine leverages Hugging Face's Inference API service, either serverless or with a dedicated endpoint."""
"""
A class to interact with Hugging Face's Inference API for language model interaction.

This engine allows you to communicate with Hugging Face's models using the Inference API.
It can be used in both serverless mode or with a dedicated endpoint, supporting features
like stop sequences and grammar customization.

def __init__(self, model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct"):
Args:
model (str, optional): The Hugging Face model ID to be used for inference. This can be a path or model
identifier from the Hugging Face model hub (default is "meta-llama/Meta-Llama-3.1-8B-Instruct").
token (str, optional): The Hugging Face API token for authentication. If not provided, the class will use
the token stored in the Hugging Face CLI configuration.
max_tokens (int, optional): The maximum number of tokens allowed in the output (default is 1500).
timeout (int, optional): Timeout for the API request, in seconds (default is 120).

Attributes:
model (str): The model ID being used for inference.
client (InferenceClient): The Hugging Face Inference API client for communicating with the language model.

Raises:
ValueError: If the model name is not provided.
"""
def __init__(
self,
model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct",
token: Optional[str] = None,
max_tokens: int = 1500,
timeout: int = 120
):
"""
Initializes the HfApiEngine.

Args:
model (str, optional): The Hugging Face model to use (default is 'meta-llama/Meta-Llama-3.1-8B-Instruct').
token (str, optional): The Hugging Face API token for authentication.
max_tokens (int, optional): The maximum number of tokens allowed in the response (default is 1500).
timeout (int, optional): The API request timeout, in seconds (default is 120).
"""
if not model:
raise ValueError("Model name must be provided.")

self.model = model
self.client = InferenceClient(self.model, timeout=120)
self.client = InferenceClient(self.model, token=token, timeout=timeout)
self.max_tokens = max_tokens

def __call__(
self, messages: List[Dict[str, str]], stop_sequences: List[str] = [], grammar: Optional[str] = None
) -> str:
"""
Processes the input messages and returns the model's response.

This method sends a list of messages to the Hugging Face Inference API, optionally
with stop sequences and grammar customization.

Args:
messages (List[Dict[str, str]]): A list of message dictionaries to be processed.
Each dictionary should have the structure {"role": "user/system", "content": "message content"}.
stop_sequences (List[str], optional): A list of strings that will stop the generation
if encountered in the model's output. Defaults to an empty list.
grammar (str, optional): The grammar or formatting structure to use in the model's response.
Default is None, which means no specific grammar.

Returns:
str: The text content of the model's response.

Examples:
>>> engine = HfApiEngine(
... model="meta-llama/Meta-Llama-3.1-8B-Instruct",
... token="your_hf_token_here",
... max_tokens=2000
... )
>>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}]
>>> response = engine(messages, stop_sequences=["END"])
>>> print(response)
"Quantum mechanics is the branch of physics that studies..."
"""
# Get clean message list
messages = get_clean_message_list(messages, role_conversions=llama_role_conversions)

# Get LLM output
# Send messages to the Hugging Face Inference API
if grammar is not None:
response = self.client.chat_completion(
messages, stop=stop_sequences, max_tokens=1500, response_format=grammar
messages, stop=stop_sequences, max_tokens=self.max_tokens, response_format=grammar
)
else:
response = self.client.chat_completion(messages, stop=stop_sequences, max_tokens=1500)
response = self.client.chat_completion(messages, stop=stop_sequences, max_tokens=self.max_tokens)

response = response.choices[0].message.content

Expand Down