Skip to content

Commit

Permalink
feat: implement destination unnesting for BigQuery (#2)
Browse files Browse the repository at this point in the history
* wip unnest for bigquery

* fix for timestamp

* feat: add unnesting to streaming

* chore: added tests on breaking schema changes

* chore: linting

* chore: linting

* feat: add assertion to detect new fields not in bigquery schema

---------

Co-authored-by: Anas El Mhamdi <[email protected]>
  • Loading branch information
aballiet and anaselmhamdi authored Dec 18, 2024
1 parent cdaed9f commit f7f586a
Show file tree
Hide file tree
Showing 8 changed files with 456 additions and 71 deletions.
34 changes: 31 additions & 3 deletions bizon/destinations/bigquery/src/config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from enum import Enum
from typing import Literal, Optional

import polars as pl
from pydantic import BaseModel, Field

from bizon.destinations.config import (
AbstractDestinationConfig,
AbstractDestinationDetailsConfig,
DestinationColumn,
DestinationTypes,
)

Expand Down Expand Up @@ -42,11 +44,37 @@ class BigQueryColumnMode(str, Enum):
REPEATED = "REPEATED"


class BigQueryColumn(BaseModel):
BIGQUERY_TO_POLARS_TYPE_MAPPING = {
"STRING": pl.String,
"BYTES": pl.Binary,
"INTEGER": pl.Int64,
"INT64": pl.Int64,
"FLOAT": pl.Float64,
"FLOAT64": pl.Float64,
"NUMERIC": pl.Float64, # Can be refined for precision with Decimal128 if needed
"BIGNUMERIC": pl.Float64, # Similar to NUMERIC
"BOOLEAN": pl.Boolean,
"BOOL": pl.Boolean,
"TIMESTAMP": pl.String, # We use BigQuery internal parsing to convert to datetime
"DATE": pl.String, # We use BigQuery internal parsing to convert to datetime
"DATETIME": pl.String, # We use BigQuery internal parsing to convert to datetime
"TIME": pl.Time,
"GEOGRAPHY": pl.Object, # Polars doesn't natively support geography types
"ARRAY": pl.List, # Requires additional handling for element types
"STRUCT": pl.Struct, # TODO
"JSON": pl.Object, # TODO
}


class BigQueryColumn(DestinationColumn):
name: str = Field(..., description="Name of the column")
type: BigQueryColumnType = Field(..., description="Type of the column")
mode: BigQueryColumnMode = Field(..., description="Mode of the column")
description: Optional[str] = Field(..., description="Description of the column")
description: Optional[str] = Field(None, description="Description of the column")

@property
def polars_type(self):
return BIGQUERY_TO_POLARS_TYPE_MAPPING.get(self.type.upper())


class BigQueryAuthentication(BaseModel):
Expand Down Expand Up @@ -87,5 +115,5 @@ class BigQueryConfigDetails(AbstractDestinationDetailsConfig):

class BigQueryConfig(AbstractDestinationConfig):
name: Literal[DestinationTypes.BIGQUERY]
buffer_size: Optional[int] = 2000
buffer_size: Optional[int] = 400
config: BigQueryConfigDetails
55 changes: 43 additions & 12 deletions bizon/destinations/bigquery/src/destination.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import io
import json
import os
import tempfile
import traceback
Expand All @@ -17,7 +16,7 @@
from bizon.engine.backend.backend import AbstractBackend
from bizon.source.config import SourceSyncModes

from .config import BigQueryConfigDetails
from .config import BigQueryColumn, BigQueryConfigDetails


class BigQueryDestination(AbstractDestination):
Expand Down Expand Up @@ -60,16 +59,30 @@ def temp_table_id(self) -> str:

def get_bigquery_schema(self, df_destination_records: pl.DataFrame) -> List[bigquery.SchemaField]:

return [
bigquery.SchemaField("_source_record_id", "STRING", mode="REQUIRED"),
bigquery.SchemaField("_source_timestamp", "TIMESTAMP", mode="REQUIRED"),
bigquery.SchemaField("_source_data", "STRING", mode="NULLABLE"),
bigquery.SchemaField("_bizon_extracted_at", "TIMESTAMP", mode="REQUIRED"),
bigquery.SchemaField(
"_bizon_loaded_at", "TIMESTAMP", mode="REQUIRED", default_value_expression="CURRENT_TIMESTAMP()"
),
bigquery.SchemaField("_bizon_id", "STRING", mode="REQUIRED"),
]
# Case we unnest the data
if self.config.unnest:
return [
bigquery.SchemaField(
col.name,
col.type,
mode=col.mode,
description=col.description,
)
for col in self.config.record_schema
]

# Case we don't unnest the data
else:
return [
bigquery.SchemaField("_source_record_id", "STRING", mode="REQUIRED"),
bigquery.SchemaField("_source_timestamp", "TIMESTAMP", mode="REQUIRED"),
bigquery.SchemaField("_source_data", "STRING", mode="NULLABLE"),
bigquery.SchemaField("_bizon_extracted_at", "TIMESTAMP", mode="REQUIRED"),
bigquery.SchemaField(
"_bizon_loaded_at", "TIMESTAMP", mode="REQUIRED", default_value_expression="CURRENT_TIMESTAMP()"
),
bigquery.SchemaField("_bizon_id", "STRING", mode="REQUIRED"),
]

def check_connection(self) -> bool:
dataset_ref = DatasetReference(self.project_id, self.dataset_id)
Expand Down Expand Up @@ -108,6 +121,24 @@ def convert_and_upload_to_buffer(self, df_destination_records: pl.DataFrame) ->

raise NotImplementedError(f"Buffer format {self.buffer_format} is not supported")

@staticmethod
def unnest_data(df_destination_records: pl.DataFrame, record_schema: list[BigQueryColumn]) -> pl.DataFrame:
"""Unnest the source_data field into separate columns"""

# Check if the schema matches the expected schema
source_data_fields = pl.DataFrame(df_destination_records['source_data'].str.json_decode()).schema["source_data"].fields

record_schema_fields = [col.name for col in record_schema]

for field in source_data_fields:
assert field.name in record_schema_fields, f"Column {field.name} not found in BigQuery schema"

# Parse the JSON and unnest the fields to polar type
return df_destination_records.select(
pl.col("source_data").str.json_path_match(f"$.{col.name}").cast(col.polars_type).alias(col.name)
for col in record_schema
)

def load_to_bigquery(self, gcs_file: str, df_destination_records: pl.DataFrame):

# We always partition by the loaded_at field
Expand Down
9 changes: 8 additions & 1 deletion bizon/destinations/bigquery_streaming/src/config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from enum import Enum
from typing import Literal, Optional

from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field

from bizon.destinations.bigquery.src.config import BigQueryColumn
from bizon.destinations.config import (
AbstractDestinationConfig,
AbstractDestinationDetailsConfig,
Expand Down Expand Up @@ -34,8 +35,14 @@ class BigQueryStreamingConfigDetails(AbstractDestinationDetailsConfig):
time_partitioning: Optional[TimePartitioning] = Field(
default=TimePartitioning.DAY, description="BigQuery Time partitioning type"
)
time_partitioning_field: Optional[str] = Field(
"_bizon_loaded_at", description="Field to partition by. You can use a transformation to create this field."
)
authentication: Optional[BigQueryAuthentication] = None
bq_max_rows_per_request: Optional[int] = Field(30000, description="Max rows per buffer streaming request.")
record_schema: Optional[list[BigQueryColumn]] = Field(
default=None, description="Schema for the records. Required if unnest is set to true."
)


class BigQueryStreamingConfig(AbstractDestinationConfig):
Expand Down
76 changes: 51 additions & 25 deletions bizon/destinations/bigquery_streaming/src/destination.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ProtoRows,
ProtoSchema,
)
from google.protobuf.json_format import ParseDict
from google.protobuf.message import Message

from bizon.common.models import SyncMetadata
Expand Down Expand Up @@ -48,17 +49,29 @@ def table_id(self) -> str:

def get_bigquery_schema(self) -> List[bigquery.SchemaField]:

# we keep raw data in the column source_data
return [
bigquery.SchemaField("_source_record_id", "STRING", mode="REQUIRED"),
bigquery.SchemaField("_source_timestamp", "TIMESTAMP", mode="REQUIRED"),
bigquery.SchemaField("_source_data", "STRING", mode="NULLABLE"),
bigquery.SchemaField("_bizon_extracted_at", "TIMESTAMP", mode="REQUIRED"),
bigquery.SchemaField(
"_bizon_loaded_at", "TIMESTAMP", mode="REQUIRED", default_value_expression="CURRENT_TIMESTAMP()"
),
bigquery.SchemaField("_bizon_id", "STRING", mode="REQUIRED"),
]
if self.config.unnest:
return [
bigquery.SchemaField(
col.name,
col.type,
mode=col.mode,
description=col.description,
)
for col in self.config.record_schema
]

# Case we don't unnest the data
else:
return [
bigquery.SchemaField("_source_record_id", "STRING", mode="REQUIRED"),
bigquery.SchemaField("_source_timestamp", "TIMESTAMP", mode="REQUIRED"),
bigquery.SchemaField("_source_data", "STRING", mode="NULLABLE"),
bigquery.SchemaField("_bizon_extracted_at", "TIMESTAMP", mode="REQUIRED"),
bigquery.SchemaField(
"_bizon_loaded_at", "TIMESTAMP", mode="REQUIRED", default_value_expression="CURRENT_TIMESTAMP()"
),
bigquery.SchemaField("_bizon_id", "STRING", mode="REQUIRED"),
]

def check_connection(self) -> bool:
dataset_ref = DatasetReference(self.project_id, self.dataset_id)
Expand Down Expand Up @@ -90,14 +103,8 @@ def append_rows_to_stream(

@staticmethod
def to_protobuf_serialization(TableRowClass: Type[Message], row: dict) -> bytes:
"""Convert a row to a protobuf serialization"""
record = TableRowClass()
record._bizon_id = row["bizon_id"]
record._bizon_extracted_at = row["bizon_extracted_at"].strftime("%Y-%m-%d %H:%M:%S.%f")
record._bizon_loaded_at = row["bizon_loaded_at"].strftime("%Y-%m-%d %H:%M:%S.%f")
record._source_record_id = row["source_record_id"]
record._source_timestamp = row["source_timestamp"].strftime("%Y-%m-%d %H:%M:%S.%f")
record._source_data = row["source_data"]
"""Convert a row to a Protobuf serialization."""
record = ParseDict(row, TableRowClass())
return record.SerializeToString()

def load_to_bigquery_via_streaming(self, df_destination_records: pl.DataFrame) -> str:
Expand All @@ -107,7 +114,9 @@ def load_to_bigquery_via_streaming(self, df_destination_records: pl.DataFrame) -
# Create table if it doesnt exist
schema = self.get_bigquery_schema()
table = bigquery.Table(self.table_id, schema=schema)
time_partitioning = TimePartitioning(field="_bizon_loaded_at", type_=self.config.time_partitioning)
time_partitioning = TimePartitioning(
field=self.config.time_partitioning_field, type_=self.config.time_partitioning
)
table.time_partitioning = time_partitioning

table = self.bq_client.create_table(table, exists_ok=True)
Expand All @@ -119,12 +128,29 @@ def load_to_bigquery_via_streaming(self, df_destination_records: pl.DataFrame) -
stream_name = f"{parent}/_default"

# Generating the protocol buffer representation of the message descriptor.
proto_schema, TableRow = get_proto_schema_and_class(clustering_keys)
proto_schema, TableRow = get_proto_schema_and_class(schema, clustering_keys)

serialized_rows = [
self.to_protobuf_serialization(TableRowClass=TableRow, row=row)
for row in df_destination_records.iter_rows(named=True)
]
if self.config.unnest:
serialized_rows = [
self.to_protobuf_serialization(TableRowClass=TableRow, row=row)
for row in df_destination_records["source_data"].str.json_decode().to_list()
]
else:
df_destination_records = df_destination_records.rename(
{
"bizon_id": "_bizon_id",
"bizon_extracted_at": "_bizon_extracted_at",
"bizon_loaded_at": "_bizon_loaded_at",
"source_record_id": "_source_record_id",
"source_timestamp": "_source_timestamp",
"source_data": "_source_data",
}
)

serialized_rows = [
self.to_protobuf_serialization(TableRowClass=TableRow, row=row)
for row in df_destination_records.iter_rows(named=True)
]

results = []
with ThreadPoolExecutor() as executor:
Expand Down
58 changes: 32 additions & 26 deletions bizon/destinations/bigquery_streaming/src/proto_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Tuple, Type

from google.cloud.bigquery import SchemaField
from google.cloud.bigquery_storage_v1.types import ProtoSchema
from google.protobuf.descriptor_pb2 import (
DescriptorProto,
Expand All @@ -11,7 +12,30 @@
from google.protobuf.message_factory import GetMessageClassesForFiles


def get_proto_schema_and_class(clustering_keys: List[str] = None) -> Tuple[ProtoSchema, Type[Message]]:
def map_bq_type_to_field_descriptor(bq_type: str) -> int:
"""Map BigQuery type to Protobuf FieldDescriptorProto type."""
type_map = {
"STRING": FieldDescriptorProto.TYPE_STRING, # STRING -> TYPE_STRING
"BYTES": FieldDescriptorProto.TYPE_BYTES, # BYTES -> TYPE_BYTES
"INTEGER": FieldDescriptorProto.TYPE_INT64, # INTEGER -> TYPE_INT64
"FLOAT": FieldDescriptorProto.TYPE_DOUBLE, # FLOAT -> TYPE_DOUBLE
"NUMERIC": FieldDescriptorProto.TYPE_STRING, # NUMERIC -> TYPE_STRING (use string to handle precision)
"BIGNUMERIC": FieldDescriptorProto.TYPE_STRING, # BIGNUMERIC -> TYPE_STRING
"BOOLEAN": FieldDescriptorProto.TYPE_BOOL, # BOOLEAN -> TYPE_BOOL
"DATE": FieldDescriptorProto.TYPE_STRING, # DATE -> TYPE_STRING
"DATETIME": FieldDescriptorProto.TYPE_STRING, # DATETIME -> TYPE_STRING
"TIME": FieldDescriptorProto.TYPE_STRING, # TIME -> TYPE_STRING
"TIMESTAMP": FieldDescriptorProto.TYPE_INT64, # TIMESTAMP -> TYPE_INT64 (Unix epoch time)
"RECORD": FieldDescriptorProto.TYPE_MESSAGE, # RECORD -> TYPE_MESSAGE (nested message)
}

return type_map.get(bq_type, FieldDescriptorProto.TYPE_STRING) # Default to TYPE_STRING


def get_proto_schema_and_class(
bq_schema: List[SchemaField], clustering_keys: List[str] = None
) -> Tuple[ProtoSchema, Type[Message]]:
"""Generate a ProtoSchema and a TableRow class for unnested BigQuery schema."""
# Define the FileDescriptorProto
file_descriptor_proto = FileDescriptorProto()
file_descriptor_proto.name = "dynamic.proto"
Expand All @@ -26,32 +50,14 @@ def get_proto_schema_and_class(clustering_keys: List[str] = None) -> Tuple[Proto

# https://stackoverflow.com/questions/70489919/protobuf-type-for-bigquery-timestamp-field
fields = [
{"name": "_bizon_id", "type": FieldDescriptorProto.TYPE_STRING, "label": FieldDescriptorProto.LABEL_REQUIRED},
{
"name": "_bizon_extracted_at",
"type": FieldDescriptorProto.TYPE_STRING,
"label": FieldDescriptorProto.LABEL_REQUIRED,
},
{
"name": "_bizon_loaded_at",
"type": FieldDescriptorProto.TYPE_STRING,
"label": FieldDescriptorProto.LABEL_REQUIRED,
},
{
"name": "_source_record_id",
"type": FieldDescriptorProto.TYPE_STRING,
"label": FieldDescriptorProto.LABEL_REQUIRED,
},
{
"name": "_source_timestamp",
"type": FieldDescriptorProto.TYPE_STRING,
"label": FieldDescriptorProto.LABEL_REQUIRED,
},
{
"name": "_source_data",
"type": FieldDescriptorProto.TYPE_STRING,
"label": FieldDescriptorProto.LABEL_OPTIONAL,
},
"name": col.name,
"type": map_bq_type_to_field_descriptor(col.field_type),
"label": (
FieldDescriptorProto.LABEL_REQUIRED if col.mode == "REQUIRED" else FieldDescriptorProto.LABEL_OPTIONAL
),
}
for col in bq_schema
]

if clustering_keys:
Expand Down
3 changes: 1 addition & 2 deletions bizon/destinations/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ class DestinationTypes(str, Enum):
class DestinationColumn(BaseModel, ABC):
name: str = Field(..., description="Name of the column")
type: str = Field(..., description="Type of the column")
mode: Optional[str] = Field(..., description="Mode of the column")
description: Optional[str] = Field(..., description="Description of the column")
description: Optional[str] = Field(None, description="Description of the column")


class AbstractDestinationDetailsConfig(BaseModel):
Expand Down
Loading

0 comments on commit f7f586a

Please sign in to comment.