Skip to content

Commit

Permalink
feat: add tool choice definition, call options add tools and tool choice
Browse files Browse the repository at this point in the history
  • Loading branch information
N3kox committed Jan 24, 2025
1 parent 4f2aa22 commit ae56b50
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 0 deletions.
27 changes: 27 additions & 0 deletions components/model/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package model

import "github.com/cloudwego/eino/schema"

// Options is the common options for the model.
type Options struct {
// Temperature is the temperature for the model, which controls the randomness of the model.
Expand All @@ -28,6 +30,10 @@ type Options struct {
TopP *float32
// Stop is the stop words for the model, which controls the stopping condition of the model.
Stop []string
// Tools is a list of tools the model may call.
Tools []*schema.ToolInfo
// ToolChoice controls which tool is called by the model.
ToolChoice *schema.ToolChoice
}

// Option is the call option for ChatModel component.
Expand Down Expand Up @@ -82,6 +88,27 @@ func WithStop(stop []string) Option {
}
}

// WithTools is the option to set tools for the model.
func WithTools(tools []*schema.ToolInfo) Option {
if tools == nil {
tools = []*schema.ToolInfo{}
}
return Option{
apply: func(opts *Options) {
opts.Tools = tools
},
}
}

// WithToolChoice is the option to set tool choice for the model.
func WithToolChoice(toolChoice schema.ToolChoice) Option {
return Option{
apply: func(opts *Options) {
opts.ToolChoice = &toolChoice
},
}
}

// WrapImplSpecificOptFn is the option to wrap the implementation specific option function.
func WrapImplSpecificOptFn[T any](optFn func(*T)) Option {
return Option{
Expand Down
22 changes: 22 additions & 0 deletions components/model/option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package model
import (
"testing"

"github.com/cloudwego/eino/schema"
"github.com/smartystreets/goconvey/convey"
)

Expand All @@ -33,6 +34,8 @@ func TestOptions(t *testing.T) {
defaultTemperature float32 = 1.0
defaultMaxTokens = 1000
defaultTopP float32 = 0.5
tools = []*schema.ToolInfo{{Name: "asd"}, {Name: "qwe"}}
toolChoice = schema.ToolChoiceForced
)

opts := GetCommonOptions(
Expand All @@ -47,6 +50,8 @@ func TestOptions(t *testing.T) {
WithMaxTokens(maxToken),
WithTopP(topP),
WithStop([]string{"hello", "bye"}),
WithTools(tools),
WithToolChoice(toolChoice),
)

convey.So(opts, convey.ShouldResemble, &Options{
Expand All @@ -55,8 +60,25 @@ func TestOptions(t *testing.T) {
MaxTokens: &maxToken,
TopP: &topP,
Stop: []string{"hello", "bye"},
Tools: tools,
ToolChoice: &toolChoice,
})
})

convey.Convey("test nil tools option", t, func() {
opts := GetCommonOptions(
&Options{
Tools: []*schema.ToolInfo{
{Name: "asd"},
{Name: "qwe"},
},
},
WithTools(nil),
)

convey.So(opts.Tools, convey.ShouldNotBeNil)
convey.So(len(opts.Tools), convey.ShouldEqual, 0)
})
}

type implOption struct {
Expand Down
17 changes: 17 additions & 0 deletions schema/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,23 @@ const (
Boolean DataType = "boolean"
)

// ToolChoice controls how the model calls tools (if any).
type ToolChoice string

const (
// ToolChoiceForbidden indicates that the model should not call any tools.
// Corresponds to "none" in OpenAI Chat Completion.
ToolChoiceForbidden ToolChoice = "forbidden"

// ToolChoiceAllowed indicates that the model can choose to generate a message or call one or more tools.
// Corresponds to "auto" in OpenAI Chat Completion.
ToolChoiceAllowed ToolChoice = "allowed"

// ToolChoiceForced indicates that the model must call one or more tools.
// Corresponds to "required" in OpenAI Chat Completion.
ToolChoiceForced ToolChoice = "forced"
)

// ToolInfo is the information of a tool.
type ToolInfo struct {
// The unique name of the tool that clearly communicates its purpose.
Expand Down

0 comments on commit ae56b50

Please sign in to comment.