Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New dataset #31

Merged
merged 16 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/graphnet/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,9 @@
from .pre_configured import I3ToParquetConverter
from .pre_configured import I3ToSQLiteConverter
from .datamodule import GraphNeTDataModule
from .curated_datamodule import CuratedDataset, ERDAHostedDataset
from .curated_datamodule import (
CuratedDataset,
ERDAHostedDataset,
SecretDataset,
PublicBenchmarkDataset,
)
286 changes: 280 additions & 6 deletions src/graphnet/data/curated_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
import and download pre-converteddatasets for training of deep learning based
methods in GraphNeT.
"""

from typing import Dict, Any, Optional, List, Tuple, Union
from abc import abstractmethod
import os
from glob import glob
import pandas as pd
from graphnet.training.labels import Direction, Track

from .datamodule import GraphNeTDataModule
from graphnet.models.graphs import GraphDefinition
Expand Down Expand Up @@ -243,10 +245,17 @@ def available_backends(self) -> List[str]:
@property
def dataset_dir(self) -> str:
"""Produce path directory that contains dataset files."""
dataset_dir = os.path.join(
self._download_dir, self.__class__.__name__, self._backend
)
return dataset_dir
if hasattr(self, "_secret"):
dir = os.path.join(
self._download_dir,
self.__class__.__name__ + "-" + self._secret,
self._backend,
)
else:
dir = os.path.join(
self._download_dir, self.__class__.__name__, self._backend
)
return dir


class ERDAHostedDataset(CuratedDataset):
Expand All @@ -271,17 +280,282 @@ def prepare_data(self) -> None:
"""Prepare the dataset for training."""
assert self._file_hashes is not None # mypy
file_hash = self._file_hashes[self._backend]
file_path = os.path.join(self.dataset_dir, file_hash + ".tar.gz")
if os.path.exists(self.dataset_dir):
return
else:
# Download, unzip and delete zipped file

os.makedirs(self.dataset_dir, exist_ok=True)

os.makedirs(self.dataset_dir)
_, file_name = os.path.split(file_hash)
extension = ".tar.gz" if ".tar.gz" not in file_name else ""
file_path = os.path.join(
self.dataset_dir,
file_name + extension,
)

os.system(f"wget -O {file_path} {self._mirror}/{file_hash}")
os.system(f"tar -xf {file_path} -C {self.dataset_dir}")
print("Unzipping file, this might take a while..")
if self._backend == "parquet":
os.system(f"tar -xf {file_path} -C {self.dataset_dir}")
else:
os.system(f"tar -xvzf {file_path} -C {self.dataset_dir}")
os.system(f"rm {file_path}")


class PublicBenchmarkDataset(ERDAHostedDataset):
"""A generic class for public Prometheus Datasets hosted using ERDA."""

def __init__(
self,
graph_definition: GraphDefinition,
download_dir: str,
backend: str = "parquet",
mode: str = "train",
train_dataloader_kwargs: Optional[Dict[str, Any]] = None,
validation_dataloader_kwargs: Dict[str, Any] = None,
test_dataloader_kwargs: Dict[str, Any] = None,
) -> None:
"""Download a public dataset and build DataLoaders.

The Dataset can be instatiated in three modes: "train", "test" or
"test-no-noise". When instantiated in "train" mode, input data is
read from the "merged_photons" table and dataloaders for training and
validation is constructed using a pre-defined selection of
events/chunks. The GraphDefinition passed to this dataset should in
this case apply time and charge smearing and subsequent merging of
coincident pulses in order to be comparable to the test set. NOTE that
the test set is not constructed in "train" mode.

If instantiated in "test" or "test-no-noise" mode,
already processed photons will be read from "pulses" or
"pulses_no_noise", respectively. GraphDefinition passed to the dataset
should in this case not smear charge and time variables, and should
not apply any merging.

Args:
graph_definition: Method that defines the data representation.
download_dir: Directory to download dataset to.
truth (Optional): List of event-level truth to include. Will
include all available information if not given.
features (Optional): List of input features from pulsemap to use.
If not given, all available features will be
used.
backend (Optional): data backend to use. Either "parquet" or
"sqlite". Defaults to "parquet".
train_dataloader_kwargs (Optional): Arguments for the training
DataLoader. Default None.
validation_dataloader_kwargs (Optional): Arguments for the
validation DataLoader, Default None.
test_dataloader_kwargs (Optional): Arguments for the test
DataLoader. Default None.
mode: Mode in which to instantiate the dataset in One of either
['train', 'test', 'test-no-noise'].
"""
# Static Member Variables:
self._mode = mode
if self._mode == "train":
self._pulsemaps = ["merged_photons"]
elif self._mode == "test":
self._pulsemaps = ["pulses"]
elif self._mode == "test-no-noise":
self._pulsemaps = ["pulses_no_noise"]
else:
raise AssertionError(
"'mode' must be one of "
f"{{['train', 'test', 'test-no-noise']}}"
f"got '{mode}'"
)
self._truth_table = "mc_truth"
self._event_truth = [
"interaction",
"initial_state_energy",
"initial_state_type",
"initial_state_zenith",
"initial_state_azimuth",
"initial_state_x",
"initial_state_y",
"initial_state_z",
"visible_inelasticity",
]
self._pulse_truth = "pulses"
self._features = [
"sensor_pos_x",
"sensor_pos_y",
"sensor_pos_z",
"t",
"charge",
"string_id",
]

ERDAHostedDataset.__init__(
self,
graph_definition=graph_definition,
download_dir=download_dir,
backend=backend,
train_dataloader_kwargs=train_dataloader_kwargs,
validation_dataloader_kwargs=validation_dataloader_kwargs,
test_dataloader_kwargs=test_dataloader_kwargs,
)

def _prepare_args(
self, backend: str, features: List[str], truth: List[str]
) -> Tuple[Dict[str, Any], Union[List[int], None], Union[List[int], None]]:
"""Prepare arguments for dataset.

Args:
backend: backend of dataset. Either "parquet" or "sqlite".
features: List of features from user to use as input.
truth: List of event-level truth variables from user.

Returns: Dataset arguments, train/val selection, test selection
"""
if backend == "sqlite":
dataset_path = glob(os.path.join(self.dataset_dir, "*.db"))
# Cast from list to string if just 1 path
if isinstance(dataset_path, list) & len(dataset_path) == 1:
dataset_path: str = dataset_path[0] # type: ignore

if self._mode == "train":
train_val = pd.read_parquet(
os.path.join(
self.dataset_dir,
"selections",
"train_selection.parquet",
)
)["event_no"].tolist()
test = None
elif self._mode == "test":
train_val = None
test = pd.read_parquet(
os.path.join(
self.dataset_dir,
"selections",
"test_noise_selection.parquet",
)
)["event_no"].tolist()
elif self._mode == "test-no-noise":
train_val = None
test = pd.read_parquet(
os.path.join(
self.dataset_dir,
"selections",
"test_selection.parquet",
)
)["event_no"].tolist()
elif backend == "parquet":
dataset_path = self.dataset_dir # type: ignore
if self._mode == "train":
train_val = pd.read_parquet(
os.path.join(
self.dataset_dir, "selections", "train_batches.parquet"
)
)["chunk_id"].tolist()
test = None
elif self._mode == "test":
train_val = None
test = pd.read_parquet(
os.path.join(
self.dataset_dir,
"selections",
"test_noise_batches.parquet",
)
)["chunk_id"].tolist()
elif self._mode == "test-no-noise":
train_val = None
test = pd.read_parquet(
os.path.join(
self.dataset_dir, "selections", "test_batches.parquet"
)
)["chunk_id"].tolist()
dataset_args = {
"truth_table": self._truth_table,
"pulsemaps": self._pulsemaps,
"path": dataset_path,
"graph_definition": self._graph_definition,
"features": features,
"truth": truth,
"labels": {
"direction": Direction(
azimuth_key="initial_state_azimuth",
zenith_key="initial_state_zenith",
),
"track": Track(
pid_key="initial_state_type", interaction_key="interaction"
),
},
}

return dataset_args, train_val, test


class SecretDataset(PublicBenchmarkDataset):
"""A Secret Dataset."""

def __init__(
self,
secret: str,
graph_definition: GraphDefinition,
download_dir: str,
backend: str = "parquet",
mode: str = "train",
train_dataloader_kwargs: Optional[Dict[str, Any]] = None,
validation_dataloader_kwargs: Dict[str, Any] = None,
test_dataloader_kwargs: Dict[str, Any] = None,
) -> None:
"""Download a secret Dataset with a ERDA sharelink ID.

The Dataset can be instatiated in three modes: "train", "test" or
"test-no-noise". When instantiated in "train" mode, input data is
read from the "merged_photons" table and dataloaders for training and
validation is constructed using a pre-defined selection of
events/chunks. The GraphDefinition passed to this dataset should in
this case apply time and charge smearing and subsequent merging of
coincident pulses in order to be comparable to the test set. NOTE that
the test set is not constructed in "train" mode.

If instantiated in "test" or "test-no-noise" mode,
already processed photons will be read from "pulses" or
"pulses_no_noise", respectively. GraphDefinition passed to the dataset
should in this case not smear charge and time variables, and should
not apply any merging.

Args:
secret: ERDA sharelink ID
graph_definition: Method that defines the data representation.
download_dir: Directory to download dataset to.
truth (Optional): List of event-level truth to include. Will
include all available information if not given.
features (Optional): List of input features from pulsemap to use.
If not given, all available features will be
used.
backend (Optional): data backend to use. Either "parquet" or
"sqlite". Defaults to "parquet".
train_dataloader_kwargs (Optional): Arguments for the training
DataLoader. Default None.
validation_dataloader_kwargs (Optional): Arguments for the
validation DataLoader, Default None.
test_dataloader_kwargs (Optional): Arguments for the test
DataLoader. Default None.
mode: Mode in which to instantiate the dataset in One of either
['train', 'test', 'test-no-noise'].
"""
self._experiment = "Unknown"
self._citation = "NA"
self._creator = "NA"
self._available_backends = [backend]
self._secret = secret
self._file_hashes = {backend: secret}

val_args = validation_dataloader_kwargs # line length..
super().__init__(
graph_definition=graph_definition,
download_dir=download_dir,
backend=backend,
mode=mode,
train_dataloader_kwargs=train_dataloader_kwargs,
validation_dataloader_kwargs=val_args,
test_dataloader_kwargs=test_dataloader_kwargs,
)
20 changes: 17 additions & 3 deletions src/graphnet/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,13 @@ def setup(self, stage: str) -> None:
self._train_dataset = self._create_dataset(
self._train_selection
)
else:
self._train_dataset = None

if self._val_selection is not None:
self._val_dataset = self._create_dataset(self._val_selection)
else:
self._val_dataset = None

return

Expand Down Expand Up @@ -377,12 +382,12 @@ def _resolve_selections(self) -> None:
self._selection
)

else: # selection is None
elif self._test_selection is None: # selection is None
# If not provided, we infer it by grabbing
# all event ids in the dataset.
self.info(
f"{self.__class__.__name__} did not receive an"
" for `selection`. Selection will "
f"{self.__class__.__name__} did not receive an argument"
" for `selection`. Selection "
"will automatically be created with a split of "
f"train: {self._train_val_split[0]} and "
f"validation: {self._train_val_split[1]}"
Expand All @@ -391,6 +396,15 @@ def _resolve_selections(self) -> None:
self._train_selection,
self._val_selection,
) = self._infer_selections() # type: ignore
else:
# Only test selection given - no training / val selection inferred
self.info(
f"{self.__class__.__name__} only received arguments for a"
" test selection. DataLoaders for training and validation"
" will not be available."
)
self._train_selection = None # type: ignore
self._val_selection = None # type: ignore

def _split_selection(
self, selection: Union[int, List[int], List[List[int]]]
Expand Down
Loading
Loading