Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add validation/function to Delay and Duration tags. #885

Merged
merged 1 commit into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 9 additions & 37 deletions hed/models/base_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
import os

import openpyxl
import pandas
import pandas as pd

from hed.models.column_mapper import ColumnMapper
from hed.errors.exceptions import HedFileError, HedExceptions
import pandas as pd

from hed.models.df_util import _handle_curly_braces_refs
from hed.models.df_util import _handle_curly_braces_refs, filter_series_by_onset


class BaseInput:
Expand Down Expand Up @@ -118,37 +117,10 @@ def series_filtered(self):
"""Return the assembled dataframe as a series, with rows that have the same onset combined.

Returns:
Series: the assembled dataframe with columns merged, and the rows filtered together.
Series or None: the assembled dataframe with columns merged, and the rows filtered together.
"""
if self.onsets is not None:
indexed_dict = self._indexed_dict_from_onsets(self.onsets.astype(float))
return self._filter_by_index_list(self.series_a, indexed_dict=indexed_dict)

@staticmethod
def _indexed_dict_from_onsets(onsets):
current_onset = -1000000.0
tol = 1e-9
from collections import defaultdict
indexed_dict = defaultdict(list)
for i, onset in enumerate(onsets):
if abs(onset - current_onset) > tol:
current_onset = onset
indexed_dict[current_onset].append(i)

return indexed_dict

# This would need to store the index list -> So it can optionally apply to other columns on request.
@staticmethod
def _filter_by_index_list(original_series, indexed_dict):
new_series = pd.Series([""] * len(original_series), dtype=str)

for onset, indices in indexed_dict.items():
if indices:
first_index = indices[0] # Take the first index of each onset group
# Join the corresponding original series entries and place them at the first index
new_series[first_index] = ",".join([str(original_series[i]) for i in indices])

return new_series
return filter_series_by_onset(self.series_a, self.onsets)

@property
def onsets(self):
Expand All @@ -161,7 +133,7 @@ def needs_sorting(self):
"""Return True if this both has an onset column, and it needs sorting."""
onsets = self.onsets
if onsets is not None:
onsets = onsets.astype(float)
onsets = pd.to_numeric(self.dataframe['onset'], errors='coerce')
return not onsets.is_monotonic_increasing

@property
Expand Down Expand Up @@ -369,9 +341,9 @@ def _get_dataframe_from_worksheet(worksheet, has_headers):
# first row is columns
cols = next(data)
data = list(data)
return pandas.DataFrame(data, columns=cols, dtype=str)
return pd.DataFrame(data, columns=cols, dtype=str)
else:
return pandas.DataFrame(worksheet.values, dtype=str)
return pd.DataFrame(worksheet.values, dtype=str)

def validate(self, hed_schema, extra_def_dicts=None, name=None, error_handler=None):
"""Creates a SpreadsheetValidator and returns all issues with this file.
Expand Down Expand Up @@ -483,14 +455,14 @@ def _open_dataframe_file(self, file, has_column_names, input_type):
if not has_column_names:
pandas_header = None

if isinstance(file, pandas.DataFrame):
if isinstance(file, pd.DataFrame):
self._dataframe = file.astype(str)
self._has_column_names = self._dataframe_has_names(self._dataframe)
elif not file:
raise HedFileError(HedExceptions.FILE_NOT_FOUND, "Empty file passed to BaseInput.", file)
elif input_type in self.TEXT_EXTENSION:
try:
self._dataframe = pandas.read_csv(file, delimiter='\t', header=pandas_header,
self._dataframe = pd.read_csv(file, delimiter='\t', header=pandas_header,
dtype=str, keep_default_na=True, na_values=("", "null"))
except Exception as e:
raise HedFileError(HedExceptions.INVALID_FILE_FORMAT, str(e), self.name) from e
Expand Down
1 change: 0 additions & 1 deletion hed/models/definition_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def __init__(self, def_dicts=None, hed_schema=None):
"""

self.defs = {}
self._label_tag_name = DefTagNames.DEF_KEY
self._issues = []
if def_dicts:
self.add_definitions(def_dicts, hed_schema)
Expand Down
104 changes: 99 additions & 5 deletions hed/models/df_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from functools import partial
import pandas as pd
from hed.models.hed_string import HedString
from hed.models.model_constants import DefTagNames


def get_assembled(tabular_file, hed_schema, extra_def_dicts=None, defs_expanded=True, return_filtered=False):
def get_assembled(tabular_file, hed_schema, extra_def_dicts=None, defs_expanded=True):
""" Create an array of assembled HedString objects (or list of these) of the same length as tabular file input.

Parameters:
Expand All @@ -14,16 +15,14 @@ def get_assembled(tabular_file, hed_schema, extra_def_dicts=None, defs_expanded=
extra_def_dicts: list of DefinitionDict, optional
Any extra DefinitionDict objects to use when parsing the HED tags.
defs_expanded (bool): (Default True) Expands definitions if True, otherwise shrinks them.
return_filtered (bool): If true, combines lines with the same onset.
Further lines with that onset are marked n/a
Returns:
tuple:
hed_strings(list of HedStrings): A list of HedStrings
def_dict(DefinitionDict): The definitions from this Sidecar.
"""

def_dict = tabular_file.get_def_dict(hed_schema, extra_def_dicts=extra_def_dicts)
series_a = tabular_file.series_a if not return_filtered else tabular_file.series_filtered
series_a = tabular_file.series_a
if defs_expanded:
return [HedString(x, hed_schema, def_dict).expand_defs() for x in series_a], def_dict
else:
Expand Down Expand Up @@ -217,7 +216,102 @@ def _handle_curly_braces_refs(df, refs, column_names):
# df[column_name] = pd.Series(x.replace(column_name_brackets, y) for x, y
# in zip(df[column_name], saved_columns[replacing_name]))
new_df[column_name] = pd.Series(replace_ref(x, y, replacing_name) for x, y
in zip(new_df[column_name], saved_columns[replacing_name]))
in zip(new_df[column_name], saved_columns[replacing_name]))
new_df = new_df[remaining_columns]

return new_df


# todo: Consider updating this to be a pure string function(or at least, only instantiating the Duration tags)
def split_delay_tags(series, hed_schema, onsets):
"""Sorts the series based on Delay tags, so that the onsets are in order after delay is applied.

Parameters:
series(pd.Series or None): the series of tags to split/sort
hed_schema(HedSchema): The schema to use to identify tags
onsets(pd.Series or None)

Returns:
sorted_df(pd.Dataframe or None): If we had onsets, a dataframe with 3 columns
"HED": The hed strings(still str)
"onset": the updated onsets
"original_index": the original source line. Multiple lines can have the same original source line.

Note: This dataframe may be longer than the original series, but it will never be shorter.
"""
if series is None or onsets is None:
return
split_df = pd.DataFrame({"onset": onsets, "HED": series, "original_index": series.index})
delay_strings = [(i, HedString(hed_string, hed_schema)) for (i, hed_string) in series.items() if
"delay/" in hed_string.lower()]
delay_groups = []
for i, delay_string in delay_strings:
duration_tags = delay_string.find_top_level_tags({DefTagNames.DELAY_KEY})
to_remove = []
for tag, group in duration_tags:
onset_mod = tag.value_as_default_unit() + float(onsets[i])
to_remove.append(group)
insert_index = split_df['original_index'].index.max() + 1
split_df.loc[insert_index] = {'HED': str(group), 'onset': onset_mod, 'original_index': i}
delay_string.remove(to_remove)
# update the old string with the removals done
split_df.at[i, "HED"] = str(delay_string)

for i, onset_mod, group in delay_groups:
insert_index = split_df['original_index'].index.max() + 1
split_df.loc[insert_index] = {'HED': str(group), 'onset': onset_mod, 'original_index': i}
split_df = sort_dataframe_by_onsets(split_df)
split_df.reset_index(drop=True, inplace=True)

split_df = filter_series_by_onset(split_df, split_df.onset)
return split_df


def filter_series_by_onset(series, onsets):
"""Return the series, with rows that have the same onset combined.

Parameters:
series(pd.Series or pd.Dataframe): the series to filter. If dataframe, it filters the "HED" column
onsets(pd.Series): the onset column to filter by
Returns:
Series or Dataframe: the series with rows filtered together.
"""
indexed_dict = _indexed_dict_from_onsets(onsets.astype(float))
return _filter_by_index_list(series, indexed_dict=indexed_dict)


def _indexed_dict_from_onsets(onsets):
"""Finds series of consecutive lines with the same(or close enough) onset"""
current_onset = -1000000.0
tol = 1e-9
from collections import defaultdict
indexed_dict = defaultdict(list)
for i, onset in enumerate(onsets):
if abs(onset - current_onset) > tol:
current_onset = onset
indexed_dict[current_onset].append(i)

return indexed_dict


def _filter_by_index_list(original_data, indexed_dict):
"""Filters a series or dataframe by the indexed_dict, joining lines as indicated"""
if isinstance(original_data, pd.Series):
data_series = original_data
elif isinstance(original_data, pd.DataFrame):
data_series = original_data["HED"]
else:
raise TypeError("Input must be a pandas Series or DataFrame")

new_series = pd.Series([""] * len(data_series), dtype=str)
for onset, indices in indexed_dict.items():
if indices:
first_index = indices[0]
new_series[first_index] = ",".join([str(data_series[i]) for i in indices])

if isinstance(original_data, pd.Series):
return new_series
else:
result_df = original_data.copy()
result_df["HED"] = new_series
return result_df
12 changes: 7 additions & 5 deletions hed/models/hed_group.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
""" A single parenthesized HED string. """
from hed.models.hed_tag import HedTag
from hed.models.model_constants import DefTagNames
import copy
from typing import Iterable, Union

Expand Down Expand Up @@ -441,7 +442,7 @@ def find_tags(self, search_tags, recursive=False, include_groups=2):
tags = self.get_all_tags()
else:
tags = self.tags()

search_tags = {tag.lower() for tag in search_tags}
for tag in tags:
if tag.short_base_tag.lower() in search_tags:
found_tags.append((tag, tag._parent))
Expand All @@ -453,7 +454,7 @@ def find_tags(self, search_tags, recursive=False, include_groups=2):
def find_wildcard_tags(self, search_tags, recursive=False, include_groups=2):
""" Find the tags and their containing groups.

This searches tag.short_tag, with an implicit wildcard on the end.
This searches tag.short_tag.lower(), with an implicit wildcard on the end.

e.g. "Eve" will find Event, but not Sensory-event.

Expand All @@ -474,6 +475,8 @@ def find_wildcard_tags(self, search_tags, recursive=False, include_groups=2):
else:
tags = self.tags()

search_tags = {search_tag.lower() for search_tag in search_tags}

for tag in tags:
for search_tag in search_tags:
if tag.short_tag.lower().startswith(search_tag):
Expand Down Expand Up @@ -539,15 +542,14 @@ def find_def_tags(self, recursive=False, include_groups=3):

@staticmethod
def _get_def_tags_from_group(group):
from hed.models.definition_dict import DefTagNames
def_tags = []
for child in group.children:
if isinstance(child, HedTag):
if child.short_base_tag == DefTagNames.DEF_ORG_KEY:
if child.short_base_tag == DefTagNames.DEF_KEY:
def_tags.append((child, child, group))
else:
for tag in child.tags():
if tag.short_base_tag == DefTagNames.DEF_EXPAND_ORG_KEY:
if tag.short_base_tag == DefTagNames.DEF_EXPAND_KEY:
def_tags.append((tag, child, group))
return def_tags

Expand Down
26 changes: 2 additions & 24 deletions hed/models/hed_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def shrink_defs(self):
for def_expand_tag, def_expand_group in self.find_tags({DefTagNames.DEF_EXPAND_KEY}, recursive=True):
expanded_parent = def_expand_group._parent
if expanded_parent:
def_expand_tag.short_base_tag = DefTagNames.DEF_ORG_KEY
def_expand_tag.short_base_tag = DefTagNames.DEF_KEY
def_expand_tag._parent = expanded_parent
expanded_parent.replace(def_expand_group, def_expand_tag)

Expand Down Expand Up @@ -353,6 +353,7 @@ def find_top_level_tags(self, anchor_tags, include_groups=2):
Returns:
list: The returned result depends on include_groups.
"""
anchor_tags = {tag.lower() for tag in anchor_tags}
top_level_tags = []
for group in self.groups():
for tag in group.tags():
Expand All @@ -365,29 +366,6 @@ def find_top_level_tags(self, anchor_tags, include_groups=2):
return [tag[include_groups] for tag in top_level_tags]
return top_level_tags

def find_top_level_tags_grouped(self, anchor_tags):
""" Find top level groups with an anchor tag.

This is an alternate one designed to be easy to use with Delay/Duration tag.

Parameters:
anchor_tags (container): A list/set/etc. of short_base_tags to find groups by.
Returns:
list of tuples:
list of tags: the tags in the same subgroup
group: the subgroup containing the tags
"""
top_level_tags = []
for group in self.groups():
tags = []
for tag in group.tags():
if tag.short_base_tag.lower() in anchor_tags:
tags.append(tag)
if tags:
top_level_tags.append((tags, group))

return top_level_tags

def remove_refs(self):
""" Remove any refs(tags contained entirely inside curly braces) from the string.

Expand Down
8 changes: 4 additions & 4 deletions hed/models/hed_tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self, hed_string, hed_schema, span=None, def_dict=None):

self._def_entry = None
if def_dict:
if self.short_base_tag in {DefTagNames.DEF_ORG_KEY, DefTagNames.DEF_EXPAND_ORG_KEY}:
if self.short_base_tag in {DefTagNames.DEF_KEY, DefTagNames.DEF_EXPAND_KEY}:
self._def_entry = def_dict.get_definition_entry(self)

def copy(self):
Expand Down Expand Up @@ -277,7 +277,7 @@ def expandable(self):
self._parent = save_parent
if def_contents is not None:
self._expandable = def_contents
self._expanded = self.short_base_tag == DefTagNames.DEF_EXPAND_ORG_KEY
self._expanded = self.short_base_tag == DefTagNames.DEF_EXPAND_KEY
return self._expandable

def is_column_ref(self):
Expand Down Expand Up @@ -621,12 +621,12 @@ def __eq__(self, other):
return True

if isinstance(other, str):
return self.lower() == other
return self.lower() == other.lower()

if not isinstance(other, HedTag):
return False

if self.short_tag.lower() == other.short_tag.lower():
if self.short_tag == other.short_tag:
return True

if self.org_tag.lower() == other.org_tag.lower():
Expand Down
Loading
Loading