Skip to content

Commit fb04fb0

Browse files
committed
update with new version
1 parent 4cde39d commit fb04fb0

File tree

3 files changed

+155
-77
lines changed

3 files changed

+155
-77
lines changed

src/utils/attribute.jl

+6-29
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@ get_name(cpt::AbstractComponent)::Symbol = cpt.name
1212
Get the input variables or their names from a component's metadata.
1313
Returns empty vector if no inputs are defined.
1414
"""
15-
get_input_vars(cpt::AbstractComponent)::AbstractVector{Num} = haskey(cpt.meta, :inputs) ? cpt.meta.inputs : Num[]
16-
get_input_names(cpt::AbstractComponent)::AbstractVector{Symbol} = Symbolics.tosymbol.(get_input_vars(cpt))
15+
get_input_names(cpt::AbstractComponent)::AbstractVector{Symbol} = haskey(cpt.infos, :inputs) ? cpt.infos.inputs : Symbol[]
1716

1817
"""
1918
get_output_vars(cpt::AbstractComponent)::AbstractVector{Num}
@@ -22,8 +21,7 @@ get_input_names(cpt::AbstractComponent)::AbstractVector{Symbol} = Symbolics.tosy
2221
Get the output variables or their names from a component's metadata.
2322
Returns empty vector if no outputs are defined.
2423
"""
25-
get_output_vars(cpt::AbstractComponent)::AbstractVector{Num} = haskey(cpt.meta, :outputs) ? cpt.meta.outputs : Num[]
26-
get_output_names(cpt::AbstractComponent)::AbstractVector{Symbol} = Symbolics.tosymbol.(get_output_vars(cpt))
24+
get_output_names(cpt::AbstractComponent)::AbstractVector{Symbol} = haskey(cpt.infos, :outputs) ? cpt.infos.outputs : Symbol[]
2725

2826
"""
2927
get_state_vars(cpt::AbstractComponent)::AbstractVector{Num}
@@ -32,8 +30,7 @@ get_output_names(cpt::AbstractComponent)::AbstractVector{Symbol} = Symbolics.tos
3230
Get the state variables or their names from a component's metadata.
3331
Returns empty vector if no states are defined.
3432
"""
35-
get_state_vars(cpt::AbstractComponent)::AbstractVector{Num} = haskey(cpt.meta, :states) ? cpt.meta.states : Num[]
36-
get_state_names(cpt::AbstractComponent)::AbstractVector{Symbol} = Symbolics.tosymbol.(get_state_vars(cpt))
33+
get_state_names(cpt::AbstractComponent)::AbstractVector{Symbol} = haskey(cpt.infos, :states) ? cpt.infos.states : Symbol[]
3734

3835
"""
3936
get_param_vars(cpt::AbstractComponent)::AbstractVector{Num}
@@ -42,8 +39,8 @@ get_state_names(cpt::AbstractComponent)::AbstractVector{Symbol} = Symbolics.tosy
4239
Get the parameter variables or their names from a component's metadata.
4340
Returns empty vector if no parameters are defined.
4441
"""
45-
get_param_vars(cpt::AbstractComponent)::AbstractVector{Num} = haskey(cpt.meta, :params) ? cpt.meta.params : Num[]
46-
get_param_names(cpt::AbstractComponent)::AbstractVector{Symbol} = Symbolics.tosymbol.(get_param_vars(cpt))
42+
get_param_vars(cpt::AbstractComponent)::AbstractVector{Num} = haskey(cpt.infos, :params) ? cpt.infos.params : Num[]
43+
get_param_names(cpt::AbstractComponent)::AbstractVector{Symbol} = haskey(cpt.infos, :params) ? cpt.infos.params : Symbol[]
4744

4845
"""
4946
get_nn_vars(cpt::AbstractComponent)::AbstractVector
@@ -52,16 +49,8 @@ get_param_names(cpt::AbstractComponent)::AbstractVector{Symbol} = Symbolics.tosy
5249
Get the neural network variables or their names from a component's metadata.
5350
Returns empty vector/ComponentVector if no neural networks are defined.
5451
"""
55-
get_nn_vars(cpt::AbstractComponent)::AbstractVector = haskey(cpt.meta, :nns) ? cpt.meta.nns : ComponentVector()
56-
get_nn_names(cpt::AbstractComponent)::AbstractVector{Symbol} = haskey(cpt.meta, :nns) ? collect(keys(cpt.meta.nns)) : Symbol[]
52+
get_nn_names(cpt::AbstractComponent)::AbstractVector{Symbol} = haskey(cpt.infos, :nns) ? cpt.infos.nns : Symbol[]
5753

58-
"""
59-
get_all_vars(cpt::AbstractComponent)::AbstractVector{Num}
60-
61-
Get all variables (inputs, outputs, states, and neural networks) from a component.
62-
Returns the union of all variable types.
63-
"""
64-
get_all_vars(cpt::AbstractComponent)::AbstractVector{Num} = reduce(union, [get_input_vars(cpt), get_output_vars(cpt), get_state_vars(cpt), get_nn_vars(cpt)])
6554

6655
"""
6756
get_exprs(cpt::AbstractComponent)
@@ -70,18 +59,6 @@ Get the expressions defined in a component.
7059
"""
7160
get_exprs(cpt::AbstractComponent) = cpt.exprs
7261

73-
function get_all_vars(cpts::Vector{<:AbstractComponent})
74-
inputs, outputs, states = Vector{Num}(), Vector{Num}(), Vector{Num}()
75-
for cpt in cpts
76-
tmp_inputs, tmp_outputs, tmp_states = get_input_vars(cpt), get_output_vars(cpt), get_state_vars(cpt)
77-
union!(inputs, tmp_inputs)
78-
union!(outputs, tmp_outputs)
79-
union!(states, tmp_states)
80-
end
81-
new_inputs = setdiff(inputs, vcat(outputs, states))
82-
return new_inputs, outputs, states
83-
end
84-
8562
"""
8663
get_var_names(comps::AbstractComponent)
8764

src/utils/build.jl

+145-44
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,26 @@
1+
"""
2+
build_flux_func(inputs::Vector{Num}, outputs::Vector{Num}, params::Vector{Num}, exprs::Vector{Num})
3+
4+
Generates a runtime function that computes flux calculations based on symbolic expressions.
5+
6+
# Arguments
7+
- `inputs::Vector{Num}`: Vector of symbolic input variables that will be provided at runtime
8+
- `outputs::Vector{Num}`: Vector of symbolic output variables that will be computed
9+
- `params::Vector{Num}`: Vector of symbolic parameters used in the flux calculations
10+
- `exprs::Vector{Num}`: Vector of symbolic expressions defining how outputs are computed from inputs and parameters
11+
12+
# Returns
13+
- A runtime-generated function with signature `(inputs, pas)` where:
14+
- `inputs`: Vector of input values corresponding to the symbolic inputs
15+
- `pas`: A parameter struct containing fields matching the parameter names
16+
17+
# Details
18+
The function generates code that:
19+
1. Assigns input values from the input vector to local variables
20+
2. Retrieves parameter values from the parameter struct
21+
3. Computes outputs using the provided expressions
22+
4. Returns a vector of computed outputs
23+
"""
124
function build_flux_func(inputs::Vector{Num}, outputs::Vector{Num}, params::Vector{Num}, exprs::Vector{Num})
225
input_names, output_names = Symbolics.tosymbol.(inputs), Symbolics.tosymbol.(outputs)
326
param_names = Symbolics.tosymbol.(params)
@@ -16,39 +39,77 @@ function build_flux_func(inputs::Vector{Num}, outputs::Vector{Num}, params::Vect
1639
return generated_flux_func
1740
end
1841

42+
"""
43+
build_ele_func(fluxes::Vector{<:AbstractFlux}, dfluxes::Vector{<:AbstractStateFlux}, meta::ComponentVector)
44+
45+
Builds runtime-generated functions for both flux calculations and state differentials in a hydrological model element.
46+
47+
# Arguments
48+
- `fluxes::Vector{<:AbstractFlux}`: Vector of flux components that define the element's behavior
49+
- `dfluxes::Vector{<:AbstractStateFlux}`: Vector of state differential components that define state changes
50+
- `meta::ComponentVector`: Metadata containing:
51+
- `inputs`: Input variable names
52+
- `outputs`: Output variable names
53+
- `states`: State variable names
54+
- `params`: Parameter names
55+
56+
# Returns
57+
A tuple containing:
58+
- First element: `Vector{Function}` with two functions:
59+
1. Regular flux function `(inputs, states, pas) -> outputs`
60+
2. Multi-dimensional flux function for batch processing
61+
- Second element: Either `nothing` (if no states) or `Vector{Function}` with two functions:
62+
1. State differential function `(inputs, states, pas) -> dstates`
63+
2. Multi-dimensional state differential function for batch processing
64+
65+
# Details
66+
The function generates four types of runtime functions:
67+
1. Single-sample flux computation
68+
2. Multi-sample flux computation (batched)
69+
3. Single-sample state differential computation (if states exist)
70+
4. Multi-sample state differential computation (if states exist)
71+
72+
For neural network fluxes (`AbstractNeuralFlux`), the function handles:
73+
- Input tensor preparation
74+
- Neural network forward passes
75+
- Output tensor reshaping
76+
77+
For regular fluxes, it directly computes using provided expressions.
78+
```
79+
"""
1980
function build_ele_func(
2081
fluxes::Vector{<:AbstractFlux},
2182
dfluxes::Vector{<:AbstractStateFlux},
22-
meta::ComponentVector,
83+
infos::NamedTuple,
2384
)
24-
input_names, output_names = tosymbol.(meta.inputs), tosymbol.(meta.outputs)
25-
state_names, param_names = tosymbol.(meta.states), tosymbol.(meta.params)
85+
input_names, output_names = infos.inputs, infos.outputs
86+
state_names, param_names = infos.states, infos.params
2687

2788
input_define_calls = [:($i = inputs[$idx]) for (idx, i) in enumerate(input_names)]
2889
state_define_calls = [:($s = states[$idx]) for (idx, s) in enumerate(state_names)]
2990
params_assign_calls = [:($p = pas.params.$p) for p in param_names]
30-
nn_params_assign_calls = [:($nn = pas.nns.$nn) for nn in [nflux.nninfos[:nns] for nflux in filter(f -> f isa AbstractNeuralFlux, fluxes)]]
91+
nn_params_assign_calls = [:($nn = pas.nns.$nn) for nn in [nflux.infos[:nns][1] for nflux in filter(f -> f isa AbstractNeuralFlux, fluxes)]]
3192
define_calls = reduce(vcat, [input_define_calls, state_define_calls, params_assign_calls, nn_params_assign_calls])
3293

3394
# varibles definitions expressions
3495
state_compute_calls, multi_state_compute_calls, flux_compute_calls, multi_flux_compute_calls = [], [], [], []
3596
for f in fluxes
3697
if f isa AbstractNeuralFlux
37-
append!(state_compute_calls, [:($(f.nninfos[:inputs]) = [$(get_input_names(f)...)])])
38-
push!(state_compute_calls, :($(f.nninfos[:outputs]) = $(f.func)($(f.nninfos[:inputs]), $(f.nninfos[:nns]))))
39-
append!(state_compute_calls, [:($(nm) = $(f.nninfos[:outputs])[$i]) for (i, nm) in enumerate(get_output_names(f))])
98+
append!(state_compute_calls, [:($(f.infos[:nn_inputs]) = [$(get_input_names(f)...)])])
99+
push!(state_compute_calls, :($(f.infos[:nn_outputs]) = $(f.func)($(f.infos[:nn_inputs]), $(f.infos[:nns][1]))))
100+
append!(state_compute_calls, [:($(nm) = $(f.infos[:nn_outputs])[$i]) for (i, nm) in enumerate(get_output_names(f))])
40101

41-
append!(multi_state_compute_calls, [:($(f.nninfos[:inputs]) = permutedims(reduce(hcat, [$(get_input_names(f)...)])))])
42-
push!(multi_state_compute_calls, :($(f.nninfos[:outputs]) = $(f.func)($(f.nninfos[:inputs]), $(f.nninfos[:nns]))))
43-
append!(multi_state_compute_calls, [:($(nm) = $(f.nninfos[:outputs])[$i, :]) for (i, nm) in enumerate(get_output_names(f))])
102+
append!(multi_state_compute_calls, [:($(f.infos[:nn_inputs]) = permutedims(reduce(hcat, [$(get_input_names(f)...)])))])
103+
push!(multi_state_compute_calls, :($(f.infos[:nn_outputs]) = $(f.func)($(f.infos[:nn_inputs]), $(f.infos[:nns][1]))))
104+
append!(multi_state_compute_calls, [:($(nm) = $(f.infos[:nn_outputs])[$i, :]) for (i, nm) in enumerate(get_output_names(f))])
44105

45-
append!(flux_compute_calls, [:($(f.nninfos[:inputs]) = permutedims(reduce(hcat, [$(get_input_names(f)...)])))])
46-
push!(flux_compute_calls, :($(f.nninfos[:outputs]) = $(f.func)($(f.nninfos[:inputs]), $(f.nninfos[:nns]))))
47-
append!(flux_compute_calls, [:($(nm) = $(f.nninfos[:outputs])[$i, :]) for (i, nm) in enumerate(get_output_names(f))])
106+
append!(flux_compute_calls, [:($(f.infos[:nn_inputs]) = permutedims(reduce(hcat, [$(get_input_names(f)...)])))])
107+
push!(flux_compute_calls, :($(f.infos[:nn_outputs]) = $(f.func)($(f.infos[:nn_inputs]), $(f.infos[:nns][1]))))
108+
append!(flux_compute_calls, [:($(nm) = $(f.infos[:nn_outputs])[$i, :]) for (i, nm) in enumerate(get_output_names(f))])
48109

49-
append!(multi_flux_compute_calls, [:($(f.nninfos[:inputs]) = permutedims(reduce((m1, m2) -> cat(m1, m2, dims=3), [$(get_input_names(f)...)]), (3, 1, 2)))])
50-
push!(multi_flux_compute_calls, :($(f.nninfos[:outputs]) = permutedims(reduce((m1, m2) -> cat(m1, m2, dims=3), $(f.func).(eachslice($(f.nninfos[:inputs]), dims=2), Ref($(f.nninfos[:nns])))), (1, 3, 2))))
51-
append!(multi_flux_compute_calls, [:($(nm) = $(f.nninfos[:outputs])[$i, :, :]) for (i, nm) in enumerate(get_output_names(f))])
110+
append!(multi_flux_compute_calls, [:($(f.infos[:nn_inputs]) = permutedims(reduce((m1, m2) -> cat(m1, m2, dims=3), [$(get_input_names(f)...)]), (3, 1, 2)))])
111+
push!(multi_flux_compute_calls, :($(f.infos[:nn_outputs]) = permutedims(reduce((m1, m2) -> cat(m1, m2, dims=3), $(f.func).(eachslice($(f.infos[:nn_inputs]), dims=2), Ref($(f.infos[:nns][1])))), (1, 3, 2))))
112+
append!(multi_flux_compute_calls, [:($(nm) = $(f.infos[:nn_outputs])[$i, :, :]) for (i, nm) in enumerate(get_output_names(f))])
52113
else
53114
append!(state_compute_calls, [:($nm = $(toexpr(expr))) for (nm, expr) in zip(get_output_names(f), f.exprs)])
54115
append!(multi_state_compute_calls, [:($nm = $(toexprv2(unwrap(expr)))) for (nm, expr) in zip(get_output_names(f), f.exprs)])
@@ -79,57 +140,97 @@ function build_ele_func(
79140
$(return_flux)
80141
end)
81142

82-
diff_func_expr = :(function (inputs, states, pas)
83-
$(meta_exprs...)
84-
$(define_calls...)
85-
$(state_compute_calls...)
86-
$(return_state)
87-
end)
88-
89-
multi_diff_func_expr = :(function (inputs, states, pas)
90-
$(meta_exprs...)
91-
$(define_calls...)
92-
$(multi_state_compute_calls...)
93-
$(return_multi_state)
94-
end)
95-
96143
generated_flux_func = @RuntimeGeneratedFunction(flux_func_expr)
97144
generated_multi_flux_func = @RuntimeGeneratedFunction(multi_flux_func_expr)
98-
generated_diff_func = @RuntimeGeneratedFunction(diff_func_expr)
99-
generated_multi_diff_func = @RuntimeGeneratedFunction(multi_diff_func_expr)
100-
return generated_flux_func, generated_multi_flux_func, generated_diff_func, generated_multi_diff_func
145+
146+
if length(state_names) > 0
147+
diff_func_expr = :(function (inputs, states, pas)
148+
$(meta_exprs...)
149+
$(define_calls...)
150+
$(state_compute_calls...)
151+
$(return_state)
152+
end)
153+
154+
multi_diff_func_expr = :(function (inputs, states, pas)
155+
$(meta_exprs...)
156+
$(define_calls...)
157+
$(multi_state_compute_calls...)
158+
$(return_multi_state)
159+
end)
160+
161+
generated_diff_func = @RuntimeGeneratedFunction(diff_func_expr)
162+
generated_multi_diff_func = @RuntimeGeneratedFunction(multi_diff_func_expr)
163+
return [generated_flux_func, generated_multi_flux_func], [generated_diff_func, generated_multi_diff_func]
164+
else
165+
return [generated_flux_func, generated_multi_flux_func], nothing
166+
end
101167
end
102168

169+
"""
170+
build_route_func(fluxes::AbstractVector{<:AbstractFlux}, dfluxes::AbstractVector{<:AbstractStateFlux}, meta::ComponentVector)
171+
172+
Builds runtime-generated functions for routing calculations in a hydrological model, handling both flux computations and state updates with special support for outflow tracking.
173+
174+
# Arguments
175+
- `fluxes::AbstractVector{<:AbstractFlux}`: Vector of flux components that define the routing behavior
176+
- `dfluxes::AbstractVector{<:AbstractStateFlux}`: Vector of state differential components that define state changes
177+
- `meta::ComponentVector`: Metadata containing:
178+
- `inputs`: Input variable names
179+
- `outputs`: Output variable names
180+
- `states`: State variable names
181+
- `params`: Parameter names
182+
183+
# Returns
184+
A tuple containing two runtime-generated functions:
185+
1. `flux_func(inputs, states, pas) -> outputs`: Computes routing fluxes
186+
2. `diff_func(inputs, states, pas) -> (dstates, outflows)`: Computes state changes and tracks outflows
187+
188+
# Details
189+
The function specializes in routing calculations by:
190+
1. Handling both regular and neural network-based flux components
191+
2. Supporting batch processing for neural network operations
192+
3. Managing tensor transformations for multi-dimensional routing
193+
4. Tracking outflows separately from other state changes
194+
195+
For neural network fluxes (`AbstractNeuralFlux`), the function:
196+
- Prepares input tensors with appropriate dimensions
197+
- Handles batched neural network forward passes
198+
- Reshapes outputs for compatibility with the routing system
199+
200+
For regular fluxes:
201+
- Directly computes expressions for both states and fluxes
202+
- Maintains dimensional consistency with neural network outputs
203+
"""
103204
function build_route_func(
104205
fluxes::AbstractVector{<:AbstractFlux},
105206
dfluxes::AbstractVector{<:AbstractStateFlux},
106-
meta::ComponentVector,
207+
infos::NamedTuple,
107208
)
108-
input_names, output_names = tosymbol.(meta.inputs), tosymbol.(meta.outputs)
109-
state_names, param_names = tosymbol.(meta.states), tosymbol.(meta.params)
209+
input_names, output_names = tosymbol.(infos.inputs), tosymbol.(infos.outputs)
210+
state_names, param_names = tosymbol.(infos.states), tosymbol.(infos.params)
110211

111212
input_define_calls = [:($i = inputs[$idx]) for (idx, i) in enumerate(input_names)]
112213
state_define_calls = [:($s = states[$idx]) for (idx, s) in enumerate(state_names)]
113214
params_assign_calls = [:($p = pas.params.$p) for p in param_names]
114-
nn_params_assign_calls = [:($nn = pas.nns.$nn) for nn in [nflux.nninfos[:nns] for nflux in filter(f -> f isa AbstractNeuralFlux, fluxes)]]
215+
nn_params_assign_calls = [:($nn = pas.nns.$nn) for nn in [nflux.infos[:nns][1] for nflux in filter(f -> f isa AbstractNeuralFlux, fluxes)]]
115216
define_calls = reduce(vcat, [input_define_calls, state_define_calls, params_assign_calls, nn_params_assign_calls])
116217
state_compute_calls, flux_compute_calls, = [], []
117218
for f in fluxes
118219
if f isa AbstractNeuralFlux
119-
append!(state_compute_calls, [:($(f.nninfos[:inputs]) = permutedims(reduce(hcat, [$(get_input_names(f)...)])))])
120-
push!(state_compute_calls, :($(f.nninfos[:outputs]) = $(f.func)($(f.nninfos[:inputs]), $(f.nninfos[:nns]))))
121-
append!(state_compute_calls, [:($(nm) = $(f.nninfos[:outputs])[$i, :]) for (i, nm) in enumerate(get_output_names(f))])
220+
append!(state_compute_calls, [:($(f.infos[:nn_inputs]) = permutedims(reduce(hcat, [$(get_input_names(f)...)])))])
221+
push!(state_compute_calls, :($(f.infos[:nn_outputs]) = $(f.func)($(f.infos[:nn_inputs]), $(f.infos[:nns]))))
222+
append!(state_compute_calls, [:($(nm) = $(f.infos[:nn_outputs])[$i, :]) for (i, nm) in enumerate(get_output_names(f))])
122223

123-
append!(flux_compute_calls, [:($(f.nninfos[:inputs]) = permutedims(reduce((m1, m2) -> cat(m1, m2, dims=3), [$(get_input_names(f)...)]), (3, 1, 2)))])
124-
push!(flux_compute_calls, :($(f.nninfos[:outputs]) = permutedims(reduce((m1, m2) -> cat(m1, m2, dims=3), $(f.func).(eachslice($(f.nninfos[:inputs]), dims=2), Ref($(f.nninfos[:nns])))), (1, 3, 2))))
125-
append!(flux_compute_calls, [:($(nm) = $(f.nninfos[:outputs])[$i, :, :]) for (i, nm) in enumerate(get_output_names(f))])
224+
append!(flux_compute_calls, [:($(f.infos[:nn_inputs]) = permutedims(reduce((m1, m2) -> cat(m1, m2, dims=3), [$(get_input_names(f)...)]), (3, 1, 2)))])
225+
push!(flux_compute_calls, :($(f.infos[:nn_outputs]) = permutedims(reduce((m1, m2) -> cat(m1, m2, dims=3), $(f.func).(eachslice($(f.infos[:nn_inputs]), dims=2), Ref($(f.infos[:nns][1])))), (1, 3, 2))))
226+
append!(flux_compute_calls, [:($(nm) = $(f.infos[:nn_outputs])[$i, :, :]) for (i, nm) in enumerate(get_output_names(f))])
126227
else
127228
append!(state_compute_calls, [:($(nm) = $(toexprv2(unwrap(expr)))) for (nm, expr) in zip(get_output_names(f), f.exprs)])
128229
append!(flux_compute_calls, [:($(nm) = $(toexprv2(unwrap(expr)))) for (nm, expr) in zip(get_output_names(f), f.exprs)])
129230
end
130231
end
131232

132-
dfluxes_outflows = reduce(vcat, [tosymbol.(dflux.meta.outflows) for dflux in dfluxes])
233+
dfluxes_outflows = reduce(vcat, [dflux.infos.outflows for dflux in dfluxes])
133234
return_state = :(return [$(map(expr -> :($(toexprv2(unwrap(expr)))), reduce(vcat, get_exprs.(dfluxes)))...)], [$(dfluxes_outflows...)])
134235
# Create function expression
135236
meta_exprs = [:(Base.@_inline_meta)]

0 commit comments

Comments
 (0)