diff --git a/README.md b/README.md index 405d534..0de4c92 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,7 @@ There is a single option (for now), to output the data types in lowercase in the # TODO - Error handling +- Unit testing - Merging with existing yaml definition files - Generate the files for a complete dataset rather than a single table - Option to output to stdout diff --git a/bq2dbt/bq2dbt.py b/bq2dbt/bq2dbt.py old mode 100644 new mode 100755 index eef0980..6627c74 --- a/bq2dbt/bq2dbt.py +++ b/bq2dbt/bq2dbt.py @@ -9,6 +9,7 @@ import argparse import os import logging +import re from google.cloud import bigquery @@ -18,15 +19,34 @@ ) logger = logging.getLogger(__name__) +case_convert_regex = re.compile(r'(? str: + """ + Converts a string from CamelCase to snake_case. + + Args: + input_string (str): The CamelCase string to be converted. + + Returns: + str: The string converted to snake_case. + """ + return case_convert_regex.sub('_', input_string).lower() def bq2dbt(): parser = argparse.ArgumentParser(description="Generate YAML and SQL output for a BigQuery table.") parser.add_argument("table_id", help="Complete BigQuery table ID (project.dataset.table)") parser.add_argument("-l", "--lower", action="store_true", help="Lowercase type names in YAML file") + parser.add_argument("--snake", action="store_true", help="Convert field names to snake_case") + parser.add_argument("--prefix", help="Prefix to add to columns names", default=None) + parser.add_argument("--suffix", help="Suffix to add to column names", default=None) args = parser.parse_args() project_id, dataset_id, table_name = args.table_id.split(".") + prefix = args.prefix + suffix = args.suffix logger.info(f"Starting generation of YAML and SQL for table {args.table_id}...") @@ -91,8 +111,12 @@ def bq2dbt(): # Iterate through the query results and add them to the YAML data for field in fields: data_type = field.data_type.split('<')[0] + + destination = convert_to_snake_case(field.field_path) if args.snake else field.field_path + destination = "_".join(filter(None, [prefix, destination, suffix])) + field_info = { - "name": field.field_path, + "name": destination, "data_type": data_type.lower() if args.lower else data_type, "description": field.description } @@ -105,7 +129,10 @@ def bq2dbt(): yaml_data["models"][0]["columns"].append(field_info) if '.' not in field.field_path: - sql_columns.append(f"`{field.field_path}`") + if destination != field.field_path: + sql_columns.append(f"`{field.field_path}` AS `{destination}`") + else: + sql_columns.append(f"`{field.field_path}`") # Generate the YAML output yaml_output = yaml.dump(yaml_data, default_flow_style=False, sort_keys=False) @@ -114,7 +141,7 @@ def bq2dbt(): sql_columns_statement = "\n , ".join(sql_columns) sql_output = f""" SELECT - {sql_columns_statement} + {sql_columns_statement} FROM `{project_id}.{dataset_id}.{table_name}` -- Don't leave this in your DBT Statement """