Skip to content

Commit

Permalink
Fix deprecation of np.alltrue, iloc on pandas series
Browse files Browse the repository at this point in the history
  • Loading branch information
mhuen committed Apr 19, 2024
1 parent e7ef973 commit 7af8bd1
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 18 deletions.
35 changes: 18 additions & 17 deletions egenerator/data/modules/data/pulse_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,10 @@ def get_data_from_hdf(self, file, *args, **kwargs):
# create Dictionary with event IDs
size = len(_labels["Event"])
event_dict = {}
for row in _labels.iterrows():
event_dict[(row[1][0], row[1][1], row[1][2], row[1][3])] = row[0]
for idx, row in _labels.iterrows():
event_dict[
(row.iloc[0], row.iloc[1], row.iloc[2], row.iloc[3])
] = idx

# create empty array for DOM charges
x_dom_charge = np.zeros(
Expand Down Expand Up @@ -342,38 +344,37 @@ def get_data_from_hdf(self, file, *args, **kwargs):
# get pulse information
# ---------------------
for pulse_index, row in enumerate(pulses.itertuples()):
string = row[6]
dom = row[7]
string = row.string
dom = row.om
if dom > 60:
self._logger.warning(
"skipping pulse: {} {}".format(string, dom)
)
continue
index = event_dict[(row[1:5])]

# pulse charge: row[12], time: row[10]
# accumulate charge in DOMs
x_dom_charge[index, string - 1, dom - 1, 0] += row[12]
x_dom_charge[index, string - 1, dom - 1, 0] += row.charge

# gather pulses
if add_charge_quantiles:

# (charge, time, quantile)
cum_charge = float(x_dom_charge[index, string - 1, dom - 1, 0])
x_pulses[pulse_index] = [row[12], row[10], cum_charge]
x_pulses[pulse_index] = [row.charge, row.time, cum_charge]

else:
# (charge, time)
x_pulses[pulse_index] = [row[12], row[10]]
x_pulses[pulse_index] = [row.charge, row.time]

# gather pulse ids (batch index, string, dom)
x_pulses_ids[pulse_index] = [index, string - 1, dom - 1]

# update time window
if row[10] > x_time_window[index, 1]:
x_time_window[index, 1] = row[10]
if row[10] < x_time_window[index, 0]:
x_time_window[index, 0] = row[10]
if row.time > x_time_window[index, 1]:
x_time_window[index, 1] = row.time
if row.time < x_time_window[index, 0]:
x_time_window[index, 0] = row.time

# convert cumulative charge to fraction of total charge, e.g. quantile
if add_charge_quantiles:
Expand All @@ -398,8 +399,8 @@ def get_data_from_hdf(self, file, *args, **kwargs):
# -------------------
if time_exclusions is not None:
for tw_index, row in enumerate(time_exclusions.itertuples()):
string = row[6]
dom = row[7]
string = row.string
dom = row.om
if dom > 60:
self._logger.warning(
"skipping tw: {} {}".format(string, dom)
Expand All @@ -410,7 +411,7 @@ def get_data_from_hdf(self, file, *args, **kwargs):
# t_start (pulse time): row[10], t_end (pulse width): row[11]

# (t_start, t_end)
x_time_exclusions[tw_index] = [row[10], row[11]]
x_time_exclusions[tw_index] = [row.time, row.width]

# gather pulse ids (batch index, string, dom)
x_time_exclusions_ids[tw_index] = [index, string - 1, dom - 1]
Expand All @@ -420,8 +421,8 @@ def get_data_from_hdf(self, file, *args, **kwargs):
# ------------------
if dom_exclusions is not None:
for row in dom_exclusions.itertuples():
string = row[7]
dom = row[8]
string = row.string
dom = row.om
if dom > 60:
msg = "skipping exclusion DOM: {!r} {!r}"
self._logger.info(msg.format(string, dom))
Expand Down
2 changes: 1 addition & 1 deletion egenerator/data/trafo.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def _perform_update_step(self, trafo_log, data_batch, n, mean, M2, dtype):

# perform logarithm on bins
if trafo_log is not None:
if np.alltrue(trafo_log):
if np.all(trafo_log):
data_batch = np.log(1.0 + data_batch)
else:
for bin_i, log_bin in enumerate(trafo_log):
Expand Down

0 comments on commit 7af8bd1

Please sign in to comment.