Skip to content

Commit

Permalink
Wire through the shell name into AI suggestions so that we can get mo…
Browse files Browse the repository at this point in the history
…re precise AI suggestions for the current shell
  • Loading branch information
ddworken committed Feb 19, 2024
1 parent 339da47 commit 0787840
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 34 deletions.
19 changes: 7 additions & 12 deletions client/ai/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,20 @@ import (

var mostRecentQuery string

func DebouncedGetAiSuggestions(ctx context.Context, query string, numberCompletions int) ([]string, error) {
func DebouncedGetAiSuggestions(ctx context.Context, shellName, query string, numberCompletions int) ([]string, error) {
mostRecentQuery = query
time.Sleep(time.Millisecond * 300)
if mostRecentQuery == query {
return GetAiSuggestions(ctx, query, numberCompletions)
return GetAiSuggestions(ctx, shellName, query, numberCompletions)
}
return nil, nil
}

func GetAiSuggestions(ctx context.Context, query string, numberCompletions int) ([]string, error) {
func GetAiSuggestions(ctx context.Context, shellName, query string, numberCompletions int) ([]string, error) {
if os.Getenv("OPENAI_API_KEY") == "" {
return GetAiSuggestionsViaHishtoryApi(ctx, query, numberCompletions)
return GetAiSuggestionsViaHishtoryApi(ctx, shellName, query, numberCompletions)
} else {
suggestions, _, err := ai.GetAiSuggestionsViaOpenAiApi(query, getShellName(), getOsName(), numberCompletions)
suggestions, _, err := ai.GetAiSuggestionsViaOpenAiApi(query, shellName, getOsName(), numberCompletions)
return suggestions, err
}
}
Expand All @@ -55,20 +55,15 @@ func getOsName() string {
}
}

func getShellName() string {
// TODO: Wire the real shell name in here
return "bash"
}

func GetAiSuggestionsViaHishtoryApi(ctx context.Context, query string, numberCompletions int) ([]string, error) {
func GetAiSuggestionsViaHishtoryApi(ctx context.Context, shellName, query string, numberCompletions int) ([]string, error) {
hctx.GetLogger().Infof("Running OpenAI query for %#v", query)
req := ai.AiSuggestionRequest{
DeviceId: hctx.GetConf(ctx).DeviceId,
UserId: data.UserId(hctx.GetConf(ctx).UserSecret),
Query: query,
NumberCompletions: numberCompletions,
OsName: getOsName(),
ShellName: getShellName(),
ShellName: shellName,
}
reqData, err := json.Marshal(req)
if err != nil {
Expand Down
6 changes: 5 additions & 1 deletion client/cmd/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ var tqueryCmd = &cobra.Command{
DisableFlagParsing: true,
Run: func(cmd *cobra.Command, args []string) {
ctx := hctx.MakeContext()
lib.CheckFatalError(tui.TuiQuery(ctx, strings.Join(args, " ")))
shellName := "bash"
if os.Getenv("HISHTORY_SHELL_NAME") != "" {
shellName = os.Getenv("HISHTORY_SHELL_NAME")
}
lib.CheckFatalError(tui.TuiQuery(ctx, shellName, strings.Join(args, " ")))
},
}

Expand Down
2 changes: 1 addition & 1 deletion client/lib/config.fish
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ end
function __hishtory_on_control_r
set -l tmp (mktemp -t fish.XXXXXX)
set -x init_query (commandline -b)
HISHTORY_TERM_INTEGRATION=1 hishtory tquery $init_query > $tmp
HISHTORY_TERM_INTEGRATION=1 HISHTORY_SHELL_NAME=fish hishtory tquery $init_query > $tmp
set -l res $status
commandline -f repaint
if [ -s $tmp ]
Expand Down
2 changes: 1 addition & 1 deletion client/lib/config.sh
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ PROMPT_COMMAND="__hishtory_postcommand; $PROMPT_COMMAND"
export HISTTIMEFORMAT=$HISTTIMEFORMAT

__history_control_r() {
READLINE_LINE=$(HISHTORY_TERM_INTEGRATION=1 hishtory tquery "$READLINE_LINE")
READLINE_LINE=$(HISHTORY_TERM_INTEGRATION=1 HISHTORY_SHELL_NAME=bash hishtory tquery "$READLINE_LINE")
READLINE_POINT=0x7FFFFFFF
}

Expand Down
2 changes: 1 addition & 1 deletion client/lib/config.zsh
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ function _hishtory_precmd() {
}

_hishtory_widget() {
BUFFER=$(HISHTORY_TERM_INTEGRATION=1 hishtory tquery $BUFFER)
BUFFER=$(HISHTORY_TERM_INTEGRATION=1 HISHTORY_SHELL_NAME=zsh hishtory tquery $BUFFER)
CURSOR=${#BUFFER}
zle reset-prompt
}
Expand Down
39 changes: 21 additions & 18 deletions client/tui/tui.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ type model struct {

// A banner from the backend to be displayed. Generally an empty string.
banner string

// The currently executing shell. Defaults to bash if not specified. Used for more precise AI suggestions.
shellName string
}

type doneDownloadingMsg struct{}
Expand All @@ -205,7 +208,7 @@ type asyncQueryFinishedMsg struct {
overriddenSearchQuery *string
}

func initialModel(ctx context.Context, initialQuery string) model {
func initialModel(ctx context.Context, shellName, initialQuery string) model {
s := spinner.New()
s.Spinner = spinner.Dot
s.Style = lipgloss.NewStyle().Foreground(lipgloss.Color("205"))
Expand All @@ -231,7 +234,7 @@ func initialModel(ctx context.Context, initialQuery string) model {
queryInput.SetValue(initialQuery)
}
CURRENT_QUERY_FOR_HIGHLIGHTING = initialQuery
return model{ctx: ctx, spinner: s, isLoading: true, table: nil, tableEntries: []*data.HistoryEntry{}, runQuery: &initialQuery, queryInput: queryInput, help: help.New()}
return model{ctx: ctx, spinner: s, isLoading: true, table: nil, tableEntries: []*data.HistoryEntry{}, runQuery: &initialQuery, queryInput: queryInput, help: help.New(), shellName: shellName}
}

func (m model) Init() tea.Cmd {
Expand All @@ -252,7 +255,7 @@ func updateTable(m model, rows []table.Row, entries []*data.HistoryEntry, search
initialCursor = m.table.Cursor()
}
if forceUpdateTable || m.table == nil {
t, err := makeTable(m.ctx, rows)
t, err := makeTable(m.ctx, m.shellName, rows)
if err != nil {
m.fatalErr = err
return m
Expand Down Expand Up @@ -299,7 +302,7 @@ func runQueryAndUpdateTable(m model, forceUpdateTable, maintainCursor bool) tea.
// The default filter was cleared for this session, so don't apply it
defaultFilter = ""
}
rows, entries, searchErr := getRows(m.ctx, conf.DisplayedColumns, defaultFilter, query, PADDED_NUM_ENTRIES)
rows, entries, searchErr := getRows(m.ctx, conf.DisplayedColumns, m.shellName, defaultFilter, query, PADDED_NUM_ENTRIES)
return asyncQueryFinishedMsg{queryId, rows, entries, searchErr, forceUpdateTable, maintainCursor, nil}
}
}
Expand Down Expand Up @@ -493,8 +496,8 @@ func renderNullableTable(m model, helpText string) string {
return baseStyle.Render(m.table.View())
}

func getRowsFromAiSuggestions(ctx context.Context, columnNames []string, query string) ([]table.Row, []*data.HistoryEntry, error) {
suggestions, err := ai.DebouncedGetAiSuggestions(ctx, strings.TrimPrefix(query, "?"), 5)
func getRowsFromAiSuggestions(ctx context.Context, columnNames []string, shellName, query string) ([]table.Row, []*data.HistoryEntry, error) {
suggestions, err := ai.DebouncedGetAiSuggestions(ctx, shellName, strings.TrimPrefix(query, "?"), 5)
if err != nil {
hctx.GetLogger().Infof("failed to get AI query suggestions: %v", err)
return nil, nil, fmt.Errorf("failed to get AI query suggestions: %w", err)
Expand Down Expand Up @@ -525,11 +528,11 @@ func getRowsFromAiSuggestions(ctx context.Context, columnNames []string, query s
return rows, entries, nil
}

func getRows(ctx context.Context, columnNames []string, defaultFilter, query string, numEntries int) ([]table.Row, []*data.HistoryEntry, error) {
func getRows(ctx context.Context, columnNames []string, shellName, defaultFilter, query string, numEntries int) ([]table.Row, []*data.HistoryEntry, error) {
db := hctx.GetDb(ctx)
config := hctx.GetConf(ctx)
if config.AiCompletion && !config.IsOffline && strings.HasPrefix(query, "?") && len(query) > 1 {
return getRowsFromAiSuggestions(ctx, columnNames, query)
return getRowsFromAiSuggestions(ctx, columnNames, shellName, query)
}
searchResults, err := lib.Search(ctx, db, defaultFilter+" "+query, numEntries)
if err != nil {
Expand Down Expand Up @@ -588,10 +591,10 @@ func getTerminalSize() (int, int, error) {

var bigQueryResults []table.Row

func makeTableColumns(ctx context.Context, columnNames []string, rows []table.Row) ([]table.Column, error) {
func makeTableColumns(ctx context.Context, shellName string, columnNames []string, rows []table.Row) ([]table.Column, error) {
// Handle an initial query with no results
if len(rows) == 0 || len(rows[0]) == 0 {
allRows, _, err := getRows(ctx, columnNames, hctx.GetConf(ctx).DefaultFilter, "", 25)
allRows, _, err := getRows(ctx, columnNames, shellName, hctx.GetConf(ctx).DefaultFilter, "", 25)
if err != nil {
return nil, err
}
Expand All @@ -604,7 +607,7 @@ func makeTableColumns(ctx context.Context, columnNames []string, rows []table.Ro
}
allRows = append(allRows, row)
}
return makeTableColumns(ctx, columnNames, allRows)
return makeTableColumns(ctx, shellName, columnNames, allRows)
}

// Calculate the minimum amount of space that we need for each column for the current actual search
Expand All @@ -617,7 +620,7 @@ func makeTableColumns(ctx context.Context, columnNames []string, rows []table.Ro

// Calculate the maximum column width that is useful for each column if we search for the empty string
if bigQueryResults == nil {
bigRows, _, err := getRows(ctx, columnNames, "", "", 1000)
bigRows, _, err := getRows(ctx, columnNames, shellName, "", "", 1000)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -678,9 +681,9 @@ func min(a, b int) int {
return b
}

func makeTable(ctx context.Context, rows []table.Row) (table.Model, error) {
func makeTable(ctx context.Context, shellName string, rows []table.Row) (table.Model, error) {
config := hctx.GetConf(ctx)
columns, err := makeTableColumns(ctx, config.DisplayedColumns, rows)
columns, err := makeTableColumns(ctx, shellName, config.DisplayedColumns, rows)
if err != nil {
return table.Model{}, err
}
Expand Down Expand Up @@ -887,22 +890,22 @@ func configureColorProfile(ctx context.Context) {
}
}

func TuiQuery(ctx context.Context, initialQuery string) error {
func TuiQuery(ctx context.Context, shellName, initialQuery string) error {
configureColorProfile(ctx)
p := tea.NewProgram(initialModel(ctx, initialQuery), tea.WithOutput(os.Stderr))
p := tea.NewProgram(initialModel(ctx, shellName, initialQuery), tea.WithOutput(os.Stderr))
// Async: Get the initial set of rows
go func() {
LAST_DISPATCHED_QUERY_ID++
queryId := LAST_DISPATCHED_QUERY_ID
LAST_DISPATCHED_QUERY_TIMESTAMP = time.Now()
conf := hctx.GetConf(ctx)
rows, entries, err := getRows(ctx, conf.DisplayedColumns, conf.DefaultFilter, initialQuery, PADDED_NUM_ENTRIES)
rows, entries, err := getRows(ctx, conf.DisplayedColumns, shellName, conf.DefaultFilter, initialQuery, PADDED_NUM_ENTRIES)
if err == nil || initialQuery == "" {
p.Send(asyncQueryFinishedMsg{queryId: queryId, rows: rows, entries: entries, searchErr: err, forceUpdateTable: true, maintainCursor: false, overriddenSearchQuery: nil})
} else {
// initialQuery is likely invalid in some way, let's just drop it
emptyQuery := ""
rows, entries, err := getRows(ctx, hctx.GetConf(ctx).DisplayedColumns, conf.DefaultFilter, emptyQuery, PADDED_NUM_ENTRIES)
rows, entries, err := getRows(ctx, hctx.GetConf(ctx).DisplayedColumns, shellName, conf.DefaultFilter, emptyQuery, PADDED_NUM_ENTRIES)
p.Send(asyncQueryFinishedMsg{queryId: queryId, rows: rows, entries: entries, searchErr: err, forceUpdateTable: true, maintainCursor: false, overriddenSearchQuery: &emptyQuery})
}
}()
Expand Down

0 comments on commit 0787840

Please sign in to comment.