Skip to content

Commit 38f78ee

Browse files
committed
complete v0.2.0 upgrade
## Breaking changes - Refactored the generation method of the RuntimeGenerate function: - Added support for ComponentArray type as parameter input, eliminating the need to convert it to Vector type. - Enabled broadcasting calculations between vectors and scalars in the RuntimeGenerate function, improving computational efficiency and gradient computation efficiency. - Modified the construction method of Route, now using rfluxes, dfluxes, and proj_func to build the Route type. - Revised the computational logic of Route, Bucket, and Flux, providing corresponding functions for both three-dimensional and two-dimensional data inputs, avoiding some complex internal data integration operations. - Added a macro construction method (under test).
1 parent 75e9e57 commit 38f78ee

14 files changed

+347
-276
lines changed

Manifest.toml

+153-146
Large diffs are not rendered by default.

Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ authors = ["jingx <50790703+chooron@users.noreply.github.com>"]
44
version = "0.2.0"
55

66
[deps]
7+
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
78
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
9+
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
810
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
911
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
1012
HydroModelTools = "31f4d4b3-d71d-422d-8932-b4ab24c6e7e3"

TODO.md

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
- 使用@generated的函数时,只能获取到类型,具体的实例属性是获取不到的
3030
- [ ] AbstractBucket分成HydroBucket和MultiBucket
3131
- [ ] 参数名称与输出输入名称不能冲突
32+
- [ ] Meta这个存储的内容得更改了
3233

3334

3435
julia语言中,我想使用@generated生成struct函数,但是这个生成函数需要的参数需要struct的属性值,这个属性值是术语Num,因此无法作为struct的类型参数,请问应该如何解决这个问题呢

release_note.txt

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
## Breaking changes
2+
3+
- Refactored the generation method of the RuntimeGenerate function:
4+
- Added support for ComponentArray type as parameter input, eliminating the need to convert it to Vector type.
5+
- Enabled broadcasting calculations between vectors and scalars in the RuntimeGenerate function, improving computational efficiency and gradient computation efficiency.
6+
- Modified the construction method of Route, now using rfluxes, dfluxes, and proj_func to build the Route type.
7+
- Revised the computational logic of Route, Bucket, and Flux, providing corresponding functions for both three-dimensional and two-dimensional data inputs, avoiding some complex internal data integration operations.
8+
- Added a macro construction method (under test).

src/HydroModels.jl

+2-7
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ using SymbolicUtils.Code
2424
import SymbolicUtils: BasicSymbolic, Sym, Term, iscall, operation, arguments, issym, symtype, sorted_arguments
2525
@reexport using ModelingToolkit: @variables, @parameters
2626
using ModelingToolkit: isparameter
27-
using ModelingToolkit: t_nounits as t
2827
# graph compute
2928
using Graphs
3029

@@ -38,9 +37,6 @@ using NNlib
3837

3938
## Abstract Component Types
4039
abstract type AbstractComponent end
41-
abstract type AbstractIOAdapter end
42-
abstract type AbstractHydroWrapper <: AbstractComponent end
43-
abstract type AbstractNeuralWrapper <: AbstractComponent end
4440

4541
abstract type AbstractFlux <: AbstractComponent end
4642
abstract type AbstractHydroFlux <: AbstractFlux end
@@ -54,16 +50,15 @@ abstract type AbstractRoute <: AbstractElement end
5450
abstract type AbstractHydroRoute <: AbstractRoute end
5551
abstract type AbstractModel <: AbstractComponent end
5652

57-
export AbstractComponent, AbstractHydroWrapper, AbstractNeuralWrapper
58-
export AbstractFlux, AbstractHydroFlux, AbstractNeuralFlux, AbstractStateFlux
53+
export AbstractComponent, AbstractFlux, AbstractHydroFlux, AbstractNeuralFlux, AbstractStateFlux
5954
export AbstractElement, AbstractBucket, AbstractHydrograph, AbstractRoute, AbstractHydroRoute, AbstractModel
6055

6156
# utils
6257
include("utils/expression.jl")
6358
include("utils/attribute.jl")
59+
include("utils/tools.jl")
6460
include("utils/display.jl")
6561
include("utils/build.jl")
66-
include("utils/sort.jl")
6762
include("utils/check.jl")
6863
#! A discrete ODE solver, if want to use more efficient solver, please import HydroModelTools.jl
6964
include("utils/solver.jl")

src/bucket.jl

+64-57
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,60 @@
11
"""
2-
HydroBucket(; funcs::Vector{<:AbstractHydroFlux}, dfuncs::Vector{<:AbstractStateFlux}=StateFlux[], name::Union{Symbol,Nothing}=nothing, sort_funcs::Bool=false)
2+
HydroBucket(;
3+
funcs::Vector{<:AbstractHydroFlux},
4+
dfuncs::Vector{<:AbstractStateFlux}=StateFlux[],
5+
name::Union{Symbol,Nothing}=nothing,
6+
sort_funcs::Bool=false
7+
)
38
4-
Represents a hydrological bucket model component.
9+
Represents a hydrological bucket model component that handles both single-node and multi-node computations.
510
611
# Arguments
7-
- `funcs::Vector{<:AbstractHydroFlux}`: A vector of flux functions that describe the hydrological processes.
8-
- `dfuncs::Vector{<:AbstractStateFlux}`: A vector of state derivative functions (default is an empty vector of StateFlux).
9-
- `name::Union{Symbol,Nothing}`: Optional name for the bucket. If not provided, a name will be automatically generated from state variable names.
10-
- `sort_funcs::Bool`: Whether to sort the flux functions (default is false).
12+
- `funcs::Vector{<:AbstractHydroFlux}`: Vector of flux functions describing hydrological processes
13+
- `dfuncs::Vector{<:AbstractStateFlux}`: Vector of state derivative functions for ODE calculations (default: empty)
14+
- `name::Union{Symbol,Nothing}`: Optional bucket identifier. Defaults to auto-generated name from state variables
15+
- `sort_funcs::Bool}`: Whether to topologically sort flux functions based on dependencies (default: false)
1116
1217
# Fields
13-
- `funcs::Vector{<:AbstractHydroFlux}`: Vector of flux functions describing hydrological processes.
14-
- `dfuncs::Vector{<:AbstractStateFlux}`: Vector of state derivative functions for ODE calculations.
15-
- `flux_func::Function`: Combined function for calculating all hydrological fluxes.
16-
- `ode_func::Union{Nothing,Function}`: Function for ordinary differential equations (ODE) calculations, or nothing if no ODE calculations are needed.
17-
- `meta::HydroMeta`: Contains metadata about the bucket, including input, output, state, parameter, and neural network names.
18+
- `fluxes::Vector{<:AbstractHydroFlux}`: Vector of flux functions describing hydrological processes
19+
- `dfluxes::Vector{<:AbstractStateFlux}`: Vector of state derivative functions for ODE calculations
20+
- `flux_func::Function`: Generated function for single-node flux calculations
21+
- `multi_flux_func::Function`: Generated function for multi-node parallel flux calculations
22+
- `ode_func::Union{Nothing,Function}`: Generated function for single-node ODE calculations
23+
- `multi_ode_func::Union{Nothing,Function}`: Generated function for multi-node parallel ODE calculations
24+
- `meta::ComponentVector`: Metadata containing model structure information
1825
1926
# Description
20-
HydroBucket is a structure that encapsulates the behavior of a hydrological bucket model.
21-
It combines multiple flux functions and state derivative functions to model water movement
22-
and storage within a hydrological unit.
23-
24-
The structure automatically extracts relevant information from the provided functions to
25-
populate the metadata, which includes names of:
26-
- Inputs: Variables that drive the model
27-
- Outputs: Variables produced by the model
28-
- States: Internal model states that evolve over time
29-
- Parameters: Model parameters that control behavior
30-
- Neural Networks: Any neural network components (if applicable)
31-
32-
The `flux_func` and `ode_func` are constructed based on the provided `funcs` and `dfuncs`,
33-
enabling efficient calculation of fluxes and state changes over time.
34-
35-
This structure is particularly useful for building complex hydrological models by combining
36-
multiple HydroBucket instances to represent different components of a water system.
37-
27+
HydroBucket is a type-stable implementation of a hydrological bucket model that supports both
28+
single-node and distributed (multi-node) computations. It automatically generates optimized
29+
functions for flux and ODE calculations based on the provided process functions.
30+
31+
## Model Structure
32+
The bucket model consists of:
33+
- Process functions (`fluxes`): Define water movement between storages
34+
- State derivatives (`dfluxes`): Define rate of change for state variables
35+
- Generated functions: Optimized implementations for both single and multi-node calculations
36+
37+
## Metadata Components
38+
The `meta` field tracks:
39+
- `inputs`: External forcing variables (e.g., precipitation, temperature)
40+
- `outputs`: Model-generated variables (e.g., runoff, evaporation)
41+
- `states`: Internal storage variables (e.g., soil moisture, groundwater)
42+
- `params`: Model parameters controlling process behavior
43+
- `nn_vars`: Neural network components (if any) for hybrid modeling
44+
45+
## Performance Features
46+
- Type-stable computations for both single and multi-node cases
47+
- Efficient broadcasting operations for vectorized calculations
48+
- Automatic function generation with optimized broadcasting
49+
- Support for ComponentArray parameters for structured data handling
50+
51+
## Usage Notes
52+
1. For single-node simulations: Use `flux_func` and `ode_func`
53+
2. For multi-node simulations: Use `multi_flux_func` and `multi_ode_func`
54+
3. Parameters should be provided as ComponentVector for type stability
55+
4. Broadcasting operations are automatically handled for multi-node cases
56+
57+
See also: [`AbstractHydroFlux`](@ref), [`AbstractStateFlux`](@ref), [`ComponentVector`](@ref)
3858
"""
3959
struct HydroBucket{S} <: AbstractBucket
4060
"Name of the bucket"
@@ -43,13 +63,13 @@ struct HydroBucket{S} <: AbstractBucket
4363
fluxes::Vector{<:AbstractHydroFlux}
4464
"Vector of state derivative functions for ODE calculations."
4565
dfluxes::Vector{<:AbstractStateFlux}
46-
"Generated function for calculating all hydrological fluxes."
66+
"Generated function for calculating all hydrological fluxes. (Supports single-node data)"
4767
flux_func::Function
48-
"Generated function for calculating all hydrological fluxes."
68+
"Generated function for calculating all hydrological fluxes. (Supports multi-nodes data)"
4969
multi_flux_func::Function
50-
"Generated function for ordinary differential equations (ODE) calculations, or nothing if no ODE calculations are needed."
70+
"Generated function for ordinary differential equations (ODE) calculations, or nothing if no ODE calculations are needed. (Supports single-node data)"
5171
ode_func::Union{Nothing,Function}
52-
"Generated function for ordinary differential equations (ODE) calculations, or nothing if no ODE calculations are needed."
72+
"Generated function for ordinary differential equations (ODE) calculations, or nothing if no ODE calculations are needed. (Supports multi-nodes data)"
5373
multi_ode_func::Union{Nothing,Function}
5474
"Metadata about the bucket, including input, output, state, parameter, and neural network names."
5575
meta::ComponentVector
@@ -118,10 +138,8 @@ The input dimensions must match the number of input variables defined in the mod
118138
Required parameters and initial states must be present in the pas argument.
119139
"""
120140
function (ele::HydroBucket{true})(
121-
input::AbstractArray{T,2},
122-
pas::ComponentVector;
123-
config::NamedTuple=NamedTuple(),
124-
kwargs...,
141+
input::AbstractArray{T,2}, pas::ComponentVector;
142+
config::NamedTuple=NamedTuple(), kwargs...
125143
) where {T}
126144
#* get kwargs
127145
solver = get(config, :solver, ManualSolver{true}())
@@ -139,15 +157,13 @@ function (ele::HydroBucket{true})(
139157
end
140158

141159
(ele::HydroBucket{false})(input::AbstractArray{T,2}, pas::ComponentVector; kwargs...) where {T} = begin
142-
flux_output = ele.flux_func(input, nothing, pas)
160+
flux_output = ele.flux_func(eachslice(input, dims=1), nothing, pas)
143161
permutedims(reduce(hcat, flux_output))
144162
end
145163

146164
function (ele::HydroBucket{true})(
147-
input::AbstractArray{T,3},
148-
pas::ComponentVector;
149-
config::NamedTuple=NamedTuple(),
150-
kwargs...,
165+
input::AbstractArray{T,3}, pas::ComponentVector;
166+
config::NamedTuple=NamedTuple(), kwargs...
151167
) where {T}
152168
input_dims, num_nodes, time_len = size(input)
153169

@@ -159,15 +175,11 @@ function (ele::HydroBucket{true})(
159175
timeidx = get(config, :timeidx, collect(1:size(input, 3)))
160176

161177
#* prepare states parameters and nns
162-
params = view(pas, :params)
163-
nn_params = isempty(get_nn_vars(ele)) ? ones(eltype(pas), num_nodes) : view(pas, :nns)
164-
expand_params = ComponentVector(NamedTuple{Tuple(get_param_names(ele))}([params[p][ptyidx] for p in get_param_names(ele)]))
165-
new_pas = ComponentVector(params=expand_params, nns=nn_params)
166-
initstates_mat = view(reshape(Vector(view(pas, :initstates)), num_nodes, :)', :, styidx)
178+
new_pas = expand_component_params(pas, ptyidx)
179+
initstates_mat = expand_component_initstates(pas, styidx)
167180

168181
#* prepare input function
169-
input_reshape = reshape(input, input_dims * num_nodes, time_len)
170-
itpfuncs = interp(input_reshape, timeidx)
182+
itpfuncs = interp(reshape(input, input_dims * num_nodes, time_len), timeidx)
171183
solved_states = solver(
172184
(u, p, t) -> begin
173185
tmp_input = reshape(itpfuncs(t), input_dims, num_nodes)
@@ -182,16 +194,11 @@ function (ele::HydroBucket{true})(
182194
end
183195

184196
function (ele::HydroBucket{false})(
185-
input::AbstractArray{T,3},
186-
pas::ComponentVector;
187-
config::NamedTuple=NamedTuple(),
188-
kwargs...,
197+
input::AbstractArray{T,3}, pas::ComponentVector;
198+
config::NamedTuple=NamedTuple(), kwargs...,
189199
) where {T}
190200
ptyidx = get(config, :ptyidx, 1:size(input, 2))
191-
params = view(pas, :params)
192-
nn_params = isempty(get_nn_vars(ele)) ? ones(eltype(pas), size(input, 2)) : view(pas, :nns)
193-
expand_params = ComponentVector(NamedTuple{Tuple(get_param_names(ele))}([params[p][ptyidx] for p in get_param_names(ele)]))
194-
new_pas = ComponentVector(params=expand_params, nns=nn_params)
201+
new_pas = expand_component_params(pas, ptyidx)
195202
#* run other functions
196203
output = ele.flux_func(eachslice(input, dims=1), nothing, new_pas)
197204
permutedims(reduce((m1, m2) -> cat(m1, m2, dims=3), output), (3, 1, 2))

src/flux.jl

+4-3
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ where traditional equations may be insufficient or unknown.
139139
struct NeuralFlux <: AbstractNeuralFlux
140140
"Name of the flux"
141141
name::Symbol
142+
"chain of the neural network"
143+
chain::LuxCore.AbstractLuxLayer
142144
"Compiled function that calculates the flux using the neural network"
143145
func::Function
144146
"Metadata about the flux, including input, output, and neural network parameter names"
@@ -154,7 +156,6 @@ struct NeuralFlux <: AbstractNeuralFlux
154156
chain_name::Union{Symbol,Nothing}=nothing,
155157
) where {T<:Num}
156158
#* Check chain name
157-
@assert chain.name isa Symbol "Neural network chain should have a name with Symbol type"
158159
chain_name = chain_name === nothing ? chain.name : chain_name
159160

160161
ps, st = Lux.setup(StableRNG(42), chain)
@@ -165,8 +166,8 @@ struct NeuralFlux <: AbstractNeuralFlux
165166

166167
meta = ComponentVector(inputs=inputs, outputs=outputs, nns=NamedTuple{Tuple([chain_name])}([chain_params]))
167168
nninfos = (inputs=nn_input_name, outputs=nn_output_name, nns=chain_name)
168-
flux_name = isnothing(name) ? Symbol("##neural_flux#", meta) : name
169-
new(flux_name, nn_func, meta, nninfos)
169+
flux_name = isnothing(name) ? Symbol("##neural_flux#", hash(meta)) : name
170+
new(flux_name, chain, nn_func, meta, nninfos)
170171
end
171172

172173
#* construct neural flux with input fluxes and output fluxes

src/route.jl

+3-6
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ struct HydroRoute <: AbstractHydroRoute
6666
#* define the route name
6767
route_name = isnothing(name) ? Symbol("##route#", hash(meta)) : name
6868
#* build the route function
69-
multi_flux_func, multi_ode_func = build_route_func(rfluxes, dfluxes, proj_func, meta)
69+
multi_flux_func, multi_ode_func = build_route_func(rfluxes, dfluxes, meta)
7070
return new(route_name, rfluxes, multi_flux_func, multi_ode_func, proj_func, meta)
7171
end
7272
end
@@ -223,11 +223,8 @@ function (route::HydroRoute)(
223223
timeidx = get(config, :timeidx, collect(1:time_len))
224224

225225
#* prepare states parameters and nns
226-
params = view(pas, :params)
227-
expand_params = ComponentVector(NamedTuple{Tuple(get_param_names(route))}([params[p][ptyidx] for p in get_param_names(route)]))
228-
nn_params = isempty(get_nn_vars(route)) ? Vector{eltype(pas)}[] : view(pas, :nns)
229-
new_pas = ComponentVector(params=expand_params) # , nns=nn_params
230-
initstates_mat = view(reshape(Vector(view(pas, :initstates)), num_nodes, :)', :, styidx)
226+
new_pas = expand_component_params(pas, ptyidx)
227+
initstates_mat = expand_component_initstates(pas, styidx)
231228

232229
#* prepare input function
233230
input_reshape = reshape(input, input_dims * num_nodes, time_len)

0 commit comments

Comments
 (0)