diff --git a/internal/compiler/query_catalog.go b/internal/compiler/query_catalog.go index 80b59d876c..e71a709861 100644 --- a/internal/compiler/query_catalog.go +++ b/internal/compiler/query_catalog.go @@ -2,6 +2,7 @@ package compiler import ( "fmt" + "strings" "github.com/sqlc-dev/sqlc/internal/sql/ast" "github.com/sqlc-dev/sqlc/internal/sql/catalog" @@ -77,6 +78,35 @@ func ConvertColumn(rel *ast.TableName, c *catalog.Column) *Column { } } +func RevertConvertColumn(c *Column) *catalog.Column { + out := &catalog.Column{ + Name: c.Name, + IsNotNull: c.NotNull, + IsUnsigned: c.Unsigned, + IsArray: c.IsArray, + ArrayDims: c.ArrayDims, + Length: c.Length, + } + if c.Type != nil { + out.Type = *c.Type + } + dataTypes := strings.Split(c.DataType, ".") + if len(dataTypes) == 1 { + out.Type.Name = dataTypes[0] + } else if len(dataTypes) == 2 { + out.Type.Schema = dataTypes[0] + out.Type.Name = dataTypes[1] + } + return out +} + +func RevertConvertColumns(columns []*Column) (out []*catalog.Column) { + for i := range columns { + out = append(out, RevertConvertColumn(columns[i])) + } + return +} + func (qc QueryCatalog) GetTable(rel *ast.TableName) (*Table, error) { cte, exists := qc.ctes[rel.Name] if exists { diff --git a/internal/compiler/resolve.go b/internal/compiler/resolve.go index b1fbb1990e..fe5793fdd1 100644 --- a/internal/compiler/resolve.go +++ b/internal/compiler/resolve.go @@ -67,9 +67,20 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, continue } // If the table name doesn't exist, first check if it's a CTE - if _, qcerr := qc.GetTable(fqn); qcerr != nil { + cteTable, qcerr := qc.GetTable(fqn) + if qcerr != nil { return nil, err } + err = indexTable(catalog.Table{ + Rel: cteTable.Rel, + Columns: RevertConvertColumns(cteTable.Columns), + }) + if err != nil { + return nil, err + } + if rv.Alias != nil { + aliasMap[*rv.Alias.Aliasname] = fqn + } continue } err = indexTable(table)