Skip to content

Commit

Permalink
Merge pull request #686 from RasmusOrsoe/parquet_hangtime_bugfix
Browse files Browse the repository at this point in the history
Bugfix to #685 and #683
  • Loading branch information
RasmusOrsoe authored Apr 8, 2024
2 parents c8564c9 + 79b400b commit 839075d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/graphnet/data/datamodule.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Base `Dataloader` class(es) used in `graphnet`."""
from typing import Dict, Any, Optional, List, Tuple, Union
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from copy import deepcopy
from sklearn.model_selection import train_test_split
import pandas as pd
Expand All @@ -13,6 +12,7 @@
ParquetDataset,
)
from graphnet.utilities.logging import Logger
from graphnet.data.dataloader import DataLoader


class GraphNeTDataModule(pl.LightningDataModule, Logger):
Expand Down
6 changes: 4 additions & 2 deletions src/graphnet/data/dataset/parquet/parquet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,10 @@ def _query_table(
)
data = df.select(columns)
if isinstance(data[columns[0]][0], Series):
data = data.explode(columns)
array = data.to_numpy()
x = [data[col][0].to_numpy().reshape(-1, 1) for col in columns]
array = np.concatenate(x, axis=1)
else:
array = data.to_numpy()
else:
array = np.array()
return array
Expand Down

0 comments on commit 839075d

Please sign in to comment.