Skip to content

Commit

Permalink
[Feature] Add initial support for Postgres (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
gwenwindflower authored Apr 20, 2024
2 parents b7130ea + 075e2b4 commit 11071cf
Show file tree
Hide file tree
Showing 18 changed files with 385 additions and 95 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
# Go workspace file
go.work

# Project specific
build
test_build
tbd

Expand Down
108 changes: 55 additions & 53 deletions fetch_dbt_profiles.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,55 +9,14 @@ import (
)

type DbtProfile struct {
Target string `yaml:"target"`
Outputs map[string]struct {
ConnType string `yaml:"type"`
Account string `yaml:"account"`
User string `yaml:"user"`
Role string `yaml:"role"`
Authenticator string `yaml:"authenticator"`
Database string `yaml:"database"`
Schema string `yaml:"schema"`
Project string `yaml:"project"`
Dataset string `yaml:"dataset"`
Path string `yaml:"path"`
Threads int `yaml:"threads"`
Password string `yaml:"password"`
Port int `yaml:"port"`
Warehouse string `yaml:"warehouse"`
Method string `yaml:"method"`
Host string `yaml:"host"`
PrivateKey string `yaml:"private_key"`
PrivateKeyPath string `yaml:"private_key_path"`
PrivateKeyPassphrase string `yaml:"private_key_passphrase"`
ClientSessionKeepAlive bool `yaml:"client_session_keep_alive"`
QueryTag string `yaml:"query_tag"`
ConnectRetries int `yaml:"connect_retries"`
ConnectTimeout int `yaml:"connect_timeout"`
RetryOnDatabaseErrors bool `yaml:"retry_on_database_errors"`
RetryAll bool `yaml:"retry_all"`
ReuseConnections bool `yaml:"reuse_connections"`
Extensions []string `yaml:"extensions"`
RefreshToken string `yaml:"refresh_token"`
ClientID string `yaml:"client_id"`
ClientSecret string `yaml:"client_secret"`
TokenURI string `yaml:"token_uri"`
Token string `yaml:"token"`
Priority string `yaml:"priority"`
Keyfile string `yaml:"keyfile"`
JobExecutionTimeoutSeconds int `yaml:"job_execution_timeout_seconds"`
JobCreationTimeoutSeconds int `yaml:"job_creation_timeout_seconds"`
JobRetryDeadlineSeconds int `yaml:"job_retry_deadline_seconds"`
Location string `yaml:"location"`
MaximumBytesBilled int `yaml:"maximum_bytes_billed"`
Scopes []string `yaml:"scopes"`
ImpersonateServiceAccount string `yaml:"impersonate_service_account"`
ExecutionProject string `yaml:"execution_project"`
GcsBucket string `yaml:"gcs_bucket"`
DataprocRegion string `yaml:"dataproc_region"`
DataprocClusterName string `yaml:"dataproc_cluster_name"`
DataprocBatch map[string]interface{} `yaml:"dataproc_batch"`
KeyfileJson map[string]struct {
Settings map[string]struct {
S3Region string `yaml:"s3_region"`
S3AccessKeyID string `yaml:"s3_access_key_id"`
S3SecretAccessKey string `yaml:"s3_secret_access_key"`
} `yaml:"settings"`
DataprocBatch map[string]interface{} `yaml:"dataproc_batch"`
KeyfileJson map[string]struct {
Type string `yaml:"type"`
ProjectId string `yaml:"project_id"`
PrivateKeyId string `yaml:"private_key_id"`
Expand All @@ -69,12 +28,55 @@ type DbtProfile struct {
AuthProviderX509CertUrl string `yaml:"auth_provider_x509_cert_url"`
ClientX509CertUrl string `yaml:"client_x509_cert_url"`
} `yaml:"keyfile_json"`
Settings map[string]struct {
S3Region string `yaml:"s3_region"`
S3AccessKeyID string `yaml:"s3_access_key_id"`
S3SecretAccessKey string `yaml:"s3_secret_access_key"`
} `yaml:"settings"`
PrivateKeyPath string `yaml:"private_key_path"`
DbName string `yaml:"dbname"`
Database string `yaml:"database"`
Account string `yaml:"account"`
Schema string `yaml:"schema"`
QueryTag string `yaml:"query_tag"`
Dataset string `yaml:"dataset"`
Path string `yaml:"path"`
Role string `yaml:"role"`
Password string `yaml:"password"`
User string `yaml:"user"`
Warehouse string `yaml:"warehouse"`
Method string `yaml:"method"`
Host string `yaml:"host"`
PrivateKey string `yaml:"private_key"`
Location string `yaml:"location"`
ConnType string `yaml:"type"`
Authenticator string `yaml:"authenticator"`
Project string `yaml:"project"`
SslMode string `yaml:"sslmode"`
DataprocClusterName string `yaml:"dataproc_cluster_name"`
DataprocRegion string `yaml:"dataproc_region"`
GcsBucket string `yaml:"gcs_bucket"`
ExecutionProject string `yaml:"execution_project"`
PrivateKeyPassphrase string `yaml:"private_key_passphrase"`
RefreshToken string `yaml:"refresh_token"`
ClientID string `yaml:"client_id"`
ClientSecret string `yaml:"client_secret"`
TokenURI string `yaml:"token_uri"`
Token string `yaml:"token"`
Priority string `yaml:"priority"`
Keyfile string `yaml:"keyfile"`
ImpersonateServiceAccount string `yaml:"impersonate_service_account"`
Extensions []string `yaml:"extensions"`
Scopes []string `yaml:"scopes"`
JobCreationTimeoutSeconds int `yaml:"job_creation_timeout_seconds"`
MaximumBytesBilled int `yaml:"maximum_bytes_billed"`
JobRetryDeadlineSeconds int `yaml:"job_retry_deadline_seconds"`
JobExecutionTimeoutSeconds int `yaml:"job_execution_timeout_seconds"`
ConnectTimeout int `yaml:"connect_timeout"`
ConnectRetries int `yaml:"connect_retries"`
Port int `yaml:"port"`
Threads int `yaml:"threads"`
ReuseConnections bool `yaml:"reuse_connections"`
RetryAll bool `yaml:"retry_all"`
RetryOnDatabaseErrors bool `yaml:"retry_on_database_errors"`
ClientSessionKeepAlive bool `yaml:"client_session_keep_alive"`
} `yaml:"outputs"`
Target string `yaml:"target"`
}

type DbtProfiles map[string]DbtProfile
Expand Down
3 changes: 3 additions & 0 deletions fetch_dbt_profiles_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,7 @@ func TestFetchDbtProfiles(t *testing.T) {
if profiles["dwarf"].Outputs["dev"].ConnType != "duckdb" {
t.Fatalf("Expected duckdb, got %s\n", profiles["dwarf"].Outputs["dev"].ConnType)
}
if profiles["ent"].Outputs["dev"].ConnType != "postgres" {
t.Fatalf("Expected postgres, got %s\n", profiles["ent"].Outputs["dev"].ConnType)
}
}
89 changes: 74 additions & 15 deletions forms.go
Original file line number Diff line number Diff line change
@@ -1,32 +1,38 @@
package main

import (
"errors"
"fmt"
"strconv"

"github.com/charmbracelet/huh"
"github.com/fatih/color"
)

type FormResponse struct {
Confirm bool
Warehouse string
Path string
Username string
Account string
BuildDir string
SslMode string
Database string
Schema string
Project string
Dataset string
Path string
BuildDir string
GenerateDescriptions bool
ProjectName string
Warehouse string
Account string
GroqKeyEnvVar string
UseDbtProfile bool
Password string
DbtProfileName string
DbtProfileOutput string
CreateProfile bool
ScaffoldProject bool
ProjectName string
Port string
Host string
Prefix string
GenerateDescriptions bool
ScaffoldProject bool
CreateProfile bool
UseDbtProfile bool
Confirm bool
}

var not_empty = func(s string) error {
Expand All @@ -52,6 +58,8 @@ func Forms(ps DbtProfiles) (FormResponse, error) {
BuildDir: "build",
GroqKeyEnvVar: "GROQ_API_KEY",
Prefix: "stg",
Host: "localhost",
Port: "5432",
}
pinkUnderline := color.New(color.FgMagenta).Add(color.Bold, color.Underline).SprintFunc()
greenBold := color.New(color.FgGreen).Add(color.Bold).SprintFunc()
Expand All @@ -70,9 +78,9 @@ To prepare, make sure you have the following:
*_OR_*
✴︎ The necessary %s for your warehouse
_See README for warehouse-specific requirements_
_See_ %s _for warehouse-specific requirements_:
https://github.com/gwenwindflower/tbd
`, greenBold(Version), pinkUnderline("existing dbt profile"), pinkUnderline("connection details"))),
`, greenBold(Version), pinkUnderline("existing dbt profile"), pinkUnderline("connection details"), greenBold("README"))),
),

huh.NewGroup(
Expand Down Expand Up @@ -137,6 +145,7 @@ https://github.com/gwenwindflower/tbd
huh.NewOption("Snowflake", "snowflake"),
huh.NewOption("BigQuery", "bigquery"),
huh.NewOption("DuckDB", "duckdb"),
huh.NewOption("Postgres", "postgres"),
).
Value(&dfr.Warehouse),
).WithHideFunc(func() bool {
Expand Down Expand Up @@ -193,17 +202,67 @@ Relative to pwd e.g. if db is in this dir -> cool_ducks.db`).
huh.NewInput().
Title("What is the *database* you want to generate?").
Value(&dfr.Database).
Placeholder("duckdb").
Placeholder("gimli_corp").
Validate(not_empty),
huh.NewInput().
Title("What is the *schema* you want to generate?").
Value(&dfr.Schema).
Placeholder("raw").
Placeholder("moria").
Validate(not_empty),
).WithHideFunc(func() bool {
return dfr.Warehouse != "duckdb"
}),

huh.NewGroup(
huh.NewInput().
Title("What is your Postgres *host*?").
Value(&dfr.Host).
Validate(not_empty),
huh.NewInput().
Title("What is your Postgres *port*?").
Value(&dfr.Port).
Validate(func(s string) error {
port, err := strconv.Atoi(s)
if err != nil || port < 1000 || port > 9999 {
return errors.New("port must be a 4-digit number")
}
return nil
}),
huh.NewInput().
Title("What is your Postgres *username*?").
Value(&dfr.Username).
Placeholder("galadriel").
Validate(not_empty),
huh.NewInput().
Title("What is your Postgres *password*?").
Value(&dfr.Password).
Validate(not_empty).
EchoMode(huh.EchoModePassword),
huh.NewInput().
Title("What is the *database* you want to generate?").
Value(&dfr.Database).
Placeholder("lothlorien").
Validate(not_empty),
huh.NewInput().
Title("What is the *schema* you want to generate?").
Value(&dfr.Schema).
Placeholder("mallorn_trees").
Validate(not_empty),
huh.NewSelect[string]().
Title("What ssl mode do you want to use?").
Value(&dfr.SslMode).
Options(
huh.NewOption("Disable", "disable"),
huh.NewOption("Require", "require"),
huh.NewOption("Verify-ca", "verify-ca"),
huh.NewOption("Verify-full", "verify-full"),
huh.NewOption("Prefer", "prefer"),
huh.NewOption("Allow", "allow")).
Validate(not_empty),
).WithHideFunc(func() bool {
return dfr.Warehouse != "postgres"
}),

huh.NewGroup(
huh.NewNote().
Title(fmt.Sprintf("🤖 %s LLM generation 🦙✨", redBold("Experimental"))).
Expand All @@ -212,7 +271,7 @@ Currently generates:
✴︎ column %s
✴︎ relevant %s
Requires a %s stored in an env var
_Requires a_ %s _stored in an env var_:
Get one at https://groq.com.`, yellowItalic("Optional"), pinkUnderline("descriptions"), pinkUnderline("tests"), greenBoldItalic("Groq API key"))),
huh.NewConfirm().
Title("Do you want to infer descriptions and tests?").
Expand Down
26 changes: 13 additions & 13 deletions generate_column_desc.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ import (
)

type Payload struct {
Messages []Message `json:"messages"`
Stop interface{} `json:"stop"`
Model string `json:"model"`
Messages []Message `json:"messages"`
Temp float64 `json:"temperature"`
Tokens int `json:"max_tokens"`
TopP int `json:"top_p"`
Stream bool `json:"stream"`
Stop interface{} `json:"stop"`
}

type Message struct {
Expand All @@ -32,18 +32,18 @@ type Message struct {
}

type GroqResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
Model string `json:"model"`
Choices []struct {
Index int `json:"index"`
Message struct {
SystemFingerprint interface{} `json:"system_fingerprint"`
ID string `json:"id"`
Object string `json:"object"`
Model string `json:"model"`
Choices []struct {
Logprobs interface{} `json:"logprobs"`
Message struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"message"`
Logprobs interface{} `json:"logprobs"`
FinishReason string `json:"finish_reason"`
FinishReason string `json:"finish_reason"`
Index int `json:"index"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
Expand All @@ -53,7 +53,7 @@ type GroqResponse struct {
TotalTokens int `json:"total_tokens"`
TotalTime float64 `json:"total_time"`
} `json:"usage"`
SystemFingerprint interface{} `json:"system_fingerprint"`
Created int `json:"created"`
}

// Groq API constants
Expand Down Expand Up @@ -156,7 +156,7 @@ func GetGroqResponse(prompt string) (GroqResponse, error) {
Content: prompt,
},
},
Model: "Llama3-70B-8192",
Model: "llama3-70b-8192",
Temp: 0.5,
Tokens: 2048,
TopP: 1,
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ require (
github.com/charmbracelet/huh v0.3.1-0.20240306161957-71f31c155b08
github.com/fatih/color v1.16.0
github.com/jarcoal/httpmock v1.3.1
github.com/lib/pq v1.10.9
github.com/marcboeker/go-duckdb v1.6.3
github.com/schollz/progressbar/v3 v3.14.2
github.com/snowflakedb/gosnowflake v1.9.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/marcboeker/go-duckdb v1.6.3 h1:5qRxB3BosFXRjfQWNP0OOqEQFXllo6o7fHGrNA7NSuM=
Expand Down
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import (

type Elapsed struct {
DbStart time.Time
DbElapsed float64
ProcessingStart time.Time
DbElapsed float64
ProcessingElapsed float64
}

Expand Down
Loading

0 comments on commit 11071cf

Please sign in to comment.