Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor training strategy handling #547

Merged
merged 7 commits into from
Jun 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ using Flux: @nograd
import Optimisers
import UnPack: @unpack

RuntimeGeneratedFunctions.init(@__MODULE__)

abstract type NeuralPDEAlgorithm <: DiffEqBase.AbstractODEAlgorithm end
"""
TerminalPDEProblem(g, f, μ, σ, x0, tspan)
Expand Down Expand Up @@ -171,14 +173,20 @@ function Base.show(io::IO, A::ParamKolmogorovPDEProblem)
show(io, A.g)
end

abstract type AbstractPINN end

abstract type AbstractTrainingStrategy end

include("pinn_types.jl")
include("symbolic_utilities.jl")
include("training_strategies.jl")
include("adaptive_losses.jl")
include("ode_solve.jl")
include("kolmogorov_solve.jl")
include("rode_solve.jl")
include("stopping_solve.jl")
include("transform_inf_integral.jl")
include("pinns_pde_solve.jl")
include("discretize.jl")
include("neural_adapter.jl")
include("param_kolmogorov_solve.jl")

Expand Down
15 changes: 14 additions & 1 deletion src/adaptive_losses.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@

abstract type AbstractAdaptiveLoss end

# Utils
function vectorify(x, t::Type{T}) where {T <: Real}
convertfunc(y) = convert(t, y)
returnval = if x isa Vector
convertfunc.(x)
else
t[convertfunc(x)]
end
end

# Dispatches

"""
A way of weighting the components of the loss function in the total sum that does not change during optimization

Expand Down Expand Up @@ -40,7 +53,7 @@ A way of adaptively reweighting the components of the loss function in the total
* `additional_loss_weights`: a scalar which describes the weight the additional loss function has in the full loss sum, this is currently not adaptive and will be constant with this adaptive loss,

from paper
Understanding and mitigating gradient pathologies in physics-informed neural networks
Understanding and mitigating gradient pathologies in physics-informed neural networks
Sifan Wang, Yujun Teng, Paris Perdikaris
https://arxiv.org/abs/2001.04536v1
with code reference
Expand Down
Loading