From a1b29ccc2963d54d8da4c1c24fa5f642d78fda28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Mon, 9 Dec 2024 18:39:12 +0000 Subject: [PATCH] don't duplicate mapping ch ColumnType to QValueKind --- flow/connectors/clickhouse/clickhouse.go | 18 ++++++------- flow/e2e/clickhouse/clickhouse.go | 32 ++++++++---------------- 2 files changed, 20 insertions(+), 30 deletions(-) diff --git a/flow/connectors/clickhouse/clickhouse.go b/flow/connectors/clickhouse/clickhouse.go index 8e43dbfb1..432a650b4 100644 --- a/flow/connectors/clickhouse/clickhouse.go +++ b/flow/connectors/clickhouse/clickhouse.go @@ -544,14 +544,7 @@ func (c *ClickHouseConnector) GetVersion(ctx context.Context) (string, error) { return clickhouseVersion.Version.String(), nil } -func (c *ClickHouseConnector) getTableSchemaForTable(ctx context.Context, tableName string) (*protos.TableSchema, error) { - // TODO sanitize - q, err := c.database.Query(ctx, fmt.Sprintf("select * from %s limit 0", tableName)) - if err != nil { - return nil, err - } - - columns := q.ColumnTypes() +func GetTableSchemaForTable(tableName string, columns []driver.ColumnType) (*protos.TableSchema, error) { colFields := make([]*protos.FieldDescription, 0, len(columns)) for _, column := range columns { var qkind qvalue.QValueKind @@ -576,6 +569,7 @@ func (c *ClickHouseConnector) getTableSchemaForTable(ctx context.Context, tableN Name: column.Name(), Type: string(qkind), TypeModifier: -1, + Nullable: column.Nullable(), }) } @@ -594,7 +588,13 @@ func (c *ClickHouseConnector) GetTableSchema( ) (map[string]*protos.TableSchema, error) { res := make(map[string]*protos.TableSchema, len(tableIdentifiers)) for _, tableName := range tableIdentifiers { - tableSchema, err := c.getTableSchemaForTable(ctx, tableName) + rows, err := c.database.Query(ctx, fmt.Sprintf("select * from %s limit 0", tableName)) + if err != nil { + return nil, err + } + + tableSchema, err := GetTableSchemaForTable(tableName, rows.ColumnTypes()) + rows.Close() if err != nil { return nil, err } diff --git a/flow/e2e/clickhouse/clickhouse.go b/flow/e2e/clickhouse/clickhouse.go index 975676152..a36a5335b 100644 --- a/flow/e2e/clickhouse/clickhouse.go +++ b/flow/e2e/clickhouse/clickhouse.go @@ -96,6 +96,7 @@ func (s ClickHouseSuite) GetRows(table string, cols string) (*model.QRecordBatch if err != nil { return nil, err } + defer ch.Close() rows, err := ch.Query( context.Background(), @@ -104,36 +105,25 @@ func (s ClickHouseSuite) GetRows(table string, cols string) (*model.QRecordBatch if err != nil { return nil, err } + defer rows.Close() batch := &model.QRecordBatch{} types := rows.ColumnTypes() row := make([]interface{}, 0, len(types)) - for _, ty := range types { - nullable := ty.Nullable() + tableSchema, err := connclickhouse.GetTableSchemaForTable(table, types) + if err != nil { + return nil, err + } + + for idx, ty := range types { + fieldDesc := tableSchema.Columns[idx] row = append(row, reflect.New(ty.ScanType()).Interface()) - var qkind qvalue.QValueKind - switch ty.DatabaseTypeName() { - case "String", "Nullable(String)": - qkind = qvalue.QValueKindString - case "Int32", "Nullable(Int32)": - qkind = qvalue.QValueKindInt32 - case "DateTime64(6)", "Nullable(DateTime64(6))": - qkind = qvalue.QValueKindTimestamp - case "Date32", "Nullable(Date32)": - qkind = qvalue.QValueKindDate - default: - if strings.Contains(ty.DatabaseTypeName(), "Decimal") { - qkind = qvalue.QValueKindNumeric - } else { - return nil, fmt.Errorf("failed to resolve QValueKind for %s", ty.DatabaseTypeName()) - } - } batch.Schema.Fields = append(batch.Schema.Fields, qvalue.QField{ Name: ty.Name(), - Type: qkind, + Type: qvalue.QValueKind(fieldDesc.Type), Precision: 0, Scale: 0, - Nullable: nullable, + Nullable: fieldDesc.Nullable, }) }