Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CLI for reboost #36

Merged
merged 11 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 70 additions & 54 deletions src/reboost/build_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from lgdo.lh5 import LH5Iterator, LH5Store
from numpy.typing import ArrayLike

from reboost import utils

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -174,8 +176,8 @@ def get_stp_evtids(


def build_glm(
stp_file: str,
glm_file: str | None,
stp_files: str | list[str],
glm_files: str | list[str] | None,
*,
out_table_name: str = "glm",
id_name: str = "g4_evtid",
Expand All @@ -191,9 +193,9 @@ def build_glm(

Parameters
----------
stp_file
stp_files
path to the stp (input) file.
glm_file
glm_files
path to the glm data, can also be `None` in which case an `ak.Array` is returned in memory.
out_table_name
name for the output table.
Expand All @@ -208,72 +210,86 @@ def build_glm(
-------
either `None` or an `ak.Array`
"""
msg = f"Start generating glm for {stp_file} to {glm_file}"
log.info(msg)
store = LH5Store()
files = utils.get_file_dict(stp_files=stp_files, glm_files=glm_files)

# loop over files
glm_sum = {}

# loop over the lh5_tables
lh5_table_list = [table for table in lh5.ls(stp_file, "stp/") if table != "stp/vertices"]
for file_idx, stp_file in enumerate(files.stp):
msg = f"Start generating glm for {stp_file} "
log.info(msg)

# get rows in the table
if glm_file is None:
glm_sum = {lh5_table.replace("stp/", ""): None for lh5_table in lh5_table_list}
else:
glm_sum = None
# loop over the lh5_tables
lh5_table_list = [table for table in lh5.ls(stp_file, "stp/") if table != "stp/vertices"]

# start row for each table
start_row = {lh5_tab: 0 for lh5_tab in lh5_table_list}
# get rows in the table
if files.glm is None:
for lh5_table in lh5_table_list:
if lh5_table.replace("stp/", "") not in glm_sum:
glm_sum[lh5_table.replace("stp/", "")] = None
else:
glm_sum = None

vfield = f"stp/vertices/{id_name}"
# start row for each table
start_row = {lh5_tab: 0 for lh5_tab in lh5_table_list}

# iterate over the vertex table
for vert_obj, vidx, n_evtid in LH5Iterator(stp_file, vfield, buffer_len=evtid_buffer):
# range of vertices
vert_ak = vert_obj.view_as("ak")[:n_evtid]
vfield = f"stp/vertices/{id_name}"

msg = f"... read chunk {vidx}"
log.debug(msg)
# iterate over the vertex table
for vert_obj, vidx, n_evtid in LH5Iterator(stp_file, vfield, buffer_len=evtid_buffer):
# range of vertices
vert_ak = vert_obj.view_as("ak")[:n_evtid]

for idx, lh5_table in enumerate(lh5_table_list):
# create the output table
out_tab = Table(size=len(vert_ak))
msg = f"... read chunk {vidx}"
log.debug(msg)

# read the stp rows starting from `start_row` until the
# evtid is larger than that in the vertices
for idx, lh5_table in enumerate(lh5_table_list):
# create the output table
out_tab = Table(size=len(vert_ak))

start_row_tmp, chunk_row, evtids = get_stp_evtids(
lh5_table,
stp_file,
id_name,
start_row[lh5_table],
last_vertex_evtid=vert_ak[-1],
stp_buffer=stp_buffer,
)
# read the stp rows starting from `start_row` until the
# evtid is larger than that in the vertices

# set the start row for the next chunk
start_row[lh5_table] = start_row_tmp
start_row_tmp, chunk_row, evtids = get_stp_evtids(
lh5_table,
stp_file,
id_name,
start_row[lh5_table],
last_vertex_evtid=vert_ak[-1],
stp_buffer=stp_buffer,
)

# now get the glm rows
glm = get_glm_rows(evtids, vert_ak, start_row=chunk_row)
# set the start row for the next chunk
start_row[lh5_table] = start_row_tmp

for field in ["evtid", "n_rows", "start_row"]:
out_tab.add_field(field, Array(glm[field].to_numpy()))
# now get the glm rows
glm = get_glm_rows(evtids, vert_ak, start_row=chunk_row)

# write the output file
mode = "of" if (vidx == 0 and idx == 0) else "append"
for field in ["evtid", "n_rows", "start_row"]:
out_tab.add_field(field, Array(glm[field].to_numpy()))

lh5_subgroup = lh5_table.replace("stp/", "")
# write the output file
mode = "of" if (vidx == 0 and idx == 0) else "append"

lh5_subgroup = lh5_table.replace("stp/", "")

if files.glm is not None:
store.write(
out_tab,
f"{out_table_name}/{lh5_subgroup}",
files.glm[file_idx],
wo_mode=mode,
)
else:
glm_sum[lh5_subgroup] = (
copy.deepcopy(glm)
if glm_sum[lh5_subgroup] is None
else ak.concatenate((glm_sum[lh5_subgroup], glm))
)
msg = f"Finished generating glm for {stp_file} "
log.info(msg)

if glm_file is not None:
store.write(out_tab, f"{out_table_name}/{lh5_subgroup}", glm_file, wo_mode=mode)
else:
glm_sum[lh5_subgroup] = (
copy.deepcopy(glm)
if glm_sum[lh5_subgroup] is None
else ak.concatenate((glm_sum[lh5_subgroup], glm))
)
msg = f"Finished generating glm for {stp_file} to {glm_file}"
log.info(msg)
# return if it was requested to keep glm in memory
if glm_sum is not None:
return ak.Array(glm_sum)
Expand Down
26 changes: 11 additions & 15 deletions src/reboost/build_hit.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,9 @@
def build_hit(
config: Mapping | str,
args: Mapping | AttrsDict,
stp_files: list | str,
glm_files: list | str,
hit_files: list | str | None,
stp_files: str | list[str],
glm_files: str | list[str],
hit_files: str | list[str] | None,
*,
start_evtid: int = 0,
n_evtid: int | None = None,
Expand Down Expand Up @@ -231,19 +231,14 @@ def build_hit(
)

# get the input files
files = {}
for file_type, file_list in zip(["stp", "glm", "hit"], [stp_files, glm_files, hit_files]):
if isinstance(file_list, str):
files[file_type] = [file_list]
else:
files[file_type] = file_list
files = utils.get_file_dict(stp_files=stp_files, glm_files=glm_files, hit_files=hit_files)

output_tables = {}
# iterate over files
for file_idx, (stp_file, glm_file) in enumerate(zip(files["stp"], files["glm"])):
for file_idx, (stp_file, glm_file) in enumerate(zip(files.stp, files.glm)):
msg = (
f"... starting post processing of {stp_file} to {files['hit'][file_idx]} "
if files["hit"] is not None
f"... starting post processing of {stp_file} to {files.hit[file_idx]} "
if files.hit is not None
else f"... starting post processing of {stp_file}"
)
log.info(msg)
Expand All @@ -252,6 +247,7 @@ def build_hit(
for group_idx, proc_group in enumerate(config["processing_groups"]):
proc_name = proc_group.get("name", "default")
time_dict[proc_name] = ProfileDict()

# extract the output detectors and the mapping to input detectors
detectors_mapping = utils.merge_dicts(
[
Expand Down Expand Up @@ -297,7 +293,7 @@ def build_hit(

for out_det_idx, out_detector in enumerate(out_detectors):
# loop over the rows
if out_detector not in output_tables and files["hit"] is None:
if out_detector not in output_tables and files.hit is None:
output_tables[out_detector] = None

hit_table = core.evaluate_hit_table_layout(
Expand Down Expand Up @@ -341,14 +337,14 @@ def build_hit(
)

# now write
if files["hit"] is not None:
if files.hit is not None:
if time_dict is not None:
start_time = time.time()

lh5.write(
hit_table,
f"{out_detector}/{out_field}",
files["hit"][file_idx],
files.hit[file_idx],
wo_mode=wo_mode,
)
if time_dict is not None:
Expand Down
Loading
Loading