-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* ref: cern-sis/issues#119
- Loading branch information
Showing
12 changed files
with
347 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import datetime | ||
|
||
import pendulum | ||
from airflow.decorators import dag, task | ||
from airflow.exceptions import AirflowException | ||
from airflow.providers.http.hooks.http import HttpHook | ||
from annual_reports.utils import get_endpoint, get_subjects_by_year | ||
from common.models.annual_reports.annual_reports import Categories | ||
from common.operators.sqlalchemy_operator import sqlalchemy_task | ||
from executor_config import kubernetes_executor_config | ||
from sqlalchemy.sql import func | ||
from tenacity import retry_if_exception_type, stop_after_attempt | ||
|
||
current_year = datetime.datetime.now().year | ||
years = list(range(2004, current_year + 1)) | ||
|
||
|
||
@dag( | ||
start_date=pendulum.today("UTC").add(days=-1), | ||
schedule="@monthly", | ||
) | ||
def annual_reports_categories_dag(): | ||
@task(executor_config=kubernetes_executor_config) | ||
def fetch_categories_report_count(year, **kwargs): | ||
endpoint = get_endpoint(year=year, key="subjects") | ||
http_hook = HttpHook(http_conn_id="cds", method="GET") | ||
response = http_hook.run_with_advanced_retry( | ||
endpoint=endpoint, | ||
_retry_args={ | ||
"stop": stop_after_attempt(3), | ||
"retry": retry_if_exception_type(AirflowException), | ||
}, | ||
) | ||
subjects = get_subjects_by_year(response.content) | ||
return {year: subjects} | ||
|
||
@sqlalchemy_task(conn_id="superset") | ||
def populate_categories_report_count(entry, session, **kwargs): | ||
for year, subjects in entry.items(): | ||
for category, count in subjects.items(): | ||
record = ( | ||
session.query(Categories) | ||
.filter_by(category=category, year=year) | ||
.first() | ||
) | ||
if record: | ||
record.count = int(count) | ||
record.year = int(year) | ||
record.updated_at = func.now() | ||
else: | ||
new_record = Categories( | ||
year=int(year), | ||
category=category, | ||
count=int(count), | ||
) | ||
session.add(new_record) | ||
|
||
previous_task = None | ||
for year in years: | ||
fetch_task = fetch_categories_report_count.override( | ||
task_id=f"fetch_report_{year}" | ||
)(year=year) | ||
populate_task = populate_categories_report_count.override( | ||
task_id=f"populate_report_{year}" | ||
)(entry=fetch_task) | ||
if previous_task: | ||
previous_task >> fetch_task | ||
fetch_task >> populate_task | ||
previous_task = populate_task | ||
|
||
|
||
annual_reports_categories = annual_reports_categories_dag() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import datetime | ||
|
||
import pendulum | ||
from airflow.decorators import dag, task | ||
from airflow.exceptions import AirflowException | ||
from airflow.providers.http.hooks.http import HttpHook | ||
from annual_reports.utils import get_endpoint, get_publications_by_year | ||
from common.models.annual_reports.annual_reports import Journals, Publications | ||
from common.operators.sqlalchemy_operator import sqlalchemy_task | ||
from executor_config import kubernetes_executor_config | ||
from sqlalchemy.sql import func | ||
from tenacity import retry_if_exception_type, stop_after_attempt | ||
|
||
current_year = datetime.datetime.now().year | ||
years = list(range(2004, current_year + 1)) | ||
|
||
|
||
@dag(start_date=pendulum.today("UTC").add(days=-1), schedule="@monthly") | ||
def annual_reports_publications_dag(): | ||
@task(executor_config=kubernetes_executor_config) | ||
def fetch_publication_report_count(year, **kwargs): | ||
endpoint = get_endpoint(key="publications", year=year) | ||
http_hook = HttpHook(http_conn_id="cds", method="GET") | ||
response = http_hook.run_with_advanced_retry( | ||
endpoint=endpoint, | ||
_retry_args={ | ||
"stop": stop_after_attempt(3), | ||
"retry": retry_if_exception_type(AirflowException), | ||
}, | ||
) | ||
publication_report_count, journals = get_publications_by_year(response.content) | ||
return {year: (publication_report_count, journals)} | ||
|
||
@sqlalchemy_task(conn_id="superset") | ||
def process_results(results, session, **kwargs): | ||
for year, values in results.items(): | ||
publications, journals = values | ||
populate_publication_report_count(publications, year, session) | ||
populate_journal_report_count(journals, year, session) | ||
|
||
def populate_publication_report_count(publications, year, session, **kwargs): | ||
record = session.query(Publications).filter_by(year=year).first() | ||
if record: | ||
record.publications = publications["publications"] | ||
record.journals = publications["journals"] | ||
record.contributions = publications["contributions"] | ||
record.theses = publications["theses"] | ||
record.rest = publications["rest"] | ||
record.year = year | ||
record.updated_at = func.now() | ||
else: | ||
new_record = Publications( | ||
year=year, | ||
publications=publications["publications"], | ||
journals=publications["journals"], | ||
contributions=publications["contributions"], | ||
theses=publications["theses"], | ||
rest=publications["rest"], | ||
) | ||
session.add(new_record) | ||
|
||
def populate_journal_report_count(journals, year, session, **kwargs): | ||
for journal, count in journals.items(): | ||
record = ( | ||
session.query(Journals).filter_by(year=year, journal=journal).first() | ||
) | ||
if record: | ||
record.journal = journal | ||
record.count = count | ||
record.year = year | ||
record.updated_at = func.now() | ||
else: | ||
new_record = Journals( | ||
year=year, | ||
journal=journal, | ||
count=count, | ||
) | ||
session.add(new_record) | ||
|
||
previous_task = None | ||
for year in years: | ||
fetch_task = fetch_publication_report_count.override( | ||
task_id=f"fetch_report_{year}" | ||
)(year=year) | ||
process_task = process_results.override(task_id=f"process_results_{year}")( | ||
results=fetch_task | ||
) | ||
if previous_task: | ||
previous_task >> fetch_task | ||
fetch_task >> process_task | ||
previous_task = process_task | ||
|
||
|
||
annual_reports_publications = annual_reports_publications_dag() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
ONLY_CERN_SUBJECTS = [ | ||
"Particle Physics - Experiment", | ||
"Particle Physics - Phenomenology", | ||
"Particle Physics - Theory", | ||
"Particle Physics - Lattice", | ||
"Nuclear Physics - Experiment", | ||
"Nuclear Physics - Theory", | ||
"General Relativity and Cosmology", | ||
"General Theoretical Physics", | ||
"Detectors and Experimental Techniques", | ||
"Accelerators and Storage Rings", | ||
"Health Physics and Radiation Effects", | ||
"Computing and Computers", | ||
"Mathematical Physics and Mathematics", | ||
"Astrophysics and Astronomy", | ||
"Nonlinear Systems", | ||
"Condensed Matter", | ||
"Other Fields of Physics", | ||
"Chemical Physics and Chemistry", | ||
"Engineering", | ||
"Information Transfer and Management", | ||
"Physics in General", | ||
"Commerce, Economics, Social Science", | ||
"Biography, Geography, History", | ||
"Other Subjects", | ||
"Science in General", | ||
"Quantum Technology", | ||
"Education and Outreach", | ||
"Popular Science", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import os | ||
import xml.etree.ElementTree as ET | ||
|
||
from common.exceptions import VariableValueIsMissing | ||
|
||
from .constants import ONLY_CERN_SUBJECTS | ||
|
||
|
||
def get_endpoint(key, year): | ||
cds_token = os.environ.get("CDS_TOKEN", "") | ||
if cds_token: | ||
base = f"tools/custom_query_summary.py?start={year}&end={year}&apikey={cds_token}&refresh=1&repeated_values=0" | ||
endpoints = {"publications": base, "subjects": base + "&otag=65017a"} | ||
return endpoints[key] | ||
raise VariableValueIsMissing("CDS_TOKEN") | ||
|
||
|
||
def get_publications_by_year(content): | ||
root = ET.fromstring(content) | ||
yearly_report = root.find("yearly_report") | ||
publication_report_count = yearly_report.attrib | ||
del publication_report_count["year"] | ||
journals = {} | ||
for journal in yearly_report.findall("line"): | ||
name = journal.find("result").text | ||
if "TOTAL" in name: | ||
continue | ||
journals[name] = journal.find("nb").text | ||
return publication_report_count, journals | ||
|
||
|
||
def get_subjects_by_year(content): | ||
root = ET.fromstring(content) | ||
yearly_report = root.find("yearly_report") | ||
subjects = {} | ||
for subject in yearly_report.findall("line"): | ||
name = subject.find("result").text | ||
if "TOTAL" in name or name not in ONLY_CERN_SUBJECTS: | ||
continue | ||
subjects[name] = subject.find("nb").text | ||
return subjects |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from sqlalchemy import Column, DateTime, Integer, String, func | ||
from sqlalchemy.ext.declarative import declarative_base | ||
|
||
Base = declarative_base() | ||
|
||
|
||
class Publications(Base): | ||
__tablename__ = "annual_reports_publications" | ||
|
||
id = Column(Integer, primary_key=True) | ||
year = Column(Integer, nullable=False) | ||
publications = Column(Integer, nullable=False) | ||
journals = Column(Integer, nullable=False) | ||
contributions = Column(Integer, nullable=False) | ||
theses = Column(Integer, nullable=False) | ||
rest = Column(Integer, nullable=False) | ||
created_at = Column(DateTime, default=func.now()) | ||
updated_at = Column(DateTime, default=func.now()) | ||
|
||
|
||
class Categories(Base): | ||
__tablename__ = "annual_reports_categories" | ||
|
||
id = Column(Integer, primary_key=True) | ||
category = Column(String, nullable=False) | ||
count = Column(Integer, nullable=False) | ||
year = Column(Integer, nullable=False) | ||
created_at = Column(DateTime, default=func.now()) | ||
updated_at = Column(DateTime, default=func.now()) | ||
|
||
|
||
class Journals(Base): | ||
__tablename__ = "annual_reports_journals" | ||
|
||
id = Column(Integer, primary_key=True) | ||
journal = Column(String, nullable=False) | ||
count = Column(Integer, nullable=False) | ||
year = Column(Integer, nullable=False) | ||
created_at = Column(DateTime, default=func.now()) | ||
updated_at = Column(DateTime, default=func.now()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
56 changes: 56 additions & 0 deletions
56
dags/migrations/versions/fc3ffc0db6db_database_revision_for_annual_reports.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
"""Database revision for Annual Reports | ||
Revision ID: fc3ffc0db6db | ||
Revises: 101f23913167 | ||
Create Date: 2024-08-06 10:53:05.078428 | ||
""" | ||
from typing import Sequence, Union | ||
|
||
import sqlalchemy as sa | ||
from alembic import op | ||
|
||
# revision identifiers, used by Alembic. | ||
revision: str = "fc3ffc0db6db" | ||
down_revision: Union[str, None] = "101f23913167" | ||
branch_labels: Union[str, Sequence[str], None] = None | ||
depends_on: Union[str, Sequence[str], None] = None | ||
|
||
|
||
def upgrade(): | ||
op.create_table( | ||
"annual_reports_publications", | ||
sa.Column("id", sa.Integer, primary_key=True), | ||
sa.Column("year", sa.Integer, nullable=False), | ||
sa.Column("publications", sa.Integer, nullable=False), | ||
sa.Column("journals", sa.Integer, nullable=False), | ||
sa.Column("contributions", sa.Integer, nullable=False), | ||
sa.Column("theses", sa.Integer, nullable=False), | ||
sa.Column("rest", sa.Integer, nullable=False), | ||
sa.Column("created_at", sa.TIMESTAMP(timezone=True), nullable=False), | ||
sa.Column("updated_at", sa.TIMESTAMP(timezone=True), nullable=False), | ||
) | ||
op.create_table( | ||
"annual_reports_journals", | ||
sa.Column("id", sa.Integer, primary_key=True), | ||
sa.Column("year", sa.Integer, nullable=False), | ||
sa.Column("journal", sa.String, nullable=False), | ||
sa.Column("count", sa.Integer, nullable=False), | ||
sa.Column("created_at", sa.TIMESTAMP(timezone=True), nullable=False), | ||
sa.Column("updated_at", sa.TIMESTAMP(timezone=True), nullable=False), | ||
) | ||
op.create_table( | ||
"annual_reports_categories", | ||
sa.Column("id", sa.Integer, primary_key=True), | ||
sa.Column("year", sa.Integer, nullable=False), | ||
sa.Column("category", sa.String, nullable=False), | ||
sa.Column("count", sa.Integer, nullable=False), | ||
sa.Column("created_at", sa.TIMESTAMP(timezone=True), nullable=False), | ||
sa.Column("updated_at", sa.TIMESTAMP(timezone=True), nullable=False), | ||
) | ||
|
||
|
||
def downgrade() -> None: | ||
op.drop_table("annual_reports_publications") | ||
op.drop_table("annual_reports_journals") | ||
op.drop_table("annual_reports_categories") |