Skip to content

Commit

Permalink
configure using different models for chat
Browse files Browse the repository at this point in the history
  • Loading branch information
NickSavage committed Jan 29, 2025
1 parent 93e7264 commit 49ddc9f
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 11 deletions.
63 changes: 57 additions & 6 deletions go-backend/handlers/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,16 +151,23 @@ func (s *Handler) PostChatMessageRoute(w http.ResponseWriter, r *http.Request) {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
message, err = s.GetChatCompletion(userID, message.ConversationID)

model, err := s.QueryLLMModel(message.ConfigurationID)
if err != nil {
log.Printf("error getting model: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
client := llms.NewClientFromModel(s.DB, model)

message, err = s.GetChatCompletion(userID, client, message.ConversationID)
if err != nil {
log.Printf("error getting chat completion: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

if newConversation {
client := llms.NewDefaultClient(s.DB)

summary, err := llms.CreateConversationSummary(client, message)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
Expand Down Expand Up @@ -206,8 +213,54 @@ func (s *Handler) WriteConversationSummary(userID int, summary models.Conversati
return nil
}

func (s *Handler) QueryLLMModel(configurationID int) (models.LLMModel, error) {
log.Printf("querying llm model: %v", configurationID)
query := `
SELECT
m.id,
m.name,
m.model_identifier,
m.description,
m.is_active,
p.id as provider_id,
p.name as provider_name,
p.base_url,
p.api_key_required,
p.api_key
FROM llm_models m
JOIN user_llm_configurations uc ON m.id = uc.model_id
JOIN llm_providers p ON m.provider_id = p.id
WHERE uc.id = $1
`
var model models.LLMModel
var provider models.LLMProvider

err := s.DB.QueryRow(query, configurationID).Scan(
&model.ID,
&model.Name,
&model.ModelIdentifier,
&model.Description,
&model.IsActive,
&provider.ID,
&provider.Name,
&provider.BaseURL,
&provider.APIKeyRequired,
&provider.APIKey,
)
if err != nil {
log.Printf("error querying llm model: %v", err)
return models.LLMModel{}, err
}

model.Provider = &provider
model.ProviderID = provider.ID

return model, nil
}

func (s *Handler) AddChatMessage(userID int, message models.ChatCompletion) (models.ChatCompletion, error) {
var nextSequence int

err := s.DB.QueryRow(`
SELECT COALESCE(MAX(sequence_number), 0) + 1
FROM chat_completions
Expand Down Expand Up @@ -260,7 +313,7 @@ func (s *Handler) AddChatMessage(userID int, message models.ChatCompletion) (mod
return insertedMessage, nil
}

func (s *Handler) GetChatCompletion(userID int, conversationID string) (models.ChatCompletion, error) {
func (s *Handler) GetChatCompletion(userID int, client *models.LLMClient, conversationID string) (models.ChatCompletion, error) {

messages, err := s.GetChatMessagesInConversation(userID, conversationID)
if err != nil {
Expand All @@ -280,8 +333,6 @@ func (s *Handler) GetChatCompletion(userID int, conversationID string) (models.C
return models.ChatCompletion{}, fmt.Errorf("failed to process response")
}

client := llms.NewDefaultClient(s.DB)

completion, err := llms.ChatCompletion(client, messages)
if err != nil {
log.Printf("error generating chat completion: %v", err)
Expand Down
4 changes: 2 additions & 2 deletions go-backend/llms/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func ChatCompletion(c *models.LLMClient, pastMessages []models.ChatCompletion) (
return models.ChatCompletion{
Role: "assistant",
Content: "This is a mock response for testing",
Model: models.MODEL,
Model: c.Model.ModelIdentifier,
Tokens: 100,
}, nil
}
Expand Down Expand Up @@ -82,7 +82,7 @@ func CreateConversationSummary(c *models.LLMClient, message models.ChatCompletio
ID: id,
Title: resp.Choices[0].Message.Content,
CreatedAt: created,
Model: models.MODEL,
Model: c.Model.ModelIdentifier,
}
return result, nil
}
Expand Down
11 changes: 10 additions & 1 deletion go-backend/llms/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@ import (
openai "github.com/sashabaranov/go-openai"
)

func NewClientFromModel(db *sql.DB, model models.LLMModel) *models.LLMClient {
config := openai.DefaultConfig(model.Provider.APIKey)
config.BaseURL = model.Provider.BaseURL

client := NewClient(db, config)
client.Model = &model
return client
}

func NewDefaultClient(db *sql.DB) *models.LLMClient {
config := openai.DefaultConfig(os.Getenv("ZETTEL_LLM_KEY"))
config.BaseURL = os.Getenv("ZETTEL_LLM_ENDPOINT")
Expand Down Expand Up @@ -43,7 +52,7 @@ func ExecuteLLMRequest(c *models.LLMClient, messages []openai.ChatCompletionMess
resp, err := c.Client.CreateChatCompletion(
context.Background(),
openai.ChatCompletionRequest{
Model: models.MODEL,
Model: c.Model.ModelIdentifier,
Messages: messages,
},
)
Expand Down
4 changes: 3 additions & 1 deletion go-backend/models/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ type LLMProvider struct {
Name string `json:"name"`
BaseURL string `json:"base_url"`
APIKeyRequired bool `json:"api_key_required"`
APIKey string `json:"api_key,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
Expand Down Expand Up @@ -60,6 +61,7 @@ type LLMClient struct {
Client *openai.Client
Testing bool
EmbeddingQueue *LLMRequestQueue
Model *LLMModel
}

func NewEmbeddingQueue(db *sql.DB) *LLMRequestQueue {
Expand Down Expand Up @@ -115,7 +117,7 @@ type ChatCompletion struct {
ReferencedCardPKs []int `json:"referenced_card_pks"`
ReferencedCards []PartialCard `json:"cards"`
UserQuery string `json:"user_query"`
ModelID int `json:"model_id"`
ConfigurationID int `json:"configuration_id"`
}

type ChatData struct {
Expand Down
2 changes: 1 addition & 1 deletion zettelkasten-front/src/api/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ export function postChatMessage(
conversation_id: conversationId, // Will be undefined for new conversations
user_query: content,
referenced_card_pks: contextCards?.map((card) => card.id),
model_id: configurationId,
configuration_id: configurationId,
};

return fetch(url, {
Expand Down

0 comments on commit 49ddc9f

Please sign in to comment.