Skip to content

Commit

Permalink
Merge pull request #757 from DHI/strongly_typed
Browse files Browse the repository at this point in the history
Dataclass instead of tuple
  • Loading branch information
ecomodeller authored Jan 24, 2025
2 parents 7179198 + acfbab7 commit c52ba1f
Showing 1 changed file with 139 additions and 102 deletions.
241 changes: 139 additions & 102 deletions mikeio/generic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Generic functions for working with all types of dfs files."""

from __future__ import annotations
from dataclasses import dataclass
import math
import os
import pathlib
Expand All @@ -21,6 +22,7 @@
DfsNonEqTimeAxis,
DfsEqCalendarAxis,
DfsNonEqCalendarAxis,
TimeAxisType,
)
from mikecore.DfsFileFactory import DfsFileFactory
from mikecore.eum import eumQuantity
Expand Down Expand Up @@ -597,10 +599,7 @@ def extract(

is_layered_dfsu = dfs_i.ItemInfo[0].Name == "Z coordinate"

file_start_new, start_step, start_sec, end_step, end_sec = _parse_start_end(
dfs_i.FileInfo.TimeAxis, start, end
)
timestep = _parse_step(dfs_i.FileInfo.TimeAxis, step)
time = _TimeInfo.parse(dfs_i.FileInfo.TimeAxis, start, end, step)
item_numbers = _valid_item_numbers(
dfs_i.ItemInfo, items, ignore_first=is_layered_dfsu
)
Expand All @@ -612,28 +611,28 @@ def extract(
dfs_o = _clone(
str(infilename),
str(outfilename),
start_time=file_start_new,
timestep=timestep,
start_time=time.file_start_new,
timestep=time.timestep,
items=item_numbers,
)

file_start_shift = 0
if file_start_new is not None:
if time.file_start_new is not None:
file_start_orig = dfs_i.FileInfo.TimeAxis.StartDateTime
file_start_shift = (file_start_new - file_start_orig).total_seconds()
file_start_shift = (time.file_start_new - file_start_orig).total_seconds()

timestep_out = -1
for timestep in range(start_step, end_step, step):
for timestep in range(time.start_step, time.end_step, step):
for item_out, item in enumerate(item_numbers):
itemdata = dfs_i.ReadItemTimeStep((item + 1), timestep)
time_sec = itemdata.Time

if time_sec > end_sec:
if time_sec > time.end_sec:
dfs_i.Close()
dfs_o.Close()
return

if time_sec >= start_sec:
if time_sec >= time.start_sec:
if item == item_numbers[0]:
timestep_out = timestep_out + 1
time_sec_out = time_sec - file_start_shift
Expand All @@ -647,100 +646,138 @@ def extract(
dfs_o.Close()


def _parse_start_end(
time_axis: TimeAxis,
start: int | float | str | datetime,
end: int | float | str | datetime,
) -> tuple[datetime | None, int, float, int, float]: # TODO better return type
"""Helper function for parsing start and end arguments."""
n_time_steps = time_axis.NumberOfTimeSteps
file_start_datetime = time_axis.StartDateTime
file_start_sec = time_axis.StartTimeOffset
start_sec = file_start_sec

timespan = 0
if time_axis.TimeAxisType == 3:
timespan = time_axis.TimeStep * (n_time_steps - 1)
elif time_axis.TimeAxisType == 4:
timespan = time_axis.TimeSpan
else:
raise ValueError("TimeAxisType not supported")

file_end_sec = start_sec + timespan
end_sec = file_end_sec

start_step = 0
if isinstance(start, int):
start_step = start
elif isinstance(start, float):
start_sec = start
elif isinstance(start, str):
parts = start.split(",")
start = parts[0]
if len(parts) == 2:
end = parts[1]
start = pd.to_datetime(start)

if isinstance(start, datetime):
start_sec = (start - file_start_datetime).total_seconds()

end_step = n_time_steps
if isinstance(end, int):
if end < 0:
end = end_step + end + 1
end_step = end
elif isinstance(end, float):
end_sec = end
elif isinstance(end, str):
end = pd.to_datetime(end)

if isinstance(end, datetime):
end_sec = (end - file_start_datetime).total_seconds()

if start_step < 0:
raise ValueError(
f"start cannot be before start of file. start={start_step} is invalid"
)
@dataclass
class _TimeInfo:
"""Parsed time information.
Attributes
----------
file_start_new : datetime | None
new start time for the new file
start_step : int
start step
start_sec : float
start time in seconds
end_step : int
end step
end_sec : float
end time in seconds
timestep : float | None
timestep in seconds
"""

file_start_new: datetime | None
start_step: int
start_sec: float
end_step: int
end_sec: float
timestep: float | None

if start_sec < file_start_sec:
raise ValueError(
f"start cannot be before start of file start={start_step} is invalid"
@staticmethod
def parse(
time_axis: TimeAxis,
start: int | float | str | datetime,
end: int | float | str | datetime,
step: int,
) -> _TimeInfo:
"""Helper function for parsing start and end arguments."""
n_time_steps = time_axis.NumberOfTimeSteps
file_start_datetime = time_axis.StartDateTime
file_start_sec = time_axis.StartTimeOffset
start_sec = file_start_sec

timespan = 0
if time_axis.TimeAxisType == 3:
timespan = time_axis.TimeStep * (n_time_steps - 1)
elif time_axis.TimeAxisType == 4:
timespan = time_axis.TimeSpan
else:
raise ValueError("TimeAxisType not supported")

file_end_sec = start_sec + timespan
end_sec = file_end_sec

start_step = 0
if isinstance(start, int):
start_step = start
elif isinstance(start, float):
start_sec = start
elif isinstance(start, str):
parts = start.split(",")
start = parts[0]
if len(parts) == 2:
end = parts[1]
start = pd.to_datetime(start)

if isinstance(start, datetime):
start_sec = (start - file_start_datetime).total_seconds()

end_step = n_time_steps
if isinstance(end, int):
if end < 0:
end = end_step + end + 1
end_step = end
elif isinstance(end, float):
end_sec = end
elif isinstance(end, str):
end = pd.to_datetime(end)

if isinstance(end, datetime):
end_sec = (end - file_start_datetime).total_seconds()

if start_step < 0:
raise ValueError(
f"start cannot be before start of file. start={start_step} is invalid"
)

if start_sec < file_start_sec:
raise ValueError(
f"start cannot be before start of file start={start_step} is invalid"
)

if (end_sec < start_sec) or (end_step < start_step):
raise ValueError("end must be after start")

if end_step > n_time_steps:
raise ValueError(
f"end cannot be after end of file. end={end_step} is invalid."
)

if end_sec > file_end_sec:
raise ValueError(
f"end cannot be after end of file. end={end_sec} is invalid."
)

file_start_new = None
if time_axis.TimeAxisType == TimeAxisType.CalendarEquidistant:
dt = time_axis.TimeStep
if (start_sec > file_start_sec) and (start_step == 0):
# we can find the coresponding step
start_step = int((start_sec - file_start_sec) / dt)
file_start_new = file_start_datetime + timedelta(seconds=start_step * dt)
elif time_axis.TimeAxisType == TimeAxisType.CalendarNonEquidistant:
if start_sec > file_start_sec:
file_start_new = file_start_datetime + timedelta(seconds=start_sec)

timestep = _TimeInfo._parse_step(time_axis, step)

return _TimeInfo(
file_start_new, start_step, start_sec, end_step, end_sec, timestep
)

if (end_sec < start_sec) or (end_step < start_step):
raise ValueError("end must be after start")

if end_step > n_time_steps:
raise ValueError(f"end cannot be after end of file. end={end_step} is invalid.")

if end_sec > file_end_sec:
raise ValueError(f"end cannot be after end of file. end={end_sec} is invalid.")

file_start_new = None
if time_axis.TimeAxisType == 3:
dt = time_axis.TimeStep
if (start_sec > file_start_sec) and (start_step == 0):
# we can find the coresponding step
start_step = int((start_sec - file_start_sec) / dt)
file_start_new = file_start_datetime + timedelta(seconds=start_step * dt)
elif time_axis.TimeAxisType == 4:
if start_sec > file_start_sec:
file_start_new = file_start_datetime + timedelta(seconds=start_sec)

return file_start_new, start_step, start_sec, end_step, end_sec


def _parse_step(time_axis: TimeAxis, step: int) -> float | None:
"""Helper function for parsing step argument."""
if step == 1:
timestep = None
elif time_axis.TimeAxisType == 3:
timestep = time_axis.TimeStep * step
elif time_axis.TimeAxisType == 4:
timestep = None
else:
raise ValueError("TimeAxisType not supported")
return timestep
@staticmethod
def _parse_step(time_axis: TimeAxis, step: int) -> float | None:
"""Helper function for parsing step argument."""
if step == 1:
timestep = None
elif time_axis.TimeAxisType == TimeAxisType.CalendarEquidistant:
timestep = time_axis.TimeStep * step
elif time_axis.TimeAxisType == TimeAxisType.CalendarNonEquidistant:
timestep = None
else:
raise ValueError("TimeAxisType not supported")
return timestep


def avg_time(
Expand Down

0 comments on commit c52ba1f

Please sign in to comment.