Skip to content

Commit 35ff1bb

Browse files
committed
complete test
1 parent df765c6 commit 35ff1bb

17 files changed

+75
-441
lines changed

Project.toml

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "HydroModels"
22
uuid = "7e3cde01-c141-467b-bff6-5350a0af19fc"
33
authors = ["jingx <50790703+chooron@users.noreply.github.com>"]
4-
version = "0.1.0"
4+
version = "0.1.1"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
@@ -47,12 +47,16 @@ Symbolics = "6"
4747
TOML = "1"
4848
Test = "1"
4949
julia = "1.10"
50+
CSV = "0.10"
51+
DataFrames = "1"
52+
Statistics = "1"
5053

5154
[extras]
5255
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
5356
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
5457
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
58+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
5559
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5660

5761
[targets]
58-
test = ["Test", "Aqua", "CSV", "DataFrames"]
62+
test = ["Test", "Aqua", "CSV", "DataFrames", "Statistics"]

src/HydroModels.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,11 @@ include("utils/name.jl")
6969
include("utils/show.jl")
7070
include("utils/build.jl")
7171
include("utils/sort.jl")
72+
include("utils/check.jl")
7273
include("utils/io.jl")
73-
inclue("utils/check.jl")
7474
export NamedTupleIOAdapter
75+
include("utils/solver.jl")
76+
export ManualSolver
7577

7678
# framework build
7779
include("flux.jl")

src/bucket.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,8 @@ function (ele::HydroBucket{F,D,FF,OF,M})(
219219
params_vec, nn_params_vec = param_func(pas), nn_param_func(pas)
220220
flux_output = ele.flux_func.(eachslice(input, dims=2), Ref(params_vec), Ref(nn_params_vec), timeidx)
221221
#* convert vector{vector} to matrix
222-
flux_output_matrix = reduce(hcat, flux_output)
223-
flux_output_matrix
222+
flux_output_mat = reduce(hcat, flux_output)
223+
flux_output_mat
224224
end
225225

226226
function (ele::HydroBucket{F,D,FF,OF,M})(
@@ -241,7 +241,7 @@ function (ele::HydroBucket{F,D,FF,OF,M})(
241241
check_ptypes(ele, input, ptypes)
242242
check_stypes(ele, input, stypes)
243243
#* check initial states
244-
check_initstates(ele, pas)
244+
check_initstates(ele, pas, stypes)
245245
#* prepare initial states
246246
init_states_mat = reduce(hcat, [collect(pas[:initstates][stype][get_state_names(ele)]) for stype in stypes])
247247
#* extract params and nn params

src/route.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ function (route::HydroRoute{F,PF,M})(
262262
sol_arr = solver(du_func, pas, init_states_mat, timeidx, convert_to_array=true)
263263
sol_arr_permuted = permutedims(sol_arr, (2, 1, 3))
264264
cat_arr = cat(input, sol_arr_permuted, dims=1)
265-
output_vec = [route.rfunc.func.(eachslice(cat_arr_, dims=2), param_func(pas), timeidx[i]) for cat_arr_ in eachslice(cat_arr, dims=3)]
265+
output_vec = [route.rfunc.func.(eachslice(cat_arr[:, :, i], dims=2), param_func(pas), timeidx[i]) for i in axes(cat_arr, 3)]
266266
out_arr = reduce(hcat, reduce.(vcat, output_vec))
267267
#* return route_states and q_out
268268
return cat(sol_arr_permuted, reshape(out_arr, 1, size(out_arr)...), dims=1)
@@ -345,8 +345,8 @@ function (route::RapidRoute)(
345345
itp_funcs = interp.(eachslice(input[1, :, :], dims=1), Ref(timeidx), extrapolate=true)
346346

347347
#* prepare the parameters for the routing function
348-
k_ps = [pas[:params][ptype][:k] for ptype in ptypes]
349-
x_ps = [pas[:params][ptype][:x] for ptype in ptypes]
348+
k_ps = [pas[:params][ptype][:rapid_k] for ptype in ptypes]
349+
x_ps = [pas[:params][ptype][:rapid_x] for ptype in ptypes]
350350
c0 = @. ((delta_t / k_ps) - (2 * x_ps)) / ((2 * (1 - x_ps)) + (delta_t / k_ps))
351351
c1 = @. ((delta_t / k_ps) + (2 * x_ps)) / ((2 * (1 - x_ps)) + (delta_t / k_ps))
352352
c2 = @. ((2 * (1 - x_ps)) - (delta_t / k_ps)) / ((2 * (1 - x_ps)) + (delta_t / k_ps))

src/uh.jl

+6-7
Original file line numberDiff line numberDiff line change
@@ -151,22 +151,21 @@ Apply the unit hydrograph flux model to input data of various dimensions.
151151

152152
(::UnitHydrograph)(::AbstractVector, ::ComponentVector; kwargs...) = @error "UnitHydrograph is not support for single timepoint"
153153

154-
function (flux::UnitHydrograph{<:Any,<:Any,<:Any,:DISCRETE})(input::AbstractArray{T,2}, pas::ComponentVector; kwargs...) where {T}
154+
function (flux::UnitHydrograph{<:Any,<:Any,<:Any,:DISCRETE})(input::AbstractArray{T,2}, pas::ComponentVector; config::NamedTuple=NamedTuple(), kwargs...) where {T}
155155
solver = get(config, :solver, ManualSolver{true}())
156-
timeidx = get(kwargs, :timeidx, collect(1:size(input, 2)))
156+
timeidx = get(config, :timeidx, collect(1:size(input, 2)))
157157
input_vec = input[1, :]
158158
#* convert the lagflux to a discrete problem
159-
lag_du_func(u,p,t) = input_vec[Int(t)] .* p[:weight] .+ [diff(u); -u[end]]
159+
lag_du_func(u, p, t) = input_vec[Int(t)] .* p[:weight] .+ [diff(u); -u[end]]
160160
#* prepare the initial states
161161
lag = pas[:params][get_param_names(flux)[1]]
162162
uh_weight = map(t -> flux.uhfunc(t, lag), 1:get_uh_tmax(flux.uhfunc, lag))[1:end-1]
163163
if length(uh_weight) == 0
164164
@warn "The unit hydrograph weight is empty, please check the unit hydrograph function"
165165
return input
166166
else
167-
initstates = input_vec[1] .* uh_weight ./ sum(uh_weight)
168167
#* solve the problem
169-
sol = solver(lag_du_func, ComponentVector(weight=uh_weight ./ sum(uh_weight)), initstates, timeidx)
168+
sol = solver(lag_du_func, ComponentVector(weight=uh_weight ./ sum(uh_weight)), zeros(length(uh_weight)), timeidx)
170169
reshape(sol[1, :], 1, length(input_vec))
171170
end
172171
end
@@ -191,9 +190,9 @@ function (flux::UnitHydrograph{<:Any,<:Any,<:Any,:SPARSE})(input::AbstractArray{
191190
end
192191

193192
# todo: 卷积计算的结果与前两个计算结果不太一致
194-
function (flux::UnitHydrograph{<:Any,<:Any,<:Any,:INTEGRAL})(input::AbstractArray{T,2}, pas::ComponentVector; kwargs...) where {T}
193+
function (flux::UnitHydrograph{<:Any,<:Any,<:Any,:INTEGRAL})(input::AbstractArray{T,2}, pas::ComponentVector; config::NamedTuple=NamedTuple(), kwargs...) where {T}
195194
input_vec = input[1, :]
196-
itp_method = get(kwargs, :interp, LinearInterpolation)
195+
itp_method = get(config, :interp, LinearInterpolation)
197196
itp = itp_method(input_vec, collect(1:length(input_vec)), extrapolate=true)
198197
#* construct the unit hydrograph function based on the interpolation method and parameter
199198
lag = pas[:params][get_param_names(flux)[1]]

src/utils/check.jl

+13-10
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ end
1818

1919
function check_pas(component::AbstractComponent, pas::ComponentVector)
2020
check_parameters(component, pas)
21-
check_states(component, pas)
21+
check_initstates(component, pas)
2222
check_nns(component, pas)
2323
end
2424

2525
function check_pas(component::AbstractComponent, pas::ComponentVector, ptypes::AbstractVector{Symbol}, stypes::AbstractVector{Symbol})
2626
check_parameters(component, pas, ptypes)
27-
check_states(component, pas, stypes)
27+
check_initstates(component, pas, stypes)
2828
check_nns(component, pas)
2929
end
3030

@@ -42,31 +42,34 @@ function check_parameters(component::AbstractComponent, pas::ComponentVector, pt
4242
param_names = get_param_names(component)
4343
cpt_name = get_name(component)
4444
for ptype in ptypes
45+
tmp_ptype_params_keys = keys(pas[:params][ptype])
4546
for param_name in param_names
46-
@assert(param_name in keys(pas[ptype][:params]),
47-
"Parameter '$(param_name)' in component '$(cpt_name)' is required but not found in parameter type '$(ptype)'. Available parameters: $(keys(pas[ptype][:params]))"
47+
@assert(param_name in tmp_ptype_params_keys,
48+
"Parameter '$(param_name)' in component '$(cpt_name)' is required but not found in parameter type '$(ptype)'. Available parameters: $(tmp_ptype_params_keys)"
4849
)
4950
end
5051
end
5152
end
5253

53-
function check_states(component::AbstractComponent, pas::ComponentVector)
54+
function check_initstates(component::AbstractComponent, pas::ComponentVector)
5455
state_names = get_state_names(component)
5556
cpt_name = get_name(component)
5657
for state_name in state_names
57-
@assert(state_name in keys(pas[:initstates]),
58-
"Initial state '$(state_name)' in component '$(cpt_name)' is required but not found in pas[:initstates]. Available states: $(keys(pas[:initstates]))"
58+
tmp_ptype_initstates_keys = keys(pas[:initstates])
59+
@assert(state_name in tmp_ptype_initstates_keys,
60+
"Initial state '$(state_name)' in component '$(cpt_name)' is required but not found in parameter type '$(ptype)'. Available states: $(tmp_ptype_initstates_keys)"
5961
)
6062
end
6163
end
6264

63-
function check_states(component::AbstractComponent, pas::ComponentVector, stypes::AbstractVector{Symbol})
65+
function check_initstates(component::AbstractComponent, pas::ComponentVector, stypes::AbstractVector{Symbol})
6466
state_names = get_state_names(component)
6567
cpt_name = get_name(component)
6668
for stype in stypes
69+
tmp_ptype_initstates_keys = keys(pas[:initstates][stype])
6770
for state_name in state_names
68-
@assert(state_name in keys(pas[stype][:initstates]),
69-
"Initial state '$(state_name)' in component '$(cpt_name)' is required but not found in state type '$(stype)'. Available states: $(keys(pas[stype][:initstates]))"
71+
@assert(state_name in tmp_ptype_initstates_keys,
72+
"Initial state '$(state_name)' in component '$(cpt_name)' is required but not found in state type '$(stype)'. Available states: $(tmp_ptype_initstates_keys)"
7073
)
7174
end
7275
end

src/utils/show.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,15 @@ function Base.show(io::IO, uh::AbstractHydrograph)
6565
print(io, "inputs: ", isempty(uh.meta.inputs) ? "nothing" : join(uh.meta.inputs, ", "))
6666
print(io, ", outputs: ", isempty(uh.meta.outputs) ? "nothing" : join(uh.meta.outputs, ", "))
6767
print(io, ", params: ", isempty(uh.meta.params) ? "nothing" : join(uh.meta.params, ", "))
68-
print(io, ", uhfunc: ", nameof(typeof(uh.uhfunc).parameters[1]))
68+
print(io, ", uhfunc: ", typeof(uh.uhfunc).parameters[1])
6969
print(io, ")")
7070
else
7171
println(io, "UnitHydroFlux:")
7272
println(io, " Inputs: ", isempty(uh.meta.inputs) ? "nothing" : join(uh.meta.inputs, ", "))
7373
println(io, " Outputs: ", isempty(uh.meta.outputs) ? "nothing" : join(uh.meta.outputs, ", "))
7474
println(io, " Parameters: ", isempty(uh.meta.params) ? "nothing" : join(uh.meta.params, ", "))
75-
println(io, " UnitFunction: ", nameof(typeof(uh.uhfunc).parameters[1]))
76-
println(io, " SolveType: ", nameof(typeof(uh).parameters[end]))
75+
println(io, " UnitFunction: ", typeof(uh.uhfunc).parameters[1])
76+
println(io, " SolveType: ", typeof(uh).parameters[end])
7777
end
7878
end
7979

test/run_bucket.jl

+4-33
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,11 @@
3131
@test Set(HydroModels.get_output_names(snow_ele)) == Set((:pet, :snowfall, :rainfall, :melt))
3232
@test Set(HydroModels.get_state_names(snow_ele)) == Set((:snowpack,))
3333
end
34-
35-
result = snow_ele(input, pas)
34+
config = (timeidx=ts, solver=ManualSolver{true}())
35+
result = snow_ele(input, pas, config=config)
3636
ele_state_and_output_names = vcat(HydroModels.get_state_names(snow_ele), HydroModels.get_output_names(snow_ele))
3737
result = NamedTuple{Tuple(ele_state_and_output_names)}(eachslice(result, dims=1))
38+
3839
@testset "test first output for hydro element" begin
3940
snowpack0 = init_states[:snowpack]
4041
pet0 = snow_funcs[1]([input_ntp.temp[1], input_ntp.lday[1]], ComponentVector(params=ComponentVector()))[1]
@@ -46,42 +47,12 @@
4647
@test melt0 == result.melt[1]
4748
end
4849

49-
@testset "test ode solved results" begin
50-
prcp_itp = LinearInterpolation(input_ntp.prcp, ts)
51-
temp_itp = LinearInterpolation(input_ntp.temp, ts)
52-
53-
function snowpack_bucket!(du, u, p, t)
54-
snowpack_ = u[1]
55-
Df, Tmax, Tmin = p.Df, p.Tmax, p.Tmin
56-
prcp_, temp_ = prcp_itp(t), temp_itp(t)
57-
snowfall_ = step_func(Tmin - temp_) * prcp_
58-
melt_ = step_func(temp_ - Tmax) * step_func(snowpack_) * min(snowpack_, Df * (temp_ - Tmax))
59-
du[1] = snowfall_ - melt_
60-
end
61-
prob = ODEProblem(snowpack_bucket!, [init_states.snowpack], (ts[1], ts[end]), params)
62-
sol = solve(prob, Tsit5(), saveat=ts, reltol=1e-3, abstol=1e-3)
63-
num_u = length(prob.u0)
64-
manual_result = [sol[i, :] for i in 1:num_u]
65-
ele_params_idx = [getaxes(pas[:params])[1][nm].idx for nm in HydroModels.get_param_names(snow_ele)]
66-
paramfunc = (p) -> [p[:params][idx] for idx in ele_params_idx]
67-
68-
param_func, nn_param_func = HydroModels._get_parameter_extractors(snow_ele, pas)
69-
itpfunc_list = map((var) -> LinearInterpolation(var, ts, extrapolate=true), eachrow(input))
70-
ode_input_func = (t) -> [itpfunc(t) for itpfunc in itpfunc_list]
71-
du_func = HydroModels._get_du_func(snow_ele, ode_input_func, param_func, nn_param_func)
72-
solver = HydroModels.ODESolver(alg=Tsit5(), reltol=1e-3, abstol=1e-3)
73-
initstates_mat = collect(pas[:initstates][HydroModels.get_state_names(snow_ele)])
74-
#* solve the problem by call the solver
75-
solved_states = solver(du_func, pas, initstates_mat, ts)
76-
@test manual_result[1] == solved_states[1, :]
77-
end
78-
7950
@testset "test all of the output" begin
8051
param_func, nn_param_func = HydroModels._get_parameter_extractors(snow_ele, pas)
8152
itpfunc_list = map((var) -> LinearInterpolation(var, ts, extrapolate=true), eachrow(input))
8253
ode_input_func = (t) -> [itpfunc(t) for itpfunc in itpfunc_list]
8354
du_func = HydroModels._get_du_func(snow_ele, ode_input_func, param_func, nn_param_func)
84-
solver = HydroModels.ODESolver(alg=Tsit5(), reltol=1e-3, abstol=1e-3)
55+
solver = ManualSolver{true}()
8556
initstates_mat = collect(pas[:initstates][HydroModels.get_state_names(snow_ele)])
8657
#* solve the problem by call the solver
8758
snowpack_vec = solver(du_func, pas, initstates_mat, ts)[1, :]

test/run_flux.jl

-20
Original file line numberDiff line numberDiff line change
@@ -46,26 +46,6 @@ end
4646
@test HydroModels.get_state_names(state_flux_3) == [:d,]
4747
end
4848

49-
# todo muskingum need rebuild
50-
# @testset "test muskingum route flux" begin
51-
# @variables q1
52-
53-
# # Building the Muskingum routing flux
54-
# k, x = 3.0, 0.2
55-
# pas = ComponentVector(params=(k=k, x=x,))
56-
# msk_flux = HydroModels.MuskingumRouteFlux(q1)
57-
# input = Float64[1 2 3 2 3 2 5 7 8 3 2 1]
58-
# re = msk_flux(input, pas)
59-
60-
# # Verifying the input, output, and parameter names
61-
# @test HydroModels.get_input_names(msk_flux) == [:q1]
62-
# @test HydroModels.get_output_names(msk_flux) == [:q1_routed]
63-
# @test HydroModels.get_param_names(msk_flux) == [:k, :x]
64-
65-
# # Checking the size and values of the output
66-
# @test size(re) == size(input)
67-
# @test re ≈ [1.0 0.977722 1.30086 1.90343 1.919 2.31884 2.15305 3.07904 4.39488 5.75286 4.83462 3.89097] atol = 1e-1
68-
# end
6949

7050

7151
@testset "test neural flux (single output)" begin

test/run_groute.jl

-23
This file was deleted.

0 commit comments

Comments
 (0)