Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve README, code refactoring #2

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,28 @@
# dbt-cloud-plugin
DBT Cloud Plugin for Airflow

## Configuration

Copy the `dbt_cloud_plugin` directory in Airflow's `plugin` directory.

Create a new connection with the following dictionary as the `Extra` parameter. Leave connection type blank.
```
{
"dbt_cloud_api_token": "123abcdefg456",
"dbt_cloud_account_id": 12345678
}
```

In order to obtain your API token, log into your [dbt Cloud Account](https://cloud.getdbt.com), click on your Avatar in the top right corner, then `My Account` and finally on `API Access` in the left bar.

Note: API Access is not available on the _Free_ plan.


In order to test if the connection is set up correctly, log onto the Airflow shell and run

`airflow test --dry_run dbt_cloud_dag run_dbt_cloud_job 2019-01-01`



----
MIT License
4 changes: 3 additions & 1 deletion dbt_cloud_plugin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from airflow.plugins_manager import AirflowPlugin

from dbt_cloud_plugin.hooks.dbt_cloud_hook import DbtCloudHook
from dbt_cloud_plugin.operators.dbt_cloud_run_job_operator import DbtCloudRunJobOperator
from dbt_cloud_plugin.sensors.dbt_cloud_run_sensor import DbtCloudRunSensor
from dbt_cloud_plugin.sensors.dbt_cloud_job_sensor import DbtCloudRunSensor


class DbtCloudPlugin(AirflowPlugin):
name = "dbt_cloud_plugin"
Expand Down
38 changes: 29 additions & 9 deletions dbt_cloud_plugin/dbt_cloud/dbt_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import json
import requests
import time
from airflow.exceptions import AirflowException


class DbtCloud(object):
"""
Expand All @@ -20,7 +22,7 @@ def __init__(self, account_id, api_token):

def _get(self, url_suffix):
url = self.api_base + url_suffix
headers = {'Authorization': 'Token %s' % self.api_token}
headers = {'Authorization': f'Token {self.api_token}'}
response = requests.get(url, headers=headers)
if response.status_code == 200:
return json.loads(response.content)
Expand All @@ -29,41 +31,59 @@ def _get(self, url_suffix):

def _post(self, url_suffix, data=None):
url = self.api_base + url_suffix
headers = {'Authorization': 'token %s' % self.api_token}
headers = {'Authorization': f'Token {self.api_token}'}
response = requests.post(url, headers=headers, data=data)
if response.status_code == 200:
return json.loads(response.content)
else:
raise RuntimeError(response.content)

def list_jobs(self):
return self._get('/accounts/%s/jobs/' % self.account_id).get('data')
return self._get(
f'/accounts/{self.account_id}/jobs/'
).get('data')

def get_run(self, run_id):
return self._get('/accounts/%s/runs/%s/' % (self.account_id, run_id)).get('data')
return self._get(
f'/accounts/{self.account_id}/runs/{run_id}/'
).get('data')

def trigger_job_run(self, job_id, data=None):
return self._post(url_suffix='/accounts/%s/jobs/%s/run/' % (self.account_id, job_id), data=data).get('data')
return self._post(
url_suffix=f'/accounts/{self.account_id}/jobs/{job_id}/run/',
data=data
).get('data')

def try_get_run(self, run_id, max_tries=3):
for i in range(max_tries):
try:
run = self.get_run(run_id)
return run
except RuntimeError as e:
print("Encountered a runtime error while fetching status for {}".format(run_id))
print(
'Encountered a runtime error while '
f'fetching status for {run_id}'
)
time.sleep(10)

raise RuntimeError("Too many failures ({}) while querying for run status".format(run_id))
raise RuntimeError(
f'Too many failures ({run_id}) while querying for run status'
)

def run_job(self, job_name, data=None):
jobs = self.list_jobs()

job_matches = [j for j in jobs if j['name'] == job_name]

if len(job_matches) != 1:
raise AirflowException("{} jobs found for {}".format(len(job_matches), job_name))
raise AirflowException(
f'{len(job_matches)} jobs found for {job_name}'
)

job_def = job_matches[0]
trigger_resp = self.trigger_job_run(job_id=job_def['id'], data=data)
trigger_resp = self.trigger_job_run(
job_id=job_def['id'],
data=data
)

return trigger_resp
14 changes: 11 additions & 3 deletions dbt_cloud_plugin/hooks/dbt_cloud_hook.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from dbt_cloud_plugin.dbt_cloud.dbt_cloud import DbtCloud
from airflow.hooks.base_hook import BaseHook
from airflow.exceptions import AirflowException

from dbt_cloud_plugin.dbt_cloud.dbt_cloud import DbtCloud


class RunStatus:
queued = 1
dequeued = 2
Expand All @@ -23,6 +25,7 @@ class RunStatus:
def lookup(cls, status):
return cls.LOOKUP.get(status, 'Unknown')


class DbtCloudHook(BaseHook):
"""
Interact with dbt Cloud.
Expand All @@ -36,11 +39,16 @@ def get_conn(self):
if 'dbt_cloud_api_token' in conn.extra_dejson:
dbt_cloud_api_token = conn.extra_dejson['dbt_cloud_api_token']
else:
raise AirflowException('No dbt Cloud API Token was supplied in dbt Cloud connection.')
raise AirflowException(
'No dbt Cloud API Token was supplied in dbt Cloud connection.'
)

if 'dbt_cloud_account_id' in conn.extra_dejson:
dbt_cloud_account_id = conn.extra_dejson['dbt_cloud_account_id']
else:
raise AirflowException('No dbt Cloud Account ID was supplied in dbt Cloud connection.')
raise AirflowException(
'No dbt Cloud Account ID was supplied in dbt Cloud connection.'
)

return DbtCloud(dbt_cloud_account_id, dbt_cloud_api_token)

Expand Down
31 changes: 18 additions & 13 deletions dbt_cloud_plugin/operators/dbt_cloud_run_job_operator.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
# -*- coding: utf-8 -*-
import json
import requests
import time

from airflow.models import BaseOperator
from dbt_cloud_plugin.hooks.dbt_cloud_hook import DbtCloudHook
from airflow.utils.decorators import apply_defaults
from airflow.exceptions import AirflowException

from dbt_cloud_plugin.hooks.dbt_cloud_hook import DbtCloudHook


class DbtCloudRunJobOperator(BaseOperator):
"""
Operator to run a dbt cloud job.
:param dbt_cloud_conn_id: dbt Cloud connection ID.
:type dbt_cloud_conn_id: string
:param project_id: dbt Cloud project ID.
:type project_id: int
:param job_name: dbt Cloud job name.
:type job_name: string
"""
Expand All @@ -27,7 +23,8 @@ def __init__(self,
super(DbtCloudRunJobOperator, self).__init__(*args, **kwargs)

if dbt_cloud_conn_id is None:
raise AirflowException('No valid dbt cloud connection ID was supplied.')
raise AirflowException('No valid dbt Cloud '
'connection ID was supplied.')

if job_name is None:
raise AirflowException('No job name was supplied.')
Expand All @@ -37,15 +34,23 @@ def __init__(self,

def execute(self, **kwargs):

self.log.info('Attempting to trigger a run of dbt cloud job: {}'.format(self.job_name))
self.log.info(
f'Attempting to trigger a run of dbt Cloud job: {self.job_name}'
)

try:
dbt_cloud_hook = DbtCloudHook(dbt_cloud_conn_id=self.dbt_cloud_conn_id)
dbt_cloud = dbt_cloud_hook.get_conn()
data = {'cause':'Kicked off via Airflow'}

data = {'cause': 'Kicked off via Airflow'}
trigger_resp = dbt_cloud.run_job(self.job_name, data=data)
self.log.info('Triggered Run ID {}'.format(trigger_resp['id']))
triggered_run_id = trigger_resp['id']

self.log.info(f'Triggered Run ID {triggered_run_id}')

except RuntimeError as e:
raise AirflowException("Error while triggering job {}: {}".format(self.job_name, e))
raise AirflowException(
f'Error while triggering job {self.job_name}: {e}'
)

return trigger_resp['id']
return triggered_run_id
22 changes: 13 additions & 9 deletions dbt_cloud_plugin/sensors/dbt_cloud_job_sensor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from dbt_cloud_plugin.hooks.dbt_cloud_hook import DbtCloudHook
from airflow.sensors.base_sensor_operator import BaseSensorOperator
from airflow.utils.decorators import apply_defaults
from airflow.exceptions import AirflowException
from airflow.utils.decorators import apply_defaults
from airflow.sensors.base_sensor_operator import BaseSensorOperator

from dbt_cloud_plugin.hooks.dbt_cloud_hook import DbtCloudHook


class DbtCloudRunSensor(BaseSensorOperator):
"""
Expand All @@ -23,25 +25,27 @@ def __init__(self,
super(DbtCloudRunSensor, self).__init__(*args, **kwargs)

if dbt_cloud_conn_id is None:
raise AirflowException('No valid dbt cloud connection ID was supplied.')
raise AirflowException('No valid dbt Cloud connection ID was supplied.')

if run_id is None:
raise AirflowException('No dbt cloud run ID was supplied.')
raise AirflowException('No dbt Cloud Run ID was supplied.')

self.dbt_cloud_conn_id = dbt_cloud_conn_id
self.run_id = run_id

def poke(self, context):
self.log.info('Sensor checking state of dbt cloud run ID: %s', self.run_id)
self.log.info(f'Sensor checking state of dbt Cloud Run ID: {self.run_id}')
dbt_cloud_hook = DbtCloudHook(dbt_cloud_conn_id=self.dbt_cloud_conn_id)
run_status = dbt_cloud_hook.get_run_status(run_id=self.run_id)
self.log.info('State of Run ID {}: {}'.format(self.run_id, run_status))
self.log.info(f'State of Run ID {self.run_id}: {run_status}')

TERMINAL_RUN_STATES = ['Success', 'Error', 'Cancelled']
TERMINAL_RUN_STATES = ['Success', 'Cancelled']
FAILED_RUN_STATES = ['Error']

if run_status in FAILED_RUN_STATES:
return AirflowException('dbt cloud Run ID {} Failed.'.format(self.run_id))
raise AirflowException(
f'dbt Cloud Run ID {self.run_id} failed.'
)
if run_status in TERMINAL_RUN_STATES:
return True
else:
Expand Down
12 changes: 10 additions & 2 deletions examples/dbt_cloud_hourly_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,15 @@
'provide_context': True
}

dag = DAG('dbt_cloud_hourly_dag', concurrency=1, max_active_runs=1, catchup=False, schedule_interval='0 * * * *', default_args=default_args)
dag = DAG(
'dbt_cloud_hourly_dag',
concurrency=1,
max_active_runs=1,
catchup=False,
schedule_interval='0 * * * *',
default_args=default_args
)

dag.doc_md = __doc__

# Run hourly DAG through dbt cloud.
Expand All @@ -33,7 +41,7 @@
watch_dbt_cloud_job = DbtCloudRunSensor(
task_id='watch_dbt_cloud_job',
dbt_cloud_conn_id='dbt_cloud',
job_id="{{ task_instance.xcom_pull(task_ids='run_dbt_cloud_job', dag_id='dbt_cloud_hourly_dag', key='return_value') }}",
run_id="{{ task_instance.xcom_pull(task_ids='run_dbt_cloud_job', dag_id='dbt_cloud_hourly_dag', key='return_value') }}",
sla=timedelta(minutes=45),
dag=dag)

Expand Down