forked from GoogleCloudPlatform/professional-services
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbigquery_component.py
executable file
·642 lines (542 loc) · 26.5 KB
/
bigquery_component.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
# Copyright 2019 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module to handle BigQuery related utilities such as creating a client,
validating the existence of dataset & table, creating a table, starting a
load job, updating the status of a running load job etc."""
import csv
import logging
import os
import time
from uuid import uuid4
from google.api_core import exceptions
from google.auth import exceptions as auth_exceptions
from google.cloud import bigquery
import custom_exceptions
from gcp_service import GCPService
from properties_reader import PropertiesReader
logger = logging.getLogger('Hive2BigQuery')
class BigQueryComponent(GCPService):
"""Creates BigQuery client and provides various utility functions.
Has utilities which do BigQuery operations using the BigQuery client,
such as creating table, retrieving dataset location, creating load job,
fetch status of a running load job, create comparison metrics table once
migrated successfully etc.
Attributes:
project_id (str): GCP Project ID.
client (google.cloud.bigquery.client.Client): BigQuery client.
"""
def __init__(self, project_id):
logger.debug("Initializing BigQuery Component")
super(BigQueryComponent, self).__init__(project_id, "BigQuery service")
def get_client(self):
"""Creates BigQuery client.
Returns:
google.cloud.bigquery.client.Client: BigQuery client
"""
logger.debug("Getting BigQuery client")
try:
client = bigquery.Client(project=self.project_id)
return client
except auth_exceptions.DefaultCredentialsError as error:
raise custom_exceptions.ConnectionError from error
def check_dataset_exists(self, dataset_id):
"""Checks whether the provided BigQuery dataset exists.
Args:
dataset_id (str): BigQuery dataset id.
Returns:
boolean: True if exists, False if doesn't exist.
"""
dataset_ref = self.client.dataset(dataset_id)
try:
self.client.get_dataset(dataset_ref)
return True
except exceptions.NotFound as error:
logger.exception(error)
return False
def check_bq_table_exists(self, dataset_id, table_name):
"""Checks whether the provided BigQuery table exists.
Args:
dataset_id (str): BigQuery dataset id.
table_name (str): BigQuery table name.
Returns:
boolean: True if exists, False if doesn't exist.
"""
table_ref = self.client.dataset(dataset_id).table(table_name)
try:
self.client.get_table(table_ref)
return True
except exceptions.NotFound as error:
return False
def get_dataset_location(self, dataset_id):
"""BigQuery dataset location.
Args:
dataset_id (str): BigQuery dataset id.
Returns:
str: Location of the dataset.
"""
dataset_ref = self.client.dataset(dataset_id)
return self.client.get_dataset(dataset_ref).location
def create_table(self, dataset_id, table_name, schema):
"""Creates BigQuery table.
Args:
dataset_id (str): BigQuery dataset id.
table_name (str): BigQuery table name.
schema (List[google.cloud.bigquery.schema.SchemaField]): Schema
of the table to be created.
"""
dataset_ref = self.client.dataset(dataset_id)
table_ref = dataset_ref.table(table_name)
table = bigquery.Table(table_ref, schema)
self.client.create_table(table)
def delete_table(self, dataset_id, table_name):
"""Deletes BigQuery table.
Args:
dataset_id (str): BigQuery dataset id.
table_name (str): BigQuery table name.
"""
table_ref = self.client.dataset(dataset_id).table(table_name)
try:
self.client.delete_table(table_ref)
logger.debug("Deleted table %s from %s dataset", table_name,
dataset_id)
except exceptions.NotFound as error:
logger.debug(error)
logger.debug("Table %s not found in %s dataset. No need to delete",
table_name, dataset_id)
def check_bq_write_mode(self, mysql_component, hive_table_model,
bq_table_model):
"""Validates the bq_table_write_mode provided by user.
If the mode is overwrite, drops the tracking table and deletes the
BigQuery table. If the mode is create, checks if the tracking table
and BigQuery table exist. If the mode is append, checks whether the
BigQuery table exists.
Args:
mysql_component (:class:`MySQLComponent`): Instance of
MySQLComponent to connect to MySQL.
hive_table_model (:class:`HiveTableModel`): Wrapper to Hive table
details.
bq_table_model (:class:`BigQueryTableModel`): Wrapper to BigQuery
table details.
Returns:
boolean: True if the write mode is okay to use, else False.
"""
write_mode = PropertiesReader.get('bq_table_write_mode')
if write_mode == "overwrite":
logger.debug("Deleting tracking table and BigQuery table...")
mysql_component.drop_table(hive_table_model.tracking_table_name)
mysql_component.update_tracking_meta_table(hive_table_model, "DELETE")
self.delete_table(bq_table_model.dataset_id,
bq_table_model.table_name)
hive_table_model.is_first_run = True
elif write_mode == "create":
if not hive_table_model.is_first_run:
raise exceptions.AlreadyExists(
"Tracking Table {} already exist".format(
hive_table_model.tracking_table_name))
if self.check_bq_table_exists(bq_table_model.dataset_id,
bq_table_model.table_name):
raise exceptions.AlreadyExists(
"BigQuery Table {} already exist in {} dataset".format(
bq_table_model.table_name, bq_table_model.dataset_id))
else:
if hive_table_model.is_first_run is False:
query = "SELECT COUNT(*) FROM {} WHERE " \
"bq_job_status='RUNNING' OR " \
"bq_job_status='DONE'".format(
hive_table_model.tracking_table_name)
results = mysql_component.execute_query(query)
if results[0][0] != 0:
if not self.check_bq_table_exists(bq_table_model.dataset_id,
bq_table_model.table_name):
raise exceptions.NotFound(
"Found the tracking table but BigQuery Table {} "
"doesn't exist in {} dataset. Clean up the "
"resources and try again".format(
bq_table_model.table_name,
bq_table_model.dataset_id))
else:
raise exceptions.NotFound(
"Tracking Table {} doesn't exist".format(
hive_table_model.tracking_table_name))
def start_load_job(self, bq_table_model, source_uri, job_id):
"""Starts BigQuery load job asynchronously.
Starts a load job with given job ID for loading data into BigQuery
table from the given GCS URI.
Args:
bq_table_model (:class:`BigQueryTableModel`): Wrapper to BigQuery
table details.
source_uri (str): URI of data file to be loaded.
job_id (str): BigQuery job ID.
"""
dataset_ref = self.client.dataset(bq_table_model.dataset_id)
# Configures the load job.
# job_config is of type: google.cloud.bigquery.job.LoadJobConfig
job_config = bigquery.LoadJobConfig()
if bq_table_model.is_partitioned:
# Specifies time-based partitioning for the table.
job_config.time_partitioning = bigquery.table.TimePartitioning(
type_=bigquery.table.TimePartitioningType.DAY,
field=bq_table_model.partition_column)
if bq_table_model.is_clustered:
# Fields defining clustering for the table.
job_config.clustering_fields = bq_table_model.clustering_columns
# Configures the file format of the data.
if bq_table_model.data_format == "ORC":
job_config.source_format = bigquery.SourceFormat.ORC
elif bq_table_model.data_format == "Parquet":
job_config.source_format = bigquery.SourceFormat.PARQUET
else:
job_config.source_format = bigquery.SourceFormat.AVRO
job_config.use_avro_logical_types = True
# Creates load job
self.client.load_table_from_uri(source_uri, dataset_ref.table(
bq_table_model.table_name), job_config=job_config, job_id=job_id)
def get_bq_table_row_count(self, bq_table_model, clause=''):
"""Queries the migrated BigQuery table to get a count of rows.
Args:
bq_table_model (:class:`BigQueryTableModel`): Wrapper to BigQuery
table details.
clause (str): where clause to filter the table on partitions, if any.
Returns:
int: Number of rows as an output from the query.
"""
query = "SELECT COUNT(*) AS n_rows FROM {0}.{1} {2}".format(
bq_table_model.dataset_id, bq_table_model.table_name, clause)
query_job = self.client.query(query)
results = query_job.result()
for row in results:
n_rows = row.n_rows
return n_rows
def load_gcs_to_bq(self, mysql_component, hive_table_model, bq_table_model):
"""Loads data from GCS to BigQuery.
Queries the tracking table and fetches information about the files
that have been copied to GCS and are ready to be loaded into
BigQuery, creates loading jobs and updates the job ID & job status in
the tracking table.
Args:
mysql_component (:class:`MySQLComponent`): Instance of
MySQLComponent to connect to MySQL.
hive_table_model (:class:`HiveTableModel`): Wrapper to Hive table
details.
bq_table_model (:class:`BigQueryTableModel`): Wrapper to BigQuery
table details.
"""
logger.info(
"Fetching information about files to load to BigQuery from "
"tracking table...")
query = "SELECT gcs_file_path FROM {} WHERE gcs_copy_status='DONE' " \
"AND bq_job_status='TODO'".format(
hive_table_model.tracking_table_name)
results = mysql_component.execute_query(query)
if not results:
logger.info("No gcs files to load to BigQuery")
for row in results:
gcs_source_uri = row[0]
bq_job_id = "hive2bq-job-{}".format(uuid4())
# Starts the load job asynchronously.
self.start_load_job(bq_table_model, gcs_source_uri, bq_job_id)
# Updates the job status as RUNNING.
query = "UPDATE {0} SET bq_job_id='{1}',bq_job_status='RUNNING' " \
"WHERE gcs_file_path='{2}'".format(
hive_table_model.tracking_table_name,
bq_job_id, gcs_source_uri)
mysql_component.execute_transaction(query)
logger.info(
"Updated BigQuery load job ID {} status TODO --> RUNNING for "
"file path {}".format(
bq_job_id, gcs_source_uri))
def update_bq_job_status(self, mysql_component, gcs_component,
hive_table_model, bq_table_model, gcs_bucket_name):
"""Updates the status of running BigQuery load jobs.
Queries the tracking table and fetches information about the load
jobs that are 'RUNNING 'and polls the job status. If the job has
finished successfully with no errors, updates the status as 'DONE'.
In case of job completion with errors, it updates the status to
'TODO' and increases the bq_job_retries count by 1. Waits for 1
minute and polls the job status again, until all the load jobs finish.
Args:
mysql_component (:class:`MySQLComponent`): Instance of
MySQLComponent to connect to MySQL.
gcs_component (:class:`GCSStorageComponent`): Instance of
GCSStorageComponent to do GCS operations.
hive_table_model (:class:`HiveTableModel`): Wrapper to Hive table
details.
bq_table_model (:class:`BigQueryTableModel`): Wrapper to BigQuery
table details.
gcs_bucket_name (str): GCS bucket name.
"""
# Uodate this value to increase the maximum number of load job retries.
bq_load_job_max_retries = 3
logger.info(
"Fetching information about BigQuery load jobs from tracking "
"table...")
query = "SELECT gcs_file_path,bq_job_id,bq_job_retries FROM {} WHERE " \
"bq_job_status='RUNNING'".format(
hive_table_model.tracking_table_name)
results = mysql_component.execute_query(query)
if not results:
logger.info(
"No BigQuery job is in RUNNING state. No values to update")
# Waits till all the load jobs finish.
while results:
count = 0
for row in results:
gcs_file_path, bq_job_id, bq_job_retries = row
# Gets information about the running job.
job = self.client.get_job(bq_job_id,
location=self.get_dataset_location(
bq_table_model.dataset_id))
if job.state == 'DONE':
# Job finished successfully.
if job.errors is None:
query = "UPDATE {0} SET bq_job_status='DONE' WHERE " \
"bq_job_id='{1}'".format(
hive_table_model.tracking_table_name,
bq_job_id)
mysql_component.execute_transaction(query)
logger.info(
"Updated BigQuery load job {} status RUNNING --> "
"DONE".format(bq_job_id))
# Deletes the data file in GCS.
gcs_component.delete_file(gcs_bucket_name,
gcs_file_path)
# Job finished with error.
elif job.errors is not None:
if bq_job_retries == bq_load_job_max_retries:
query = "UPDATE {0} SET bq_job_status='FAILED' " \
"WHERE bq_job_id='{1}'".format(
hive_table_model.tracking_table_name,
bq_job_id)
mysql_component.execute_transaction(query)
logger.info(
"BigQuery job {} failed.Tried for a maximum "
"of 3 times.Updated status RUNNING --> "
"FAILED".format(bq_job_id))
else:
query = "UPDATE {0} SET bq_job_status='TODO'," \
"bq_job_retries={1} WHERE " \
"bq_job_id='{2}'".format(
hive_table_model.tracking_table_name,
bq_job_retries + 1, bq_job_id)
mysql_component.execute_transaction(query)
logger.info(
"BigQuery job {} failed.Updated status "
"RUNNING --> TODO & increased retries count "
"by 1".format(bq_job_id))
elif job.state == 'RUNNING':
# Count of jobs which are still in running state.
count += 1
else:
logger.debug(
"job id %s job state %s", bq_job_id, job.state)
if count == 0:
logger.info(
"No BigQuery job is in RUNNING state. No values to update")
break
logger.info("Waiting for 1 min..")
time.sleep(60)
logger.info(
"Fetching information about BigQuery load jobs from tracking "
"table...")
query = "SELECT gcs_file_path,bq_job_id,bq_job_retries FROM {} " \
"WHERE bq_job_status='RUNNING'".format(
hive_table_model.tracking_table_name)
results = mysql_component.execute_query(query)
@staticmethod
def generate_metrics_table_schema(columns_list):
"""Creates schema for the BigQuery comparison metrics table.
Args:
columns_list (List[str]): List of column names.
Returns:
List[google.cloud.bigquery.schema.SchemaField]: Schema of the
comparison metrics table.
"""
schema = [
bigquery.SchemaField(
'operation', 'STRING', mode='REQUIRED',
description='operation'),
bigquery.SchemaField(
'table_name', 'STRING', mode='REQUIRED',
description='Table name'),
bigquery.SchemaField(
'column_count', 'STRING', mode='REQUIRED',
description='Number of columns'),
]
for col in columns_list:
schema.append(
bigquery.SchemaField(str(col), 'STRING', mode='REQUIRED'))
return schema
@staticmethod
def analyze_hive_table(hive_table_model):
"""Gets information about the Hive table.
Args:
hive_table_model (:class:`HiveTableModel`): Wrapper to Hive table
details.
Returns:
dict: A dictionary of metrics about the Hive table.
"""
table_analysis = dict()
table_analysis['operation'] = "Hive"
table_analysis['table_name'] = hive_table_model.table_name
table_analysis['num_cols'] = str(hive_table_model.n_cols)
table_analysis['schema'] = hive_table_model.flat_schema
return table_analysis
@staticmethod
def analyze_bq_table(bq_table_model):
"""Gets information about the BigQuery table.
Args:
bq_table_model (:class:`BigQueryTableModel`): Wrapper to BigQuery
table details.
Returns:
dict: A dictionary of metrics about the BigQuery table.
"""
table_analysis = dict()
table_analysis['operation'] = "BigQuery"
table_analysis['table_name'] = bq_table_model.table_name
table_analysis['num_cols'] = str(bq_table_model.n_cols)
table_analysis['schema'] = bq_table_model.flat_schema
return table_analysis
@staticmethod
def append_row_to_metrics_file(filename, row, columns_list):
"""Writes comparison metrics row to CSV file.
Args:
filename (str): Comparison metrics CSV filename.
row (dict): Data to be written to CSV file.
columns_list (List[str]): List of flattened column names.
"""
data = [row['operation'], row['table_name'], row['num_cols']]
for item in columns_list:
data.append(row['schema'][item])
with open(filename, 'a+') as csv_file:
writer = csv.writer(csv_file)
writer.writerow(data)
@staticmethod
def do_health_checks(hive_table_analysis,
bq_table_analysis, columns_list):
"""Populates the Health checks values by comparing Hive and BigQuery
tables.
Args:
hive_table_analysis (dict): A dictionary of metrics about the
Hive table.
bq_table_analysis (dict): A dictionary of metrics about the
BigQuery table.
columns_list (List[str]): List of flattened column names.
"""
def read_validations():
"""Reads the set of Hive-BigQuery data type validation rules into a
list."""
validations_csv_filename = 'validations.csv'
with open(validations_csv_filename, 'r') as file_content:
reader = csv.reader(file_content)
validations_list = [row for row in reader]
return validations_list
logger.debug("Reading the validations CSV file")
validation_rules = read_validations()
healths = {
"operation": "Health Check",
"table_name": "NA",
"num_cols": "Fail",
"schema": {}
}
if hive_table_analysis['num_cols'] == bq_table_analysis['num_cols']:
healths["num_cols"] = "Pass"
for item in columns_list:
# Reduce strings of array_ in the data type field.
if 'array_' in hive_table_analysis['schema'][item]:
hive_table_analysis['schema'][item] = '_'.join(
hive_table_analysis['schema'][item].split('_')[-2:])
if ([hive_table_analysis['schema'][item],
bq_table_analysis['schema'][item]] in validation_rules):
healths['schema'][str(item)] = "Pass"
else:
healths['schema'][str(item)] = "Fail"
return healths
def load_csv_to_bigquery(self, csv_uri, dataset_id, table_name):
"""Loads metrics CSV data to BigQuery comparison table.
Args:
csv_uri (str): Cloud Storage URI of the metrics CSV file.
dataset_id (str): BigQuery dataset ID.
table_name (str): BigQuery comparison metrics table name.
"""
dataset_ref = self.client.dataset(dataset_id)
# Load job configuration.
job_config = bigquery.LoadJobConfig()
# Source format is set to CSV.
job_config.source_format = bigquery.SourceFormat.CSV
# Start load job.
load_job = self.client.load_table_from_uri(csv_uri, dataset_ref.table(
table_name), job_config=job_config)
logger.info('Loading metrics data to BigQuery... Job {}'.format(
load_job.job_id))
# wait for the job to completed.
load_job.result()
destination_table = self.client.get_table(dataset_ref.table(table_name))
logger.info(
"Loaded {} rows in metrics table\nMigrated data successfully from "
"Hive to BigQuery\nComparison metrics of tables available in "
"BigQuery table {}".format(
destination_table.num_rows, table_name))
def write_metrics_to_bigquery(self, gcs_component,
hive_table_model, bq_table_model):
"""Writes comparison metrics to BigQuery.
Flattens the schema of both the Hive table and BigQuery table,
reads the data types validation rules list, does health checks of the
migration and loads the metrics data into a BigQuery comparison table.
Args:
gcs_component (:class:`GCSStorageComponent`): Instance of
GCSStorageComponent to do GCS operations.
hive_table_model (:class:`HiveTableModel`): Wrapper to Hive table
details.
bq_table_model (:class:`BigQueryTableModel`): Wrapper to BigQuery
table details.
"""
metrics_table_name = PropertiesReader.get('hive_bq_comparison_table')
metrics_csv_filename = PropertiesReader.get('hive_bq_comparison_csv')
logger.info("Analyzing the Hive and BigQuery tables...")
# Flattens the Hive table schema and writes row to CSV file.
flat_list_columns = hive_table_model.flat_schema.keys()
hive_table_analysis = self.analyze_hive_table(hive_table_model)
self.append_row_to_metrics_file(metrics_csv_filename,
hive_table_analysis, flat_list_columns)
logger.debug("Analyzed Hive table metrics")
# Flattens the BigQuery table schema and writes row to CSV file.
bq_table_analysis = self.analyze_bq_table(bq_table_model)
self.append_row_to_metrics_file(metrics_csv_filename, bq_table_analysis,
flat_list_columns)
logger.debug("Analyzed BigQuery table metrics")
# Does Health checks by comparing Hive and BigQuery metrics.
healths = self.do_health_checks(hive_table_analysis,
bq_table_analysis, flat_list_columns)
self.append_row_to_metrics_file(metrics_csv_filename, healths,
flat_list_columns)
logger.debug("Health checks are done")
logger.debug("Getting metrics table schema")
logger.debug("Creating BigQuery metrics table")
self.create_table(bq_table_model.dataset_id, metrics_table_name,
self.generate_metrics_table_schema(flat_list_columns))
# Uploads metrics CSV file to GCS bucket.
blob_name = "BQ_staging/{}".format(metrics_csv_filename)
csv_uri = gcs_component.upload_file(
PropertiesReader.get('gcs_bucket_name'), metrics_csv_filename,
blob_name)
logger.debug("metrics CSV file is uploaded at %s", csv_uri)
# Deletes local csv file.
os.remove(metrics_csv_filename)
# Loads CSV file to BigQuery metrics table.
self.load_csv_to_bigquery(csv_uri, bq_table_model.dataset_id,
metrics_table_name)
# Deletes uploaded metrics CSV file in GCS bucket.
gcs_component.delete_file(PropertiesReader.get('gcs_bucket_name'),
blob_name)
logger.debug("Deleting metrics CSV file at %s", csv_uri)