diff --git a/forms.go b/forms.go index 95b4174..15f8f17 100644 --- a/forms.go +++ b/forms.go @@ -154,7 +154,7 @@ Relative to pwd e.g. if db is in this dir -> cool_ducks.db`). huh.NewGroup( huh.NewNote(). Title("🚧🚨 Choose your build directory carefully! 🚨🚧"). - Description(`_I highly recommend choosing a new or empty directory to build into._ + Description(`Choose a _new_ or _empty_ directory. If you use an existing directory, tbd will overwrite any existing files of the same name.`), ), diff --git a/generate_column_desc.go b/generate_column_desc.go index 4d7b227..504ded7 100644 --- a/generate_column_desc.go +++ b/generate_column_desc.go @@ -136,6 +136,7 @@ func GenerateColumnDescriptions(tables shared.SourceTables) { }(i, j) } } + wg.Wait() } func GetGroqResponse(prompt string) (GroqResponse, error) { diff --git a/generate_column_desc_test.go b/generate_column_desc_test.go new file mode 100644 index 0000000..ede81cc --- /dev/null +++ b/generate_column_desc_test.go @@ -0,0 +1,47 @@ +package main + +import ( + "testing" + + "github.com/jarcoal/httpmock" +) + +func TestGetGroqResponse(t *testing.T) { + prompt := "Who destroyed Orthanc" + httpmock.Activate() + defer httpmock.DeactivateAndReset() + httpmock.RegisterResponder("POST", "https://api.groq.com/openai/v1/chat/completions", + httpmock.NewStringResponder(200, `{"choices": [{"index": 0, "message": {"role": "assistant","content": "Treebeard and the Ents destroyed Orthanc."}}]}`)) + GroqResponse, err := GetGroqResponse(prompt) + if err != nil { + t.Error("expected", nil, "got", err) + } + info := httpmock.GetCallCountInfo() + if info["POST https://api.groq.com/openai/v1/chat/completions"] != 1 { + t.Error("expected", 1, "got", info["POST https://api.groq.com/openai/v1/chat/completions"]) + } + expected := "Treebeard and the Ents destroyed Orthanc." + if GroqResponse.Choices[0].Message.Content != expected { + t.Error("expected", expected, "got", GroqResponse.Choices[0].Message.Content) + } +} + +func TestGenerateColumnDescriptions(t *testing.T) { + ts := CreateTempSourceTables() + httpmock.Activate() + defer httpmock.DeactivateAndReset() + httpmock.RegisterResponder("POST", "https://api.groq.com/openai/v1/chat/completions", + httpmock.NewStringResponder(200, `{"choices": [{"index": 0, "message": {"role": "assistant","content": "lord of rivendell"}}]}`)) + GenerateColumnDescriptions(ts) + + info := httpmock.GetCallCountInfo() + if info["POST https://api.groq.com/openai/v1/chat/completions"] != 2 { + t.Error("expected", 2, "got", info["POST https://api.groq.com/openai/v1/chat/completions"]) + } + + expected := "lord of rivendell" + desc := ts.SourceTables[0].Columns[0].Description + if desc != expected { + t.Error("expected", expected, "got", desc) + } +} diff --git a/get_dbt_profile_test.go b/get_dbt_profile_test.go index e146ac3..fb17ec9 100644 --- a/get_dbt_profile_test.go +++ b/get_dbt_profile_test.go @@ -9,7 +9,8 @@ func TestGetDbtProfile(t *testing.T) { CreateTempDbtProfile(t) defer os.RemoveAll(os.Getenv("HOME")) defer os.Unsetenv("HOME") - // Profile exists + + // Profile exists profile, err := GetDbtProfile("elf") if err != nil { t.Errorf("GetDbtProfile returned an error for an existing profile: %v", err) @@ -20,6 +21,8 @@ func TestGetDbtProfile(t *testing.T) { if profile.Outputs["dev"].ConnType != "snowflake" { t.Errorf("Expected connection type 'snowflake', got '%s'", profile.Outputs["dev"].ConnType) } + + // Profile exists, DuckDB profile, err = GetDbtProfile("dwarf") if err != nil { t.Errorf("GetDbtProfile returned an error for an existing profile: %v", err) @@ -33,6 +36,7 @@ func TestGetDbtProfile(t *testing.T) { if profile.Outputs["dev"].Schema != "balins_tomb" { t.Errorf("Expected schema 'balins_tomb', got '%s'", profile.Outputs["dev"].Schema) } + // If using dbt profile with DuckDB, path should be unedited if profile.Outputs["dev"].Path != "/usr/local/var/dwarf.db" { t.Errorf("Expected path '/usr/local/var/dwarf.db', got '%s'", profile.Outputs["dev"].Path) } diff --git a/go.mod b/go.mod index 3e3d90b..d688ac7 100644 --- a/go.mod +++ b/go.mod @@ -64,6 +64,7 @@ require ( github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect github.com/googleapis/gax-go/v2 v2.12.3 // indirect github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c // indirect + github.com/jarcoal/httpmock v1.3.1 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/klauspost/compress v1.17.7 // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect diff --git a/go.sum b/go.sum index 4d8e53c..27e09be 100644 --- a/go.sum +++ b/go.sum @@ -172,6 +172,8 @@ github.com/googleapis/gax-go/v2 v2.12.3 h1:5/zPPDvw8Q1SuXjrqrZslrqT7dL/uJT2CQii/ github.com/googleapis/gax-go/v2 v2.12.3/go.mod h1:AKloxT6GtNbaLm8QTNSidHUVsHYcBHwWRvkNFJUQcS4= github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c h1:6rhixN/i8ZofjG1Y75iExal34USq5p+wiN1tpie8IrU= github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c/go.mod h1:NMPJylDgVpX0MLRlPy15sqSwOFv/U1GZ2m21JhFfek0= +github.com/jarcoal/httpmock v1.3.1 h1:iUx3whfZWVf3jT01hQTO/Eo5sAYtB2/rqaUuOtpInww= +github.com/jarcoal/httpmock v1.3.1/go.mod h1:3yb8rc4BI7TCBhFY8ng0gjuLKJNquuDNiPaZjnENuYg= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= diff --git a/main.go b/main.go index 34bc94e..da8975c 100644 --- a/main.go +++ b/main.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "log" - "sync" "time" "github.com/charmbracelet/huh/spinner" @@ -52,17 +51,10 @@ func main() { GenerateColumnDescriptions(ts) } PrepBuildDir(bd) - var wg sync.WaitGroup - wg.Add(2) - go func() { - defer wg.Done() - WriteYAML(ts, bd) - }() - go func() { - defer wg.Done() - WriteStagingModels(ts, bd) - }() - wg.Wait() + err = WriteFiles(ts, bd) + if err != nil { + log.Fatalf("Error writing files: %v\n", err) + } }).Title("🏎️✨ Generating YAML and SQL files...").Run() if err != nil { log.Fatalf("Error running spinner action: %v\n", err) diff --git a/set_connection_details_test.go b/set_connection_details_test.go index 90cc9fc..49fcdba 100644 --- a/set_connection_details_test.go +++ b/set_connection_details_test.go @@ -95,9 +95,14 @@ func TestSetConnectionDetailsWithDuckDBWithoutDbtProfile(t *testing.T) { Confirm: true, } connectionDetails := SetConnectionDetails(formResponse) + wd, err := os.Getwd() + if err != nil { + t.Errorf("Failed to get working directory: %v", err) + } + p := wd + "/dwarf.db" want := shared.ConnectionDetails{ ConnType: "duckdb", - Path: "/Users/winnie/dev/tbd/dwarf.db", + Path: p, Database: "khazad_dum", Schema: "balins_tomb", } diff --git a/test_helpers.go b/test_helpers.go index 391fc19..888d36b 100644 --- a/test_helpers.go +++ b/test_helpers.go @@ -4,6 +4,8 @@ import ( "os" "path/filepath" "testing" + + "github.com/gwenwindflower/tbd/shared" ) func CreateTempDbtProfile(t *testing.T) string { @@ -52,3 +54,21 @@ dwarf: os.Setenv("HOME", tmpDir) return tmpDir } + +func CreateTempSourceTables() shared.SourceTables { + return shared.SourceTables{ + SourceTables: []shared.SourceTable{ + { + Name: "arwen", + Columns: []shared.Column{ + { + Name: "elrond", + DataType: "string", + Description: "my dad", + Tests: []string{"unique", "not_null"}, + }, + }, + }, + }, + } +} diff --git a/write_files.go b/write_files.go new file mode 100644 index 0000000..01f8034 --- /dev/null +++ b/write_files.go @@ -0,0 +1,26 @@ +package main + +import ( + "errors" + "sync" + + "github.com/gwenwindflower/tbd/shared" +) + +func WriteFiles(ts shared.SourceTables, bd string) error { + if len(ts.SourceTables) == 0 { + return errors.New("no tables to write") + } + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + WriteYAML(ts, bd) + }() + go func() { + defer wg.Done() + WriteStagingModels(ts, bd) + }() + wg.Wait() + return nil +} diff --git a/write_files_test.go b/write_files_test.go new file mode 100644 index 0000000..6049018 --- /dev/null +++ b/write_files_test.go @@ -0,0 +1,42 @@ +package main + +import ( + "strings" + "testing" + + "github.com/gwenwindflower/tbd/shared" +) + +func TestWriteFiles(t *testing.T) { + ts := shared.SourceTables{ + SourceTables: []shared.SourceTable{ + { + Name: "table1", + Columns: []shared.Column{ + { + Name: "column1", + DataType: "type1", + }, + }, + }, + }, + } + bd := t.TempDir() + WriteFiles(ts, bd) +} + +func TestWriteFilesError(t *testing.T) { + ts := shared.SourceTables{ + SourceTables: []shared.SourceTable{}, + } + bd := t.TempDir() + + err := WriteFiles(ts, bd) + if err == nil { + t.Error("expected error, got nil") + } else { + if !strings.Contains(err.Error(), "no tables to write") { + t.Errorf("expected error to contain 'no tables to write', got %v", err) + } + } +}