1
+ """
2
+ build_flux_func(inputs::Vector{Num}, outputs::Vector{Num}, params::Vector{Num}, exprs::Vector{Num})
3
+
4
+ Generates a runtime function that computes flux calculations based on symbolic expressions.
5
+
6
+ # Arguments
7
+ - `inputs::Vector{Num}`: Vector of symbolic input variables that will be provided at runtime
8
+ - `outputs::Vector{Num}`: Vector of symbolic output variables that will be computed
9
+ - `params::Vector{Num}`: Vector of symbolic parameters used in the flux calculations
10
+ - `exprs::Vector{Num}`: Vector of symbolic expressions defining how outputs are computed from inputs and parameters
11
+
12
+ # Returns
13
+ - A runtime-generated function with signature `(inputs, pas)` where:
14
+ - `inputs`: Vector of input values corresponding to the symbolic inputs
15
+ - `pas`: A parameter struct containing fields matching the parameter names
16
+
17
+ # Details
18
+ The function generates code that:
19
+ 1. Assigns input values from the input vector to local variables
20
+ 2. Retrieves parameter values from the parameter struct
21
+ 3. Computes outputs using the provided expressions
22
+ 4. Returns a vector of computed outputs
23
+ """
1
24
function build_flux_func (inputs:: Vector{Num} , outputs:: Vector{Num} , params:: Vector{Num} , exprs:: Vector{Num} )
2
25
input_names, output_names = Symbolics. tosymbol .(inputs), Symbolics. tosymbol .(outputs)
3
26
param_names = Symbolics. tosymbol .(params)
@@ -16,39 +39,77 @@ function build_flux_func(inputs::Vector{Num}, outputs::Vector{Num}, params::Vect
16
39
return generated_flux_func
17
40
end
18
41
42
+ """
43
+ build_ele_func(fluxes::Vector{<:AbstractFlux}, dfluxes::Vector{<:AbstractStateFlux}, meta::ComponentVector)
44
+
45
+ Builds runtime-generated functions for both flux calculations and state differentials in a hydrological model element.
46
+
47
+ # Arguments
48
+ - `fluxes::Vector{<:AbstractFlux}`: Vector of flux components that define the element's behavior
49
+ - `dfluxes::Vector{<:AbstractStateFlux}`: Vector of state differential components that define state changes
50
+ - `meta::ComponentVector`: Metadata containing:
51
+ - `inputs`: Input variable names
52
+ - `outputs`: Output variable names
53
+ - `states`: State variable names
54
+ - `params`: Parameter names
55
+
56
+ # Returns
57
+ A tuple containing:
58
+ - First element: `Vector{Function}` with two functions:
59
+ 1. Regular flux function `(inputs, states, pas) -> outputs`
60
+ 2. Multi-dimensional flux function for batch processing
61
+ - Second element: Either `nothing` (if no states) or `Vector{Function}` with two functions:
62
+ 1. State differential function `(inputs, states, pas) -> dstates`
63
+ 2. Multi-dimensional state differential function for batch processing
64
+
65
+ # Details
66
+ The function generates four types of runtime functions:
67
+ 1. Single-sample flux computation
68
+ 2. Multi-sample flux computation (batched)
69
+ 3. Single-sample state differential computation (if states exist)
70
+ 4. Multi-sample state differential computation (if states exist)
71
+
72
+ For neural network fluxes (`AbstractNeuralFlux`), the function handles:
73
+ - Input tensor preparation
74
+ - Neural network forward passes
75
+ - Output tensor reshaping
76
+
77
+ For regular fluxes, it directly computes using provided expressions.
78
+ ```
79
+ """
19
80
function build_ele_func (
20
81
fluxes:: Vector{<:AbstractFlux} ,
21
82
dfluxes:: Vector{<:AbstractStateFlux} ,
22
- meta :: ComponentVector ,
83
+ infos :: NamedTuple ,
23
84
)
24
- input_names, output_names = tosymbol .(meta . inputs), tosymbol .(meta . outputs)
25
- state_names, param_names = tosymbol .(meta . states), tosymbol .(meta . params)
85
+ input_names, output_names = infos . inputs, infos . outputs
86
+ state_names, param_names = infos . states, infos . params
26
87
27
88
input_define_calls = [:($ i = inputs[$ idx]) for (idx, i) in enumerate (input_names)]
28
89
state_define_calls = [:($ s = states[$ idx]) for (idx, s) in enumerate (state_names)]
29
90
params_assign_calls = [:($ p = pas. params.$ p) for p in param_names]
30
- nn_params_assign_calls = [:($ nn = pas. nns.$ nn) for nn in [nflux. nninfos [:nns ] for nflux in filter (f -> f isa AbstractNeuralFlux, fluxes)]]
91
+ nn_params_assign_calls = [:($ nn = pas. nns.$ nn) for nn in [nflux. infos [:nns ][ 1 ] for nflux in filter (f -> f isa AbstractNeuralFlux, fluxes)]]
31
92
define_calls = reduce (vcat, [input_define_calls, state_define_calls, params_assign_calls, nn_params_assign_calls])
32
93
33
94
# varibles definitions expressions
34
95
state_compute_calls, multi_state_compute_calls, flux_compute_calls, multi_flux_compute_calls = [], [], [], []
35
96
for f in fluxes
36
97
if f isa AbstractNeuralFlux
37
- append! (state_compute_calls, [:($ (f. nninfos[ :inputs ]) = [$ (get_input_names (f)... )])])
38
- push! (state_compute_calls, :($ (f. nninfos[ :outputs ]) = $ (f. func)($ (f. nninfos[ :inputs ]), $ (f. nninfos [:nns ]))))
39
- append! (state_compute_calls, [:($ (nm) = $ (f. nninfos[ :outputs ])[$ i]) for (i, nm) in enumerate (get_output_names (f))])
98
+ append! (state_compute_calls, [:($ (f. infos[ :nn_inputs ]) = [$ (get_input_names (f)... )])])
99
+ push! (state_compute_calls, :($ (f. infos[ :nn_outputs ]) = $ (f. func)($ (f. infos[ :nn_inputs ]), $ (f. infos [:nns ][ 1 ]))))
100
+ append! (state_compute_calls, [:($ (nm) = $ (f. infos[ :nn_outputs ])[$ i]) for (i, nm) in enumerate (get_output_names (f))])
40
101
41
- append! (multi_state_compute_calls, [:($ (f. nninfos[ :inputs ]) = permutedims (reduce (hcat, [$ (get_input_names (f)... )])))])
42
- push! (multi_state_compute_calls, :($ (f. nninfos[ :outputs ]) = $ (f. func)($ (f. nninfos[ :inputs ]), $ (f. nninfos [:nns ]))))
43
- append! (multi_state_compute_calls, [:($ (nm) = $ (f. nninfos[ :outputs ])[$ i, :]) for (i, nm) in enumerate (get_output_names (f))])
102
+ append! (multi_state_compute_calls, [:($ (f. infos[ :nn_inputs ]) = permutedims (reduce (hcat, [$ (get_input_names (f)... )])))])
103
+ push! (multi_state_compute_calls, :($ (f. infos[ :nn_outputs ]) = $ (f. func)($ (f. infos[ :nn_inputs ]), $ (f. infos [:nns ][ 1 ]))))
104
+ append! (multi_state_compute_calls, [:($ (nm) = $ (f. infos[ :nn_outputs ])[$ i, :]) for (i, nm) in enumerate (get_output_names (f))])
44
105
45
- append! (flux_compute_calls, [:($ (f. nninfos[ :inputs ]) = permutedims (reduce (hcat, [$ (get_input_names (f)... )])))])
46
- push! (flux_compute_calls, :($ (f. nninfos[ :outputs ]) = $ (f. func)($ (f. nninfos[ :inputs ]), $ (f. nninfos [:nns ]))))
47
- append! (flux_compute_calls, [:($ (nm) = $ (f. nninfos[ :outputs ])[$ i, :]) for (i, nm) in enumerate (get_output_names (f))])
106
+ append! (flux_compute_calls, [:($ (f. infos[ :nn_inputs ]) = permutedims (reduce (hcat, [$ (get_input_names (f)... )])))])
107
+ push! (flux_compute_calls, :($ (f. infos[ :nn_outputs ]) = $ (f. func)($ (f. infos[ :nn_inputs ]), $ (f. infos [:nns ][ 1 ]))))
108
+ append! (flux_compute_calls, [:($ (nm) = $ (f. infos[ :nn_outputs ])[$ i, :]) for (i, nm) in enumerate (get_output_names (f))])
48
109
49
- append! (multi_flux_compute_calls, [:($ (f. nninfos[ :inputs ]) = permutedims (reduce ((m1, m2) -> cat (m1, m2, dims= 3 ), [$ (get_input_names (f)... )]), (3 , 1 , 2 )))])
50
- push! (multi_flux_compute_calls, :($ (f. nninfos[ :outputs ]) = permutedims (reduce ((m1, m2) -> cat (m1, m2, dims= 3 ), $ (f. func). (eachslice ($ (f. nninfos[ :inputs ]), dims= 2 ), Ref ($ (f. nninfos [:nns ])))), (1 , 3 , 2 ))))
51
- append! (multi_flux_compute_calls, [:($ (nm) = $ (f. nninfos[ :outputs ])[$ i, :, :]) for (i, nm) in enumerate (get_output_names (f))])
110
+ append! (multi_flux_compute_calls, [:($ (f. infos[ :nn_inputs ]) = permutedims (reduce ((m1, m2) -> cat (m1, m2, dims= 3 ), [$ (get_input_names (f)... )]), (3 , 1 , 2 )))])
111
+ push! (multi_flux_compute_calls, :($ (f. infos[ :nn_outputs ]) = permutedims (reduce ((m1, m2) -> cat (m1, m2, dims= 3 ), $ (f. func). (eachslice ($ (f. infos[ :nn_inputs ]), dims= 2 ), Ref ($ (f. infos [:nns ][ 1 ])))), (1 , 3 , 2 ))))
112
+ append! (multi_flux_compute_calls, [:($ (nm) = $ (f. infos[ :nn_outputs ])[$ i, :, :]) for (i, nm) in enumerate (get_output_names (f))])
52
113
else
53
114
append! (state_compute_calls, [:($ nm = $ (toexpr (expr))) for (nm, expr) in zip (get_output_names (f), f. exprs)])
54
115
append! (multi_state_compute_calls, [:($ nm = $ (toexprv2 (unwrap (expr)))) for (nm, expr) in zip (get_output_names (f), f. exprs)])
@@ -79,57 +140,97 @@ function build_ele_func(
79
140
$ (return_flux)
80
141
end )
81
142
82
- diff_func_expr = :(function (inputs, states, pas)
83
- $ (meta_exprs... )
84
- $ (define_calls... )
85
- $ (state_compute_calls... )
86
- $ (return_state)
87
- end )
88
-
89
- multi_diff_func_expr = :(function (inputs, states, pas)
90
- $ (meta_exprs... )
91
- $ (define_calls... )
92
- $ (multi_state_compute_calls... )
93
- $ (return_multi_state)
94
- end )
95
-
96
143
generated_flux_func = @RuntimeGeneratedFunction (flux_func_expr)
97
144
generated_multi_flux_func = @RuntimeGeneratedFunction (multi_flux_func_expr)
98
- generated_diff_func = @RuntimeGeneratedFunction (diff_func_expr)
99
- generated_multi_diff_func = @RuntimeGeneratedFunction (multi_diff_func_expr)
100
- return generated_flux_func, generated_multi_flux_func, generated_diff_func, generated_multi_diff_func
145
+
146
+ if length (state_names) > 0
147
+ diff_func_expr = :(function (inputs, states, pas)
148
+ $ (meta_exprs... )
149
+ $ (define_calls... )
150
+ $ (state_compute_calls... )
151
+ $ (return_state)
152
+ end )
153
+
154
+ multi_diff_func_expr = :(function (inputs, states, pas)
155
+ $ (meta_exprs... )
156
+ $ (define_calls... )
157
+ $ (multi_state_compute_calls... )
158
+ $ (return_multi_state)
159
+ end )
160
+
161
+ generated_diff_func = @RuntimeGeneratedFunction (diff_func_expr)
162
+ generated_multi_diff_func = @RuntimeGeneratedFunction (multi_diff_func_expr)
163
+ return [generated_flux_func, generated_multi_flux_func], [generated_diff_func, generated_multi_diff_func]
164
+ else
165
+ return [generated_flux_func, generated_multi_flux_func], nothing
166
+ end
101
167
end
102
168
169
+ """
170
+ build_route_func(fluxes::AbstractVector{<:AbstractFlux}, dfluxes::AbstractVector{<:AbstractStateFlux}, meta::ComponentVector)
171
+
172
+ Builds runtime-generated functions for routing calculations in a hydrological model, handling both flux computations and state updates with special support for outflow tracking.
173
+
174
+ # Arguments
175
+ - `fluxes::AbstractVector{<:AbstractFlux}`: Vector of flux components that define the routing behavior
176
+ - `dfluxes::AbstractVector{<:AbstractStateFlux}`: Vector of state differential components that define state changes
177
+ - `meta::ComponentVector`: Metadata containing:
178
+ - `inputs`: Input variable names
179
+ - `outputs`: Output variable names
180
+ - `states`: State variable names
181
+ - `params`: Parameter names
182
+
183
+ # Returns
184
+ A tuple containing two runtime-generated functions:
185
+ 1. `flux_func(inputs, states, pas) -> outputs`: Computes routing fluxes
186
+ 2. `diff_func(inputs, states, pas) -> (dstates, outflows)`: Computes state changes and tracks outflows
187
+
188
+ # Details
189
+ The function specializes in routing calculations by:
190
+ 1. Handling both regular and neural network-based flux components
191
+ 2. Supporting batch processing for neural network operations
192
+ 3. Managing tensor transformations for multi-dimensional routing
193
+ 4. Tracking outflows separately from other state changes
194
+
195
+ For neural network fluxes (`AbstractNeuralFlux`), the function:
196
+ - Prepares input tensors with appropriate dimensions
197
+ - Handles batched neural network forward passes
198
+ - Reshapes outputs for compatibility with the routing system
199
+
200
+ For regular fluxes:
201
+ - Directly computes expressions for both states and fluxes
202
+ - Maintains dimensional consistency with neural network outputs
203
+ """
103
204
function build_route_func (
104
205
fluxes:: AbstractVector{<:AbstractFlux} ,
105
206
dfluxes:: AbstractVector{<:AbstractStateFlux} ,
106
- meta :: ComponentVector ,
207
+ infos :: NamedTuple ,
107
208
)
108
- input_names, output_names = tosymbol .(meta . inputs), tosymbol .(meta . outputs)
109
- state_names, param_names = tosymbol .(meta . states), tosymbol .(meta . params)
209
+ input_names, output_names = tosymbol .(infos . inputs), tosymbol .(infos . outputs)
210
+ state_names, param_names = tosymbol .(infos . states), tosymbol .(infos . params)
110
211
111
212
input_define_calls = [:($ i = inputs[$ idx]) for (idx, i) in enumerate (input_names)]
112
213
state_define_calls = [:($ s = states[$ idx]) for (idx, s) in enumerate (state_names)]
113
214
params_assign_calls = [:($ p = pas. params.$ p) for p in param_names]
114
- nn_params_assign_calls = [:($ nn = pas. nns.$ nn) for nn in [nflux. nninfos [:nns ] for nflux in filter (f -> f isa AbstractNeuralFlux, fluxes)]]
215
+ nn_params_assign_calls = [:($ nn = pas. nns.$ nn) for nn in [nflux. infos [:nns ][ 1 ] for nflux in filter (f -> f isa AbstractNeuralFlux, fluxes)]]
115
216
define_calls = reduce (vcat, [input_define_calls, state_define_calls, params_assign_calls, nn_params_assign_calls])
116
217
state_compute_calls, flux_compute_calls, = [], []
117
218
for f in fluxes
118
219
if f isa AbstractNeuralFlux
119
- append! (state_compute_calls, [:($ (f. nninfos[ :inputs ]) = permutedims (reduce (hcat, [$ (get_input_names (f)... )])))])
120
- push! (state_compute_calls, :($ (f. nninfos[ :outputs ]) = $ (f. func)($ (f. nninfos[ :inputs ]), $ (f. nninfos [:nns ]))))
121
- append! (state_compute_calls, [:($ (nm) = $ (f. nninfos[ :outputs ])[$ i, :]) for (i, nm) in enumerate (get_output_names (f))])
220
+ append! (state_compute_calls, [:($ (f. infos[ :nn_inputs ]) = permutedims (reduce (hcat, [$ (get_input_names (f)... )])))])
221
+ push! (state_compute_calls, :($ (f. infos[ :nn_outputs ]) = $ (f. func)($ (f. infos[ :nn_inputs ]), $ (f. infos [:nns ]))))
222
+ append! (state_compute_calls, [:($ (nm) = $ (f. infos[ :nn_outputs ])[$ i, :]) for (i, nm) in enumerate (get_output_names (f))])
122
223
123
- append! (flux_compute_calls, [:($ (f. nninfos[ :inputs ]) = permutedims (reduce ((m1, m2) -> cat (m1, m2, dims= 3 ), [$ (get_input_names (f)... )]), (3 , 1 , 2 )))])
124
- push! (flux_compute_calls, :($ (f. nninfos[ :outputs ]) = permutedims (reduce ((m1, m2) -> cat (m1, m2, dims= 3 ), $ (f. func). (eachslice ($ (f. nninfos[ :inputs ]), dims= 2 ), Ref ($ (f. nninfos [:nns ])))), (1 , 3 , 2 ))))
125
- append! (flux_compute_calls, [:($ (nm) = $ (f. nninfos[ :outputs ])[$ i, :, :]) for (i, nm) in enumerate (get_output_names (f))])
224
+ append! (flux_compute_calls, [:($ (f. infos[ :nn_inputs ]) = permutedims (reduce ((m1, m2) -> cat (m1, m2, dims= 3 ), [$ (get_input_names (f)... )]), (3 , 1 , 2 )))])
225
+ push! (flux_compute_calls, :($ (f. infos[ :nn_outputs ]) = permutedims (reduce ((m1, m2) -> cat (m1, m2, dims= 3 ), $ (f. func). (eachslice ($ (f. infos[ :nn_inputs ]), dims= 2 ), Ref ($ (f. infos [:nns ][ 1 ])))), (1 , 3 , 2 ))))
226
+ append! (flux_compute_calls, [:($ (nm) = $ (f. infos[ :nn_outputs ])[$ i, :, :]) for (i, nm) in enumerate (get_output_names (f))])
126
227
else
127
228
append! (state_compute_calls, [:($ (nm) = $ (toexprv2 (unwrap (expr)))) for (nm, expr) in zip (get_output_names (f), f. exprs)])
128
229
append! (flux_compute_calls, [:($ (nm) = $ (toexprv2 (unwrap (expr)))) for (nm, expr) in zip (get_output_names (f), f. exprs)])
129
230
end
130
231
end
131
232
132
- dfluxes_outflows = reduce (vcat, [tosymbol .( dflux. meta . outflows) for dflux in dfluxes])
233
+ dfluxes_outflows = reduce (vcat, [dflux. infos . outflows for dflux in dfluxes])
133
234
return_state = :(return [$ (map (expr -> :($ (toexprv2 (unwrap (expr)))), reduce (vcat, get_exprs .(dfluxes)))... )], [$ (dfluxes_outflows... )])
134
235
# Create function expression
135
236
meta_exprs = [:(Base. @_inline_meta )]
0 commit comments