diff --git a/flow/connectors/clickhouse/clickhouse.go b/flow/connectors/clickhouse/clickhouse.go index bf7e1b4c9..d99462042 100644 --- a/flow/connectors/clickhouse/clickhouse.go +++ b/flow/connectors/clickhouse/clickhouse.go @@ -23,6 +23,7 @@ import ( metadataStore "github.com/PeerDB-io/peer-flow/connectors/external_metadata" "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/generated/protos" + "github.com/PeerDB-io/peer-flow/model/qvalue" "github.com/PeerDB-io/peer-flow/peerdbenv" "github.com/PeerDB-io/peer-flow/shared" ) @@ -542,3 +543,71 @@ func (c *ClickHouseConnector) GetVersion(ctx context.Context) (string, error) { c.logger.Info("[clickhouse] version", slog.Any("version", clickhouseVersion.DisplayName)) return clickhouseVersion.Version.String(), nil } + +func GetTableSchemaForTable(tableName string, columns []driver.ColumnType) (*protos.TableSchema, error) { + colFields := make([]*protos.FieldDescription, 0, len(columns)) + for _, column := range columns { + var qkind qvalue.QValueKind + switch column.DatabaseTypeName() { + case "String", "Nullable(String)": + qkind = qvalue.QValueKindString + case "Bool", "Nullable(Bool)": + qkind = qvalue.QValueKindBoolean + case "Int16", "Nullable(Int16)": + qkind = qvalue.QValueKindInt16 + case "Int32", "Nullable(Int32)": + qkind = qvalue.QValueKindInt32 + case "Int64", "Nullable(Int64)": + qkind = qvalue.QValueKindInt64 + case "UUID", "Nullable(UUID)": + qkind = qvalue.QValueKindUUID + case "DateTime64(6)", "Nullable(DateTime64(6))": + qkind = qvalue.QValueKindTimestamp + case "Date32", "Nullable(Date32)": + qkind = qvalue.QValueKindDate + default: + if strings.Contains(column.DatabaseTypeName(), "Decimal") { + qkind = qvalue.QValueKindNumeric + } else { + return nil, fmt.Errorf("failed to resolve QValueKind for %s", column.DatabaseTypeName()) + } + } + + colFields = append(colFields, &protos.FieldDescription{ + Name: column.Name(), + Type: string(qkind), + TypeModifier: -1, + Nullable: column.Nullable(), + }) + } + + return &protos.TableSchema{ + TableIdentifier: tableName, + Columns: colFields, + System: protos.TypeSystem_Q, + }, nil +} + +func (c *ClickHouseConnector) GetTableSchema( + ctx context.Context, + _env map[string]string, + _system protos.TypeSystem, + tableIdentifiers []string, +) (map[string]*protos.TableSchema, error) { + res := make(map[string]*protos.TableSchema, len(tableIdentifiers)) + for _, tableName := range tableIdentifiers { + rows, err := c.database.Query(ctx, fmt.Sprintf("select * from %s limit 0", tableName)) + if err != nil { + return nil, err + } + + tableSchema, err := GetTableSchemaForTable(tableName, rows.ColumnTypes()) + rows.Close() + if err != nil { + return nil, err + } + res[tableName] = tableSchema + } + + return res, nil +} diff --git a/flow/connectors/core.go b/flow/connectors/core.go index 0991a5097..afdf24494 100644 --- a/flow/connectors/core.go +++ b/flow/connectors/core.go @@ -470,6 +470,7 @@ var ( _ GetTableSchemaConnector = &connpostgres.PostgresConnector{} _ GetTableSchemaConnector = &connsnowflake.SnowflakeConnector{} + _ GetTableSchemaConnector = &connclickhouse.ClickHouseConnector{} _ NormalizedTablesConnector = &connpostgres.PostgresConnector{} _ NormalizedTablesConnector = &connbigquery.BigQueryConnector{} diff --git a/flow/e2e/clickhouse/clickhouse.go b/flow/e2e/clickhouse/clickhouse.go index 975676152..a36a5335b 100644 --- a/flow/e2e/clickhouse/clickhouse.go +++ b/flow/e2e/clickhouse/clickhouse.go @@ -96,6 +96,7 @@ func (s ClickHouseSuite) GetRows(table string, cols string) (*model.QRecordBatch if err != nil { return nil, err } + defer ch.Close() rows, err := ch.Query( context.Background(), @@ -104,36 +105,25 @@ func (s ClickHouseSuite) GetRows(table string, cols string) (*model.QRecordBatch if err != nil { return nil, err } + defer rows.Close() batch := &model.QRecordBatch{} types := rows.ColumnTypes() row := make([]interface{}, 0, len(types)) - for _, ty := range types { - nullable := ty.Nullable() + tableSchema, err := connclickhouse.GetTableSchemaForTable(table, types) + if err != nil { + return nil, err + } + + for idx, ty := range types { + fieldDesc := tableSchema.Columns[idx] row = append(row, reflect.New(ty.ScanType()).Interface()) - var qkind qvalue.QValueKind - switch ty.DatabaseTypeName() { - case "String", "Nullable(String)": - qkind = qvalue.QValueKindString - case "Int32", "Nullable(Int32)": - qkind = qvalue.QValueKindInt32 - case "DateTime64(6)", "Nullable(DateTime64(6))": - qkind = qvalue.QValueKindTimestamp - case "Date32", "Nullable(Date32)": - qkind = qvalue.QValueKindDate - default: - if strings.Contains(ty.DatabaseTypeName(), "Decimal") { - qkind = qvalue.QValueKindNumeric - } else { - return nil, fmt.Errorf("failed to resolve QValueKind for %s", ty.DatabaseTypeName()) - } - } batch.Schema.Fields = append(batch.Schema.Fields, qvalue.QField{ Name: ty.Name(), - Type: qkind, + Type: qvalue.QValueKind(fieldDesc.Type), Precision: 0, Scale: 0, - Nullable: nullable, + Nullable: fieldDesc.Nullable, }) }