Skip to content

Commit

Permalink
don't duplicate mapping ch ColumnType to QValueKind
Browse files Browse the repository at this point in the history
  • Loading branch information
serprex committed Dec 9, 2024
1 parent 8c312e3 commit a1b29cc
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 30 deletions.
18 changes: 9 additions & 9 deletions flow/connectors/clickhouse/clickhouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -544,14 +544,7 @@ func (c *ClickHouseConnector) GetVersion(ctx context.Context) (string, error) {
return clickhouseVersion.Version.String(), nil
}

func (c *ClickHouseConnector) getTableSchemaForTable(ctx context.Context, tableName string) (*protos.TableSchema, error) {
// TODO sanitize
q, err := c.database.Query(ctx, fmt.Sprintf("select * from %s limit 0", tableName))
if err != nil {
return nil, err
}

columns := q.ColumnTypes()
func GetTableSchemaForTable(tableName string, columns []driver.ColumnType) (*protos.TableSchema, error) {
colFields := make([]*protos.FieldDescription, 0, len(columns))
for _, column := range columns {
var qkind qvalue.QValueKind
Expand All @@ -576,6 +569,7 @@ func (c *ClickHouseConnector) getTableSchemaForTable(ctx context.Context, tableN
Name: column.Name(),
Type: string(qkind),
TypeModifier: -1,
Nullable: column.Nullable(),
})
}

Expand All @@ -594,7 +588,13 @@ func (c *ClickHouseConnector) GetTableSchema(
) (map[string]*protos.TableSchema, error) {
res := make(map[string]*protos.TableSchema, len(tableIdentifiers))
for _, tableName := range tableIdentifiers {
tableSchema, err := c.getTableSchemaForTable(ctx, tableName)
rows, err := c.database.Query(ctx, fmt.Sprintf("select * from %s limit 0", tableName))
if err != nil {
return nil, err
}

tableSchema, err := GetTableSchemaForTable(tableName, rows.ColumnTypes())
rows.Close()
if err != nil {
return nil, err
}
Expand Down
32 changes: 11 additions & 21 deletions flow/e2e/clickhouse/clickhouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ func (s ClickHouseSuite) GetRows(table string, cols string) (*model.QRecordBatch
if err != nil {
return nil, err
}
defer ch.Close()

rows, err := ch.Query(
context.Background(),
Expand All @@ -104,36 +105,25 @@ func (s ClickHouseSuite) GetRows(table string, cols string) (*model.QRecordBatch
if err != nil {
return nil, err
}
defer rows.Close()

batch := &model.QRecordBatch{}
types := rows.ColumnTypes()
row := make([]interface{}, 0, len(types))
for _, ty := range types {
nullable := ty.Nullable()
tableSchema, err := connclickhouse.GetTableSchemaForTable(table, types)
if err != nil {
return nil, err
}

for idx, ty := range types {
fieldDesc := tableSchema.Columns[idx]
row = append(row, reflect.New(ty.ScanType()).Interface())
var qkind qvalue.QValueKind
switch ty.DatabaseTypeName() {
case "String", "Nullable(String)":
qkind = qvalue.QValueKindString
case "Int32", "Nullable(Int32)":
qkind = qvalue.QValueKindInt32
case "DateTime64(6)", "Nullable(DateTime64(6))":
qkind = qvalue.QValueKindTimestamp
case "Date32", "Nullable(Date32)":
qkind = qvalue.QValueKindDate
default:
if strings.Contains(ty.DatabaseTypeName(), "Decimal") {
qkind = qvalue.QValueKindNumeric
} else {
return nil, fmt.Errorf("failed to resolve QValueKind for %s", ty.DatabaseTypeName())
}
}
batch.Schema.Fields = append(batch.Schema.Fields, qvalue.QField{
Name: ty.Name(),
Type: qkind,
Type: qvalue.QValueKind(fieldDesc.Type),
Precision: 0,
Scale: 0,
Nullable: nullable,
Nullable: fieldDesc.Nullable,
})
}

Expand Down

0 comments on commit a1b29cc

Please sign in to comment.