Skip to content

Commit

Permalink
refactor event types (#87)
Browse files Browse the repository at this point in the history
  • Loading branch information
tovganesh authored Dec 1, 2022
1 parent 4c36c66 commit 37d3f25
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 49 deletions.
45 changes: 26 additions & 19 deletions dbt/adapters/hive/cloudera_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@
from dbt.events import AdapterLogger
from decouple import config

# all event types
class TrackingEventType:
DEBUG = "debug_and_fetch_permission"
OPEN = "open"
CLOSE = "close"
START_QUERY = "start_query"
END_QUERY = "end_query"
INCREMENTAL = "incremental"
MODEL_ACCESS = "model_access"

# global logger
logger = AdapterLogger("Tracker")

Expand All @@ -45,6 +55,9 @@
# Json object to store dbt deployment environment variables
dbt_deployment_env_info = {}

# Json object for warehouse information
warehouse_info = { "warehouse_version": { "version": "NA", "build": "NA" } }

def populate_platform_info(cred: Credentials, ver):
"""
populate platform info to be passed on for tracking
Expand All @@ -57,9 +70,7 @@ def populate_platform_info(cred: Credentials, ver):
platform_info["system"] = platform.system()
# Architecture e.g. x86_64 ,arm, AMD64
platform_info["machine"] = platform.machine()
# Full platform info e.g.
# Linux-2.6.32-32-server-x86_64-with-Ubuntu-10.04-lucid,
# Windows-2008ServerR2-6.1.7601-SP1
# Full platform info e.g Linux-2.6.32-32-server-x86_64-with-Ubuntu-10.04-lucid,Windows-2008ServerR2-6.1.7601-SP1
platform_info["platform"] = platform.platform()
# dbt core version
platform_info[
Expand All @@ -76,24 +87,24 @@ def populate_dbt_deployment_env_info():
default_value = "{}" # if environment variables doesn't exist add empty json as default
dbt_deployment_env_info["dbt_deployment_env"] = json.loads(os.environ.get('DBT_DEPLOYMENT_ENV', default_value))

def populate_unique_ids(cred: Credentials):
def populate_unique_ids(cred: Credentials, userkey="username"):
host = str(cred.host).encode()
user = str(cred.username).encode()
user = str(getattr(cred, userkey)).encode()
timestamp = str(time.time()).encode()

# dbt invocation id
if active_user:
unique_ids["id"] = active_user.invocation_id
else:
unique_ids["id"] = "N/A"

# hashed host name
unique_ids["unique_host_hash"] = hashlib.md5(host).hexdigest()
# hashed username
unique_ids["unique_user_hash"] = hashlib.md5(user).hexdigest()
# hashed session
unique_ids["unique_session_hash"] = hashlib.md5(host + user + timestamp).hexdigest()


def generate_profile_info(self):
if not profile_info:
# name of dbt project in profiles
Expand All @@ -103,13 +114,15 @@ def generate_profile_info(self):
# number of threads in profiles
profile_info["no_of_threads"] = self.profile.threads

def populate_warehouse_info(w_info):
warehouse_info["warehouse_version"]["version"] = w_info["version"]
warehouse_info["warehouse_version"]["build"] = w_info["build"]

def _merge_keys(source_dict, dest_dict):
for key, value in source_dict.items():
dest_dict[key] = value
return dest_dict


def _get_sql_type(sql):
if not sql:
return ""
Expand All @@ -121,11 +134,10 @@ def _get_sql_type(sql):
else:
sql_words = words[0].strip().split()

sql_type = " ".join(sql_words[: min(2, len(sql_words))]).lower()
sql_type = " ".join(sql_words[:min(2, len(sql_words))]).lower()

return sql_type


def fix_tracking_payload(given_payload):
"""
The payload for an event
Expand All @@ -140,7 +152,7 @@ def fix_tracking_payload(given_payload):
if "sql" in desired_payload:
desired_payload["sql_type"] = _get_sql_type(desired_payload["sql"])
del desired_payload["sql"]

desired_keys = [
"auth",
"connection_state",
Expand All @@ -150,7 +162,7 @@ def fix_tracking_payload(given_payload):
"model_type",
"permissions",
"profile_name",
"sql_type",
"sql_type"
]

for key in desired_keys:
Expand All @@ -176,10 +188,7 @@ def track_usage(tracking_payload):

global usage_tracking

logger.debug(
f"Usage tracking flag {usage_tracking}."
f"To turn on/off use usage_tracking flag in profiles.yml"
)
logger.debug(f"Usage tracking flag {usage_tracking}. To turn on/off use usage_tracking flag in profiles.yml")

# if usage_tracking is disabled, quit
if not usage_tracking:
Expand All @@ -193,6 +202,7 @@ def track_usage(tracking_payload):
tracking_payload = _merge_keys(platform_info, tracking_payload)
tracking_payload = _merge_keys(dbt_deployment_env_info, tracking_payload)
tracking_payload = _merge_keys(profile_info, tracking_payload)
tracking_payload = _merge_keys(warehouse_info, tracking_payload)

# form the tracking data
tracking_data = {"data": json.dumps(tracking_payload)}
Expand Down Expand Up @@ -227,10 +237,7 @@ def _tracking_func(data):

try:
res = requests.post(
SNOWPLOW_ENDPOINT,
data=data,
headers=headers,
timeout=SNOWPLOW_TIMEOUT
SNOWPLOW_ENDPOINT, data=data, headers=headers, timeout=SNOWPLOW_TIMEOUT
)
except Exception as err:
logger.debug(f"Usage tracking error. {err}")
Expand Down
37 changes: 33 additions & 4 deletions dbt/adapters/hive/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ def description(self):
class HiveConnectionManager(SQLConnectionManager):
TYPE = "hive"

hive_version = None

def __init__(self, profile: AdapterRequiredConfig):
super().__init__(profile)
# generate profile related object for instrumentation.
Expand Down Expand Up @@ -225,6 +227,8 @@ def open(cls, connection):
connection.state = ConnectionState.OPEN
connection.handle = HiveConnectionWrapper(hive_conn)
connection.handle.cursor()

HiveConnectionManager.fetch_hive_version(connection.handle)
except Exception as exc:
logger.debug("Connection error: {}".format(exc))
connection_ex = exc
Expand All @@ -234,7 +238,7 @@ def open(cls, connection):

# track usage
payload = {
"event_type": "dbt_hive_open",
"event_type": tracker.TrackingEventType.OPEN,
"auth": auth_type,
"connection_state": connection.state,
"elapsed_time": "{:.2f}".format(
Expand Down Expand Up @@ -271,6 +275,31 @@ def exception_handler(self, sql: str):
def cancel(self, connection):
connection.handle.cancel()

@classmethod
def fetch_hive_version(cls, connection):

if HiveConnectionManager.hive_version:
return HiveConnectionManager.hive_version

try:
sql = "select version()"
cursor = connection.cursor()
cursor.execute(sql)

res = cursor.fetchall()

HiveConnectionManager.hive_version = res[0][0].split(".")[0].strip()

tracker.populate_warehouse_info({ "version": HiveConnectionManager.hive_version, "build": res[0][0] })
except Exception as ex:
# we couldn't get the hive warehouse version
logger.debug(f"Cannot get hive version. Error: {ex}")
HiveConnectionManager.impala_version = "NA"

tracker.populate_warehouse_info({ "version": HiveConnectionManager.hive_version, "build": "NA" })

logger.debug(f"HIVE VERSION {'HiveConnectionManager.hive_version'}")

@classmethod
def close(cls, connection):
try:
Expand All @@ -283,7 +312,7 @@ def close(cls, connection):
connection_close_end_time = time.time()

payload = {
"event_type": "dbt_hive_close",
"event_type": tracker.TrackingEventType.CLOSE,
"connection_state": ConnectionState.CLOSED,
"elapsed_time": "{:.2f}".format(
connection_close_end_time - connection_close_start_time
Expand Down Expand Up @@ -330,7 +359,7 @@ def add_query(

# track usage
payload = {
"event_type": "dbt_hive_start_query",
"event_type": tracker.TrackingEventType.START_QUERY,
"sql": log_sql,
"profile_name": self.profile.profile_name
}
Expand Down Expand Up @@ -368,7 +397,7 @@ def add_query(
elapsed_time = time.time() - pre

payload = {
"event_type": "dbt_hive_end_query",
"event_type": tracker.TrackingEventType.END_QUERY,
"sql": log_sql,
"elapsed_time": "{:.2f}".format(elapsed_time),
"status": query_status,
Expand Down
25 changes: 1 addition & 24 deletions dbt/adapters/hive/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def debug_query(self) -> None:
permissions_json = permissions_object

payload = {
"event_type": "dbt_hive_debug_and_fetch_permissions",
"event_type": tracker.TrackingEventType.DEBUG,
"permissions": permissions_json,
}
tracker.track_usage(payload)
Expand All @@ -427,29 +427,6 @@ def debug_query(self) -> None:
)
self.connections.get_thread_connection().handle.close()

# query warehouse version
try:
sql_query = "select version()"
_, table = self.execute(sql_query, True, True)
version_object = []
json_funcs = [c.jsonify for c in table.column_types]

for row in table.rows:
values = tuple(json_funcs[i](d) for i, d in enumerate(row))
version_object.append(OrderedDict(zip(row.keys(), values)))

version_json = version_object

payload = {
"event_type": "dbt_hive_warehouse",
"warehouse_version": version_json,
}
tracker.track_usage(payload)
except Exception as ex:
logger.error(
f"Failed to fetch warehouse version. Exception: {ex}"
)

self.connections.get_thread_connection().handle.close()

###
Expand Down
4 changes: 2 additions & 2 deletions dbt/adapters/hive/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __post_init__(self):
if self.type:
tracker.track_usage(
{
"event_type": "dbt_hive_model_access",
"event_type": tracker.TrackingEventType.MODEL_ACCESS,
"model_name": self.render(),
"model_type": self.type,
"incremental_strategy": "",
Expand All @@ -72,7 +72,7 @@ def log_relation(self, incremental_strategy):
if self.type:
tracker.track_usage(
{
"event_type": "dbt_hive_new_incremental",
"event_type": tracker.TrackingEventType.INCREMENTAL,
"model_name": self.render(),
"model_type": self.type,
"incremental_strategy": incremental_strategy,
Expand Down

0 comments on commit 37d3f25

Please sign in to comment.