diff --git a/.gitignore b/.gitignore index 7bd960c..02b5ca9 100644 --- a/.gitignore +++ b/.gitignore @@ -24,6 +24,8 @@ # Go workspace file go.work +# Project specific +build test_build tbd diff --git a/fetch_dbt_profiles.go b/fetch_dbt_profiles.go index 155148f..05e7b6f 100644 --- a/fetch_dbt_profiles.go +++ b/fetch_dbt_profiles.go @@ -9,55 +9,14 @@ import ( ) type DbtProfile struct { - Target string `yaml:"target"` Outputs map[string]struct { - ConnType string `yaml:"type"` - Account string `yaml:"account"` - User string `yaml:"user"` - Role string `yaml:"role"` - Authenticator string `yaml:"authenticator"` - Database string `yaml:"database"` - Schema string `yaml:"schema"` - Project string `yaml:"project"` - Dataset string `yaml:"dataset"` - Path string `yaml:"path"` - Threads int `yaml:"threads"` - Password string `yaml:"password"` - Port int `yaml:"port"` - Warehouse string `yaml:"warehouse"` - Method string `yaml:"method"` - Host string `yaml:"host"` - PrivateKey string `yaml:"private_key"` - PrivateKeyPath string `yaml:"private_key_path"` - PrivateKeyPassphrase string `yaml:"private_key_passphrase"` - ClientSessionKeepAlive bool `yaml:"client_session_keep_alive"` - QueryTag string `yaml:"query_tag"` - ConnectRetries int `yaml:"connect_retries"` - ConnectTimeout int `yaml:"connect_timeout"` - RetryOnDatabaseErrors bool `yaml:"retry_on_database_errors"` - RetryAll bool `yaml:"retry_all"` - ReuseConnections bool `yaml:"reuse_connections"` - Extensions []string `yaml:"extensions"` - RefreshToken string `yaml:"refresh_token"` - ClientID string `yaml:"client_id"` - ClientSecret string `yaml:"client_secret"` - TokenURI string `yaml:"token_uri"` - Token string `yaml:"token"` - Priority string `yaml:"priority"` - Keyfile string `yaml:"keyfile"` - JobExecutionTimeoutSeconds int `yaml:"job_execution_timeout_seconds"` - JobCreationTimeoutSeconds int `yaml:"job_creation_timeout_seconds"` - JobRetryDeadlineSeconds int `yaml:"job_retry_deadline_seconds"` - Location string `yaml:"location"` - MaximumBytesBilled int `yaml:"maximum_bytes_billed"` - Scopes []string `yaml:"scopes"` - ImpersonateServiceAccount string `yaml:"impersonate_service_account"` - ExecutionProject string `yaml:"execution_project"` - GcsBucket string `yaml:"gcs_bucket"` - DataprocRegion string `yaml:"dataproc_region"` - DataprocClusterName string `yaml:"dataproc_cluster_name"` - DataprocBatch map[string]interface{} `yaml:"dataproc_batch"` - KeyfileJson map[string]struct { + Settings map[string]struct { + S3Region string `yaml:"s3_region"` + S3AccessKeyID string `yaml:"s3_access_key_id"` + S3SecretAccessKey string `yaml:"s3_secret_access_key"` + } `yaml:"settings"` + DataprocBatch map[string]interface{} `yaml:"dataproc_batch"` + KeyfileJson map[string]struct { Type string `yaml:"type"` ProjectId string `yaml:"project_id"` PrivateKeyId string `yaml:"private_key_id"` @@ -69,12 +28,55 @@ type DbtProfile struct { AuthProviderX509CertUrl string `yaml:"auth_provider_x509_cert_url"` ClientX509CertUrl string `yaml:"client_x509_cert_url"` } `yaml:"keyfile_json"` - Settings map[string]struct { - S3Region string `yaml:"s3_region"` - S3AccessKeyID string `yaml:"s3_access_key_id"` - S3SecretAccessKey string `yaml:"s3_secret_access_key"` - } `yaml:"settings"` + PrivateKeyPath string `yaml:"private_key_path"` + DbName string `yaml:"dbname"` + Database string `yaml:"database"` + Account string `yaml:"account"` + Schema string `yaml:"schema"` + QueryTag string `yaml:"query_tag"` + Dataset string `yaml:"dataset"` + Path string `yaml:"path"` + Role string `yaml:"role"` + Password string `yaml:"password"` + User string `yaml:"user"` + Warehouse string `yaml:"warehouse"` + Method string `yaml:"method"` + Host string `yaml:"host"` + PrivateKey string `yaml:"private_key"` + Location string `yaml:"location"` + ConnType string `yaml:"type"` + Authenticator string `yaml:"authenticator"` + Project string `yaml:"project"` + SslMode string `yaml:"sslmode"` + DataprocClusterName string `yaml:"dataproc_cluster_name"` + DataprocRegion string `yaml:"dataproc_region"` + GcsBucket string `yaml:"gcs_bucket"` + ExecutionProject string `yaml:"execution_project"` + PrivateKeyPassphrase string `yaml:"private_key_passphrase"` + RefreshToken string `yaml:"refresh_token"` + ClientID string `yaml:"client_id"` + ClientSecret string `yaml:"client_secret"` + TokenURI string `yaml:"token_uri"` + Token string `yaml:"token"` + Priority string `yaml:"priority"` + Keyfile string `yaml:"keyfile"` + ImpersonateServiceAccount string `yaml:"impersonate_service_account"` + Extensions []string `yaml:"extensions"` + Scopes []string `yaml:"scopes"` + JobCreationTimeoutSeconds int `yaml:"job_creation_timeout_seconds"` + MaximumBytesBilled int `yaml:"maximum_bytes_billed"` + JobRetryDeadlineSeconds int `yaml:"job_retry_deadline_seconds"` + JobExecutionTimeoutSeconds int `yaml:"job_execution_timeout_seconds"` + ConnectTimeout int `yaml:"connect_timeout"` + ConnectRetries int `yaml:"connect_retries"` + Port int `yaml:"port"` + Threads int `yaml:"threads"` + ReuseConnections bool `yaml:"reuse_connections"` + RetryAll bool `yaml:"retry_all"` + RetryOnDatabaseErrors bool `yaml:"retry_on_database_errors"` + ClientSessionKeepAlive bool `yaml:"client_session_keep_alive"` } `yaml:"outputs"` + Target string `yaml:"target"` } type DbtProfiles map[string]DbtProfile diff --git a/fetch_dbt_profiles_test.go b/fetch_dbt_profiles_test.go index e826364..5f0438d 100644 --- a/fetch_dbt_profiles_test.go +++ b/fetch_dbt_profiles_test.go @@ -25,4 +25,7 @@ func TestFetchDbtProfiles(t *testing.T) { if profiles["dwarf"].Outputs["dev"].ConnType != "duckdb" { t.Fatalf("Expected duckdb, got %s\n", profiles["dwarf"].Outputs["dev"].ConnType) } + if profiles["ent"].Outputs["dev"].ConnType != "postgres" { + t.Fatalf("Expected postgres, got %s\n", profiles["ent"].Outputs["dev"].ConnType) + } } diff --git a/forms.go b/forms.go index 64157be..bf89b36 100644 --- a/forms.go +++ b/forms.go @@ -1,32 +1,38 @@ package main import ( + "errors" "fmt" + "strconv" "github.com/charmbracelet/huh" "github.com/fatih/color" ) type FormResponse struct { - Confirm bool - Warehouse string + Path string Username string - Account string + BuildDir string + SslMode string Database string Schema string Project string Dataset string - Path string - BuildDir string - GenerateDescriptions bool + ProjectName string + Warehouse string + Account string GroqKeyEnvVar string - UseDbtProfile bool + Password string DbtProfileName string DbtProfileOutput string - CreateProfile bool - ScaffoldProject bool - ProjectName string + Port string + Host string Prefix string + GenerateDescriptions bool + ScaffoldProject bool + CreateProfile bool + UseDbtProfile bool + Confirm bool } var not_empty = func(s string) error { @@ -52,6 +58,8 @@ func Forms(ps DbtProfiles) (FormResponse, error) { BuildDir: "build", GroqKeyEnvVar: "GROQ_API_KEY", Prefix: "stg", + Host: "localhost", + Port: "5432", } pinkUnderline := color.New(color.FgMagenta).Add(color.Bold, color.Underline).SprintFunc() greenBold := color.New(color.FgGreen).Add(color.Bold).SprintFunc() @@ -70,9 +78,9 @@ To prepare, make sure you have the following: *_OR_* ✴︎ The necessary %s for your warehouse -_See README for warehouse-specific requirements_ +_See_ %s _for warehouse-specific requirements_: https://github.com/gwenwindflower/tbd -`, greenBold(Version), pinkUnderline("existing dbt profile"), pinkUnderline("connection details"))), +`, greenBold(Version), pinkUnderline("existing dbt profile"), pinkUnderline("connection details"), greenBold("README"))), ), huh.NewGroup( @@ -137,6 +145,7 @@ https://github.com/gwenwindflower/tbd huh.NewOption("Snowflake", "snowflake"), huh.NewOption("BigQuery", "bigquery"), huh.NewOption("DuckDB", "duckdb"), + huh.NewOption("Postgres", "postgres"), ). Value(&dfr.Warehouse), ).WithHideFunc(func() bool { @@ -193,17 +202,67 @@ Relative to pwd e.g. if db is in this dir -> cool_ducks.db`). huh.NewInput(). Title("What is the *database* you want to generate?"). Value(&dfr.Database). - Placeholder("duckdb"). + Placeholder("gimli_corp"). Validate(not_empty), huh.NewInput(). Title("What is the *schema* you want to generate?"). Value(&dfr.Schema). - Placeholder("raw"). + Placeholder("moria"). Validate(not_empty), ).WithHideFunc(func() bool { return dfr.Warehouse != "duckdb" }), + huh.NewGroup( + huh.NewInput(). + Title("What is your Postgres *host*?"). + Value(&dfr.Host). + Validate(not_empty), + huh.NewInput(). + Title("What is your Postgres *port*?"). + Value(&dfr.Port). + Validate(func(s string) error { + port, err := strconv.Atoi(s) + if err != nil || port < 1000 || port > 9999 { + return errors.New("port must be a 4-digit number") + } + return nil + }), + huh.NewInput(). + Title("What is your Postgres *username*?"). + Value(&dfr.Username). + Placeholder("galadriel"). + Validate(not_empty), + huh.NewInput(). + Title("What is your Postgres *password*?"). + Value(&dfr.Password). + Validate(not_empty). + EchoMode(huh.EchoModePassword), + huh.NewInput(). + Title("What is the *database* you want to generate?"). + Value(&dfr.Database). + Placeholder("lothlorien"). + Validate(not_empty), + huh.NewInput(). + Title("What is the *schema* you want to generate?"). + Value(&dfr.Schema). + Placeholder("mallorn_trees"). + Validate(not_empty), + huh.NewSelect[string](). + Title("What ssl mode do you want to use?"). + Value(&dfr.SslMode). + Options( + huh.NewOption("Disable", "disable"), + huh.NewOption("Require", "require"), + huh.NewOption("Verify-ca", "verify-ca"), + huh.NewOption("Verify-full", "verify-full"), + huh.NewOption("Prefer", "prefer"), + huh.NewOption("Allow", "allow")). + Validate(not_empty), + ).WithHideFunc(func() bool { + return dfr.Warehouse != "postgres" + }), + huh.NewGroup( huh.NewNote(). Title(fmt.Sprintf("🤖 %s LLM generation 🦙✨", redBold("Experimental"))). @@ -212,7 +271,7 @@ Currently generates: ✴︎ column %s ✴︎ relevant %s -Requires a %s stored in an env var +_Requires a_ %s _stored in an env var_: Get one at https://groq.com.`, yellowItalic("Optional"), pinkUnderline("descriptions"), pinkUnderline("tests"), greenBoldItalic("Groq API key"))), huh.NewConfirm(). Title("Do you want to infer descriptions and tests?"). diff --git a/generate_column_desc.go b/generate_column_desc.go index 4fd7593..4c91d1e 100644 --- a/generate_column_desc.go +++ b/generate_column_desc.go @@ -17,13 +17,13 @@ import ( ) type Payload struct { - Messages []Message `json:"messages"` + Stop interface{} `json:"stop"` Model string `json:"model"` + Messages []Message `json:"messages"` Temp float64 `json:"temperature"` Tokens int `json:"max_tokens"` TopP int `json:"top_p"` Stream bool `json:"stream"` - Stop interface{} `json:"stop"` } type Message struct { @@ -32,18 +32,18 @@ type Message struct { } type GroqResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int `json:"created"` - Model string `json:"model"` - Choices []struct { - Index int `json:"index"` - Message struct { + SystemFingerprint interface{} `json:"system_fingerprint"` + ID string `json:"id"` + Object string `json:"object"` + Model string `json:"model"` + Choices []struct { + Logprobs interface{} `json:"logprobs"` + Message struct { Role string `json:"role"` Content string `json:"content"` } `json:"message"` - Logprobs interface{} `json:"logprobs"` - FinishReason string `json:"finish_reason"` + FinishReason string `json:"finish_reason"` + Index int `json:"index"` } `json:"choices"` Usage struct { PromptTokens int `json:"prompt_tokens"` @@ -53,7 +53,7 @@ type GroqResponse struct { TotalTokens int `json:"total_tokens"` TotalTime float64 `json:"total_time"` } `json:"usage"` - SystemFingerprint interface{} `json:"system_fingerprint"` + Created int `json:"created"` } // Groq API constants @@ -156,7 +156,7 @@ func GetGroqResponse(prompt string) (GroqResponse, error) { Content: prompt, }, }, - Model: "Llama3-70B-8192", + Model: "llama3-70b-8192", Temp: 0.5, Tokens: 2048, TopP: 1, diff --git a/go.mod b/go.mod index fdd149e..a42ad55 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/charmbracelet/huh v0.3.1-0.20240306161957-71f31c155b08 github.com/fatih/color v1.16.0 github.com/jarcoal/httpmock v1.3.1 + github.com/lib/pq v1.10.9 github.com/marcboeker/go-duckdb v1.6.3 github.com/schollz/progressbar/v3 v3.14.2 github.com/snowflakedb/gosnowflake v1.9.0 diff --git a/go.sum b/go.sum index 602e113..4773741 100644 --- a/go.sum +++ b/go.sum @@ -194,6 +194,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/marcboeker/go-duckdb v1.6.3 h1:5qRxB3BosFXRjfQWNP0OOqEQFXllo6o7fHGrNA7NSuM= diff --git a/main.go b/main.go index c794e1d..63f40d7 100644 --- a/main.go +++ b/main.go @@ -11,8 +11,8 @@ import ( type Elapsed struct { DbStart time.Time - DbElapsed float64 ProcessingStart time.Time + DbElapsed float64 ProcessingElapsed float64 } diff --git a/set_connection_details.go b/set_connection_details.go index 431557d..27ab2d3 100644 --- a/set_connection_details.go +++ b/set_connection_details.go @@ -4,6 +4,7 @@ import ( "log" "os" "path/filepath" + "strconv" "github.com/gwenwindflower/tbd/shared" ) @@ -43,6 +44,25 @@ func SetConnectionDetails(fr FormResponse, ps DbtProfiles) shared.ConnectionDeta Schema: fr.Schema, } } + case "postgres": + var sslMode string + if profile.Outputs[fr.DbtProfileOutput].SslMode == "" { + sslMode = "disable" + } else { + sslMode = profile.Outputs[fr.DbtProfileOutput].SslMode + } + { + cd = shared.ConnectionDetails{ + ConnType: profile.Outputs[fr.DbtProfileOutput].ConnType, + Host: profile.Outputs[fr.DbtProfileOutput].Host, + Port: profile.Outputs[fr.DbtProfileOutput].Port, + Username: profile.Outputs[fr.DbtProfileOutput].User, + Password: profile.Outputs[fr.DbtProfileOutput].Password, + Database: profile.Outputs[fr.DbtProfileOutput].Database, + SslMode: sslMode, + Schema: fr.Schema, + } + } default: { log.Fatalf("Unsupported connection type %v\n", profile.Outputs[fr.DbtProfileOutput].ConnType) @@ -86,6 +106,23 @@ func SetConnectionDetails(fr FormResponse, ps DbtProfiles) shared.ConnectionDeta Schema: fr.Schema, } } + case "postgres": + port, err := strconv.Atoi(fr.Port) + if err != nil || port < 1000 || port > 9999 { + log.Fatalf("Port must be a 4-digit number\n") + } + { + cd = shared.ConnectionDetails{ + ConnType: fr.Warehouse, + Host: fr.Host, + Port: port, + Username: fr.Username, + Password: fr.Password, + Database: fr.Database, + Schema: fr.Schema, + SslMode: fr.SslMode, + } + } default: { log.Fatalf("Unsupported connection type %v\n", fr.Warehouse) diff --git a/set_connection_details_test.go b/set_connection_details_test.go index 3c83412..91ee7cc 100644 --- a/set_connection_details_test.go +++ b/set_connection_details_test.go @@ -126,3 +126,38 @@ func TestSetConnectionDetailsWithDuckDBWithoutDbtProfile(t *testing.T) { t.Errorf("got %v, want %v", connectionDetails, want) } } + +func TestSetConnectionDetailsPostgresWithoutDbtProfile(t *testing.T) { + formResponse := FormResponse{ + UseDbtProfile: false, + Warehouse: "postgres", + Host: "localhost", + Port: "5432", + Username: "treebeard", + Password: "entmoot", + Database: "fangorn", + Schema: "huorns", + SslMode: "disable", + GenerateDescriptions: false, + BuildDir: "test_build", + Confirm: true, + } + ps, err := FetchDbtProfiles() + if err != nil { + t.Errorf("Error fetching dbt profiles: %v", err) + } + cd := SetConnectionDetails(formResponse, ps) + want := shared.ConnectionDetails{ + ConnType: "postgres", + Host: "localhost", + Port: 5432, + Username: "treebeard", + Password: "entmoot", + Database: "fangorn", + Schema: "huorns", + SslMode: "disable", + } + if cd != want { + t.Errorf("got %v, want %v", cd, want) + } +} diff --git a/shared/types.go b/shared/types.go index 3d0f8e4..cda3ea4 100644 --- a/shared/types.go +++ b/shared/types.go @@ -10,8 +10,8 @@ type Column struct { type SourceTable struct { DataTypeGroups map[string][]Column `yaml:"-"` Name string `yaml:"name"` - Columns []Column `yaml:"columns"` Schema string `yaml:"-"` + Columns []Column `yaml:"columns"` } type SourceTables struct { @@ -19,13 +19,17 @@ type SourceTables struct { } type ConnectionDetails struct { - ConnType string + Dataset string Username string Account string Database string Schema string Project string - Dataset string + ConnType string Path string ProjectName string + Host string + Password string + SslMode string + Port int } diff --git a/sourcerer/connect_to_db.go b/sourcerer/connect_to_db.go index c620e8d..90b31d4 100644 --- a/sourcerer/connect_to_db.go +++ b/sourcerer/connect_to_db.go @@ -9,6 +9,7 @@ import ( "time" "cloud.google.com/go/bigquery" + _ "github.com/lib/pq" _ "github.com/marcboeker/go-duckdb" _ "github.com/snowflakedb/gosnowflake" ) @@ -60,3 +61,14 @@ func (dc *DuckConn) ConnectToDb(ctx context.Context) (err error) { } return err } + +func (pgc *PgConn) ConnectToDb(ctx context.Context) (err error) { + _, pgc.Cancel = context.WithTimeout(ctx, 1*time.Minute) + defer pgc.Cancel() + connStr := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", pgc.Host, pgc.Port, pgc.Username, pgc.Password, pgc.Database, pgc.SslMode) + pgc.Db, err = sql.Open("postgres", connStr) + if err != nil { + log.Fatalf("Could not connect to Postgres %v\n", err) + } + return err +} diff --git a/sourcerer/connect_to_db_test.go b/sourcerer/connect_to_db_test.go index 32ff39e..e1181ef 100644 --- a/sourcerer/connect_to_db_test.go +++ b/sourcerer/connect_to_db_test.go @@ -7,7 +7,7 @@ import ( "github.com/gwenwindflower/tbd/shared" ) -func TestConnectToDb(t *testing.T) { +func TestConnectToDbSnowflake(t *testing.T) { cd := shared.ConnectionDetails{ ConnType: "snowflake", Account: "dunedain.snowflakecomputing.com", @@ -19,7 +19,7 @@ func TestConnectToDb(t *testing.T) { if err != nil { t.Errorf("GetConn failed: %v", err) } - SfConn, ok := conn.(*SfConn) + sfc, ok := conn.(*SfConn) if !ok { t.Errorf("conn not of type SfConn: %v", err) } @@ -27,10 +27,43 @@ func TestConnectToDb(t *testing.T) { if err != nil { t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) } - SfConn.Db = db - defer SfConn.Db.Close() + sfc.Db = db + defer sfc.Db.Close() mock.ExpectBegin() - if _, err := SfConn.Db.Begin(); err != nil { + if _, err := sfc.Db.Begin(); err != nil { + t.Errorf("error '%s' was not expected, while pinging db", err) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestConnectToDbPostgres(t *testing.T) { + cd := shared.ConnectionDetails{ + ConnType: "postgres", + Host: "localhost", + Port: 5432, + Username: "frodo", + Password: "0nering", + Database: "shire", + Schema: "hobbiton", + } + conn, err := GetConn(cd) + if err != nil { + t.Errorf("GetConn failed: %v", err) + } + pgc, ok := conn.(*PgConn) + if !ok { + t.Errorf("conn not of type PgConn: %v", err) + } + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + pgc.Db = db + defer pgc.Db.Close() + mock.ExpectBegin() + if _, err := pgc.Db.Begin(); err != nil { t.Errorf("error '%s' was not expected, while pinging db", err) } if err := mock.ExpectationsWereMet(); err != nil { diff --git a/sourcerer/get_columns.go b/sourcerer/get_columns.go index 157c30f..639b1d0 100644 --- a/sourcerer/get_columns.go +++ b/sourcerer/get_columns.go @@ -77,3 +77,21 @@ func (dc *DuckConn) GetColumns(ctx context.Context, t shared.SourceTable) ([]sha } return cs, nil } + +func (pgc *PgConn) GetColumns(ctx context.Context, t shared.SourceTable) ([]shared.Column, error) { + var cs []shared.Column + q := fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = '%s' AND table_name = '%s'", pgc.Schema, t.Name) + rows, err := pgc.Db.QueryContext(ctx, q) + if err != nil { + log.Fatalf("Error fetching columns for table %s: %v\n", t.Name, err) + } + defer rows.Close() + for rows.Next() { + c := shared.Column{} + if err := rows.Scan(&c.Name, &c.DataType); err != nil { + log.Fatalf("Error scanning columns for table %s: %v\n", t.Name, err) + } + cs = append(cs, c) + } + return cs, nil +} diff --git a/sourcerer/get_conn.go b/sourcerer/get_conn.go index 3d22d01..0173952 100644 --- a/sourcerer/get_conn.go +++ b/sourcerer/get_conn.go @@ -18,34 +18,45 @@ type DbConn interface { } type SfConn struct { + Db *sql.DB + Cancel context.CancelFunc Account string Username string Database string Schema string - Db *sql.DB - Cancel context.CancelFunc } type BqConn struct { - Project string - Dataset string Bq *bigquery.Client Cancel context.CancelFunc + Project string + Dataset string } type DuckConn struct { + Db *sql.DB + Cancel context.CancelFunc Path string Database string Schema string +} + +type PgConn struct { Db *sql.DB Cancel context.CancelFunc + Host string + Username string + Password string + Database string + Schema string + SslMode string + Port int } func GetConn(cd shared.ConnectionDetails) (DbConn, error) { switch cd.ConnType { case "snowflake": { - // TODO: Why do I need to use a pointer here? return &SfConn{ Account: strings.ToUpper(cd.Account), Username: strings.ToUpper(cd.Username), @@ -68,6 +79,18 @@ func GetConn(cd shared.ConnectionDetails) (DbConn, error) { Schema: cd.Schema, }, nil } + case "postgres": + { + return &PgConn{ + Host: cd.Host, + Port: cd.Port, + Username: cd.Username, + Password: cd.Password, + Database: cd.Database, + Schema: cd.Schema, + SslMode: cd.SslMode, + }, nil + } default: return nil, errors.New("unsupported connection type") } diff --git a/sourcerer/get_conn_test.go b/sourcerer/get_conn_test.go index ea35d8f..5647dc5 100644 --- a/sourcerer/get_conn_test.go +++ b/sourcerer/get_conn_test.go @@ -75,3 +75,29 @@ func TestGetConnDuckDB(t *testing.T) { t.Errorf("GetConn failed: Account is not correct") } } + +func TestGetConnPostgres(t *testing.T) { + cd := shared.ConnectionDetails{ + ConnType: "postgres", + Host: "localhost", + Port: 5432, + Username: "frodo", + Password: "0nering", + Database: "shire", + Schema: "hobbiton", + } + conn, err := GetConn(cd) + if err != nil { + t.Errorf("GetConn failed: %v", err) + } + if conn == nil { + t.Errorf("GetConn failed: conn is nil") + } + pgc, ok := conn.(*PgConn) + if !ok { + t.Errorf("GetConn failed: conn is not of type PgConn") + } + if pgc.Host != "localhost" { + t.Errorf("GetConn failed: Host is not correct") + } +} diff --git a/sourcerer/get_sources_tables.go b/sourcerer/get_sources_tables.go index e7854cf..df1ee90 100644 --- a/sourcerer/get_sources_tables.go +++ b/sourcerer/get_sources_tables.go @@ -66,3 +66,23 @@ func (dc *DuckConn) GetSourceTables(ctx context.Context) (shared.SourceTables, e } return ts, nil } + +func (pgc *PgConn) GetSourceTables(ctx context.Context) (shared.SourceTables, error) { + ts := shared.SourceTables{} + defer pgc.Cancel() + q := fmt.Sprintf("SELECT table_name FROM information_schema.tables WHERE table_schema = '%s'", pgc.Schema) + rows, err := pgc.Db.QueryContext(ctx, q) + if err != nil { + log.Fatalf("Error fetching tables: %v\n", err) + } + defer rows.Close() + for rows.Next() { + var table shared.SourceTable + if err := rows.Scan(&table.Name); err != nil { + log.Fatalf("Error scanning tables: %v\n", err) + } + table.Schema = pgc.Schema + ts.SourceTables = append(ts.SourceTables, table) + } + return ts, nil +} diff --git a/test_helpers.go b/test_helpers.go index e44c41b..b96bd60 100644 --- a/test_helpers.go +++ b/test_helpers.go @@ -40,6 +40,19 @@ dwarf: database: khazad_dum schema: balins_tomb threads: 4 + +ent: + target: dev + outputs: + dev: + type: postgres + host: localhost + port: 5432 + user: treebeard + password: entmoot + database: fangorn + schema: huorns + threads: 2 `) tmpDir := t.TempDir() err := os.Mkdir(filepath.Join(tmpDir, ".dbt"), 0755)