Skip to content

Commit

Permalink
formatting auto fix
Browse files Browse the repository at this point in the history
  • Loading branch information
henri123lemoine committed Aug 8, 2023
1 parent adfbfe0 commit 0f1b1c8
Show file tree
Hide file tree
Showing 63 changed files with 3,343 additions and 2,377 deletions.
1 change: 1 addition & 0 deletions align_data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
ALL_DATASETS = sorted([dataset.name for dataset in DATASET_REGISTRY])
DATASET_MAP = {dataset.name: dataset for dataset in DATASET_REGISTRY}


def get_dataset(name):
try:
return DATASET_MAP[name]
Expand Down
40 changes: 27 additions & 13 deletions align_data/analysis/analyse_jsonl_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from collections import defaultdict


def is_valid_date_format(data_dict, format="%Y-%m-%dT%H:%M:%SZ"):
"""
Checks if the given date string matches the expected format.
Expand All @@ -15,37 +16,44 @@ def is_valid_date_format(data_dict, format="%Y-%m-%dT%H:%M:%SZ"):
except ValueError:
return False


def validate_data(data_dict):
"""
Processes each dictionary element in the jsonl file.
Processes each dictionary element in the jsonl file.
"""
if not is_valid_date_format(data_dict):
raise ValueError(f"Invalid date format for source: {data_dict['source']}, title: {data_dict['title'][:30]}, date_pub: {data_dict['date_published']}")
raise ValueError(
f"Invalid date format for source: {data_dict['source']}, title: {data_dict['title'][:30]}, date_pub: {data_dict['date_published']}"
)
# TODO: add more checks here


def check_for_duplicates(data_dict, seen_urls):
id = data_dict.get('id')
id = data_dict.get("id")
seen_urls[id].append(data_dict)

#TODO: Add more validation logic here
return seen_urls
# TODO: Add more validation logic here
return seen_urls


def get_data_dict_str(data_dict):
"""
Returns a string representation of the given data_dict.
"""
return f"source: {data_dict['source']}, title: {data_dict['title'][:50]}, date_pub: {data_dict['date_published']}, url: {data_dict['url']}\n"


def files_iterator(data_dir):
"""
Goes through the data directory, opens every jsonl file sequentially,
Goes through the data directory, opens every jsonl file sequentially,
and yields every element (which is a dictionary) in the jsonl file.
"""
for path in Path(data_dir).glob('*.jsonl'):
for path in Path(data_dir).glob("*.jsonl"):
with jsonlines.open(path) as f:
for line in f:
yield line


def process_jsonl_files(data_dir):
seen_urls = defaultdict(list) # holds all seen urls
for data_dict in files_iterator(data_dir):
Expand All @@ -57,23 +65,29 @@ def process_jsonl_files(data_dir):
except Exception as e:
print(f"Unexpected error: {e}")
dup_count = 0

for id, duplicates in seen_urls.items():
if len(duplicates) > 1:
list_of_duplicates = '\n'.join(get_data_dict_str(duplicate) for duplicate in duplicates)
print(f"{len(duplicates)} duplicate ids found. \nId: {id}\n{list_of_duplicates}\n\n\n\n")
list_of_duplicates = "\n".join(
get_data_dict_str(duplicate) for duplicate in duplicates
)
print(
f"{len(duplicates)} duplicate ids found. \nId: {id}\n{list_of_duplicates}\n\n\n\n"
)
dup_count += 1
print(f"Total number of duplicate ids found: {dup_count}")


def delete_all_txt_and_jsonl(data_dir):
"""
Deletes all txt and jsonl files in the given directory.
"""
for path in Path(data_dir).glob('*.txt'):
for path in Path(data_dir).glob("*.txt"):
path.unlink()
for path in Path(data_dir).glob('*.jsonl'):
for path in Path(data_dir).glob("*.jsonl"):
path.unlink()


if __name__ == "__main__":
process_jsonl_files("data/")
#delete_all_txt_and_jsonl("data/")
# delete_all_txt_and_jsonl("data/")
12 changes: 7 additions & 5 deletions align_data/analysis/count_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@
import jsonlines
import logging
from typing import Tuple

logger = logging.getLogger(__name__)

def count_token(merged_dataset_path : str = "data/merged_dataset/alignment_texts.jsonl") -> Tuple[int , int , int]:

def count_token(
merged_dataset_path: str = "data/merged_dataset/alignment_texts.jsonl",
) -> Tuple[int, int, int]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
total_token_count , total_word_count , total_character_count = 0 , 0 , 0
total_token_count, total_word_count, total_character_count = 0, 0, 0

with jsonlines.open(merged_dataset_path) as reader:
for obj in reader:
Expand All @@ -18,6 +22,4 @@ def count_token(merged_dataset_path : str = "data/merged_dataset/alignment_texts
logger.info(f"Total token count: {total_token_count}")
logger.info(f"Total word count: {total_word_count}")
logger.info(f"Total character count: {total_character_count}")
return total_token_count , total_word_count , total_character_count


return total_token_count, total_word_count, total_character_count
51 changes: 32 additions & 19 deletions align_data/common/alignment_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ class AlignmentDataset:

_: KW_ONLY

files_path = Path('')
files_path = Path("")
"""The path where data can be found. Usually a folder"""

done_key = 'id'
done_key = "id"
"""The key of the entry to use as the id when checking if already processed."""

COOLDOWN = 0
Expand All @@ -58,30 +58,30 @@ class AlignmentDataset:
_outputted_items = set()
"""A set of the ids of all previously processed items"""
_: KW_ONLY
id_fields: List[str] = field(default_factory=lambda: ['url', 'title'])
id_fields: List[str] = field(default_factory=lambda: ["url", "title"])
"""A list of fields to use as the id of the entry. If not set, will use ['url', 'title']"""

def __str__(self) -> str:
return self.name

def __post_init__(self, data_path=Path(__file__).parent / '../../data/'):
def __post_init__(self, data_path=Path(__file__).parent / "../../data/"):
self.data_path = data_path
self.raw_data_path = self.data_path / 'raw'
self.raw_data_path = self.data_path / "raw"

# set the default place to look for data
self.files_path = self.raw_data_path / self.name

def _add_authors(self, article: Article, authors: List[str]) -> Article:
# TODO: Don't keep adding the same authors - come up with some way to reuse them
article.authors = ','.join(authors)
article.authors = ",".join(authors)
if len(article.authors) > 1024:
article.authors = ','.join(article.authors[:1024].split(',')[:-1])
article.authors = ",".join(article.authors[:1024].split(",")[:-1])
return article

def make_data_entry(self, data, **kwargs) -> Article:
data = dict(data, **kwargs)
summary = data.pop('summary', None)
authors = data.pop('authors', [])
summary = data.pop("summary", None)
authors = data.pop("authors", [])

article = Article(
id_fields=self.id_fields,
Expand All @@ -95,21 +95,21 @@ def make_data_entry(self, data, **kwargs) -> Article:

def to_jsonl(self, out_path=None, filename=None) -> Path:
if not out_path:
out_path=Path(__file__).parent / '../../data/'
out_path = Path(__file__).parent / "../../data/"

if not filename:
filename = f"{self.name}.jsonl"
filename = Path(out_path) / filename

with jsonlines.open(filename, 'w') as jsonl_writer:
with jsonlines.open(filename, "w") as jsonl_writer:
for article in self.read_entries():
jsonl_writer.write(article.to_dict())
return filename.resolve()

def read_entries(self, sort_by=None):
"""Iterate through all the saved entries."""
with make_session() as session:
query = select(Article).where(Article.source==self.name)
query = select(Article).where(Article.source == self.name)
if sort_by is not None:
query = query.order_by(sort_by)
for item in session.scalars(query):
Expand All @@ -136,8 +136,8 @@ def commit():
for entry in batch:
session.add(entry)
if not commit():
logger.error(f'found duplicate of {entry}')
logger.error(f"found duplicate of {entry}")

def setup(self):
self._outputted_items = self._load_outputted_items()

Expand All @@ -160,9 +160,14 @@ def _load_outputted_items(self) -> Set[str]:
# This doesn't filter by self.name. The good thing about that is that it should handle a lot more
# duplicates. The bad thing is that this could potentially return a massive amount of data if there
# are lots of items.
return set(session.scalars(select(getattr(Article, self.done_key))).all())
return set(
session.scalars(select(getattr(Article, self.done_key))).all()
)
# TODO: Properly handle this - it should create a proper SQL JSON select
return {item.get(self.done_key) for item in session.scalars(select(Article.meta)).all()}
return {
item.get(self.done_key)
for item in session.scalars(select(Article.meta)).all()
}

def unprocessed_items(self, items=None) -> Iterable:
"""Return a list of all items to be processed.
Expand Down Expand Up @@ -213,15 +218,17 @@ def _get_published_date(self, date) -> Optional[datetime]:


class SummaryDataset(AlignmentDataset):

def unprocessed_items(self, items=None) -> Iterable:
# This breaks the possible lazy loading of the items. Should be fine...
items = list(super().unprocessed_items(items))

urls = map(self.get_item_key, items)
with make_session() as session:
self.articles = {
a.url: a for a in session.query(Article).options(joinedload(Article.summaries)).filter(Article.url.in_(urls))
a.url: a
for a in session.query(Article)
.options(joinedload(Article.summaries))
.filter(Article.url.in_(urls))
if a.url
}

Expand All @@ -230,7 +237,13 @@ def unprocessed_items(self, items=None) -> Iterable:
def _load_outputted_items(self) -> Set[str]:
"""Load the output file (if it exists) in order to know which items have already been output."""
with make_session() as session:
return set(session.scalars(select(Article.url).join(Article.summaries).filter(Summary.source == self.name)))
return set(
session.scalars(
select(Article.url)
.join(Article.summaries)
.filter(Summary.source == self.name)
)
)

def _add_batch(self, session, batch):
def merge(item):
Expand Down
52 changes: 28 additions & 24 deletions align_data/common/html_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@

logger = logging.getLogger(__name__)


@dataclass
class HTMLDataset(AlignmentDataset):
"""
Fetches articles from a different blog by collecting links to articles from an index page.
"""

url: str
done_key = "url"

Expand All @@ -29,9 +31,9 @@ class HTMLDataset(AlignmentDataset):
source_key: str = None
summary_key: str = None

item_selector = 'article'
title_selector = 'article h1'
text_selector = 'article'
item_selector = "article"
title_selector = "article h1"
text_selector = "article"
source_type = "blog"
ignored_selectors = []

Expand Down Expand Up @@ -64,16 +66,18 @@ def process_entry(self, article):
if not text:
return None

return self.make_data_entry({
"text": text,
"url": article_url,
"title": title,
"source": self.name,
"source_type": "blog",
"date_published": date_published,
"authors": self.extract_authors(contents),
**self._extra_values(contents),
})
return self.make_data_entry(
{
"text": text,
"url": article_url,
"title": title,
"source": self.name,
"source_type": "blog",
"date_published": date_published,
"authors": self.extract_authors(contents),
**self._extra_values(contents),
}
)

def _get_contents(self, url):
logger.info("Fetching {}".format(url))
Expand All @@ -93,44 +97,44 @@ def _get_text(self, contents):

def _find_date(self, items):
for i in items:
if re.match('\w+ \d{1,2}, \d{4}', i.text):
return datetime.strptime(i.text, '%b %d, %Y').replace(tzinfo=pytz.UTC)
if re.match("\w+ \d{1,2}, \d{4}", i.text):
return datetime.strptime(i.text, "%b %d, %Y").replace(tzinfo=pytz.UTC)

def _extract_markdown(self, element):
return element and markdownify(str(element)).strip()


@dataclass
class RSSDataset(HTMLDataset):
date_format = '%a, %d %b %Y %H:%M:%S %z'
date_format = "%a, %d %b %Y %H:%M:%S %z"

def get_item_key(self, item):
return item

@property
def feed_url(self):
return f'{self.url}/rss.xml'
return f"{self.url}/rss.xml"

def extract_authors(self, item):
if 'authors' in item:
return [a['name'] for a in item['authors'] if a.get('name')]
if "authors" in item:
return [a["name"] for a in item["authors"] if a.get("name")]
return self.authors

@staticmethod
def _get_title(item):
return item['title']
return item["title"]

def _get_published_date(self, item):
date_published = item.get('published') or item.get('pubDate')
date_published = item.get("published") or item.get("pubDate")
return super()._get_published_date(date_published)

def _get_text(self, item):
text = item.get('content') and item['content'][0].get('value')
text = item.get("content") and item["content"][0].get("value")
return self._extract_markdown(text)

def _get_contents(self, url):
item = self.items[url]
if 'content' in item:
if "content" in item:
return item

logger.info("Fetching {}".format(url))
Expand All @@ -145,5 +149,5 @@ def _get_contents(self, url):
def items_list(self):
logger.info(f"Fetching entries from {self.feed_url}")
feed = feedparser.parse(self.feed_url)
self.items = {item['link']: item for item in feed['entries']}
self.items = {item["link"]: item for item in feed["entries"]}
return list(self.items.keys())
Loading

0 comments on commit 0f1b1c8

Please sign in to comment.