Skip to content

Commit

Permalink
More jinja
Browse files Browse the repository at this point in the history
Signed-off-by: Ettore Di Giacinto <[email protected]>
  • Loading branch information
mudler committed Dec 8, 2024
1 parent b7d9988 commit 614bb5e
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions pkg/templates/evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,28 +95,31 @@ func (e *Evaluator) evaluateTemplateForChatMessage(templateName string, messageD
return e.cache.EvaluateTemplate(ChatMessageTemplate, templateName, messageData)
}

func (e *Evaluator) templateJinjaChat(templateName string, messageData []ChatMessageTemplateData) (string, error) {
func (e *Evaluator) templateJinjaChat(templateName string, messageData []ChatMessageTemplateData, funcs []functions.Function) (string, error) {

conversation := make(map[string]interface{})
messages := make([]map[string]interface{}, len(messageData))

for _, message := range messageData {
// TODO: this is not correct, we have to map jinja tokenizer template from transformers to our own
// convert from ChatMessageTemplateData to what the jinja template expects

for _, message := range messageData {
// TODO: this seems to cover minimum text templates. Can be expanded to cover more complex interactions
var data []byte
data, _ = json.Marshal(message.FunctionCall)
messages = append(messages, map[string]interface{}{
//"role": message.Role,
"role": message.RoleName,
"content": message.Content,
"FunctionCall": message.FunctionCall,
"FunctionName": message.FunctionName,
"LastMessage": message.LastMessage,
"Function": message.Function,
"MessageIndex": message.MessageIndex,
"role": message.RoleName,
"content": message.Content,
"tool_call": string(data),
})
}

conversation["messages"] = messages

// if tools are detected, add these
if len(funcs) > 0 {
conversation["tools"] = funcs
}

return e.cache.EvaluateJinjaTemplate(ChatMessageTemplate, templateName, conversation)
}

Expand Down Expand Up @@ -152,7 +155,7 @@ func (e *Evaluator) TemplateMessages(messages []schema.Message, config *config.B
})
}

templatedInput, err := e.templateJinjaChat(config.TemplateConfig.ChatMessage, messageData)
templatedInput, err := e.templateJinjaChat(config.TemplateConfig.ChatMessage, messageData, funcs)
if err == nil {
return templatedInput
}
Expand Down

0 comments on commit 614bb5e

Please sign in to comment.