From 614bb5e542024fb37c327fd416ebb999ca856d8e Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sun, 8 Dec 2024 11:20:18 +0100 Subject: [PATCH] More jinja Signed-off-by: Ettore Di Giacinto --- pkg/templates/evaluator.go | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/pkg/templates/evaluator.go b/pkg/templates/evaluator.go index 63b16f6d0689..e7c426048d65 100644 --- a/pkg/templates/evaluator.go +++ b/pkg/templates/evaluator.go @@ -95,28 +95,31 @@ func (e *Evaluator) evaluateTemplateForChatMessage(templateName string, messageD return e.cache.EvaluateTemplate(ChatMessageTemplate, templateName, messageData) } -func (e *Evaluator) templateJinjaChat(templateName string, messageData []ChatMessageTemplateData) (string, error) { +func (e *Evaluator) templateJinjaChat(templateName string, messageData []ChatMessageTemplateData, funcs []functions.Function) (string, error) { conversation := make(map[string]interface{}) messages := make([]map[string]interface{}, len(messageData)) - for _, message := range messageData { - // TODO: this is not correct, we have to map jinja tokenizer template from transformers to our own + // convert from ChatMessageTemplateData to what the jinja template expects + for _, message := range messageData { + // TODO: this seems to cover minimum text templates. Can be expanded to cover more complex interactions + var data []byte + data, _ = json.Marshal(message.FunctionCall) messages = append(messages, map[string]interface{}{ - //"role": message.Role, - "role": message.RoleName, - "content": message.Content, - "FunctionCall": message.FunctionCall, - "FunctionName": message.FunctionName, - "LastMessage": message.LastMessage, - "Function": message.Function, - "MessageIndex": message.MessageIndex, + "role": message.RoleName, + "content": message.Content, + "tool_call": string(data), }) } conversation["messages"] = messages + // if tools are detected, add these + if len(funcs) > 0 { + conversation["tools"] = funcs + } + return e.cache.EvaluateJinjaTemplate(ChatMessageTemplate, templateName, conversation) } @@ -152,7 +155,7 @@ func (e *Evaluator) TemplateMessages(messages []schema.Message, config *config.B }) } - templatedInput, err := e.templateJinjaChat(config.TemplateConfig.ChatMessage, messageData) + templatedInput, err := e.templateJinjaChat(config.TemplateConfig.ChatMessage, messageData, funcs) if err == nil { return templatedInput }