diff --git a/Project.toml b/Project.toml index 5eaf3541f..ae8c124f4 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45" DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a" +JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" @@ -21,6 +22,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" StatsModels = "3eaba693-59b7-5ba5-a881-562e759f1c8d" +StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] @@ -28,6 +30,7 @@ Arrow = "1" DataAPI = "1" Distributions = "0.21, 0.22, 0.23, 0.24" GLM = "1" +JSON3 = "1" LazyArtifacts = "1" NLopt = "0.5, 0.6" PooledArrays = "0.5, 1" @@ -36,6 +39,7 @@ StaticArrays = "0.11, 0.12, 1" StatsBase = "0.31, 0.32, 0.33" StatsFuns = "0.8, 0.9" StatsModels = "0.6" +StructTypes = "1" Tables = "1" julia = "1.4" diff --git a/src/MixedModels.jl b/src/MixedModels.jl index 49afef0ae..47c890287 100644 --- a/src/MixedModels.jl +++ b/src/MixedModels.jl @@ -4,6 +4,7 @@ using Arrow using DataAPI using Distributions using GLM +using JSON3 using LazyArtifacts using LinearAlgebra using Markdown @@ -16,6 +17,7 @@ using StaticArrays using Statistics using StatsBase using StatsModels +using StructTypes using Tables using LinearAlgebra: BlasFloat, BlasReal, HermOrSym, PosDefException, copytri! @@ -108,6 +110,8 @@ export @formula, replicate, residuals, response, + restoreoptsum!, + saveoptsum, shortestcovint, sdest, setθ!, diff --git a/src/linearmixedmodel.jl b/src/linearmixedmodel.jl index 908f89bfc..753d75f12 100644 --- a/src/linearmixedmodel.jl +++ b/src/linearmixedmodel.jl @@ -782,6 +782,44 @@ StatsBase.residuals(m::LinearMixedModel) = response(m) .- fitted(m) StatsBase.response(m::LinearMixedModel) = m.y +""" + restoreoptsum!(m::LinearMixedModel, io::IO) + restoreoptsum!(m::LinearMixedModel, fnm::AbstractString) + +Read, check, and restore the `optsum` field from a JSON stream or filename. +""" +function restoreoptsum!(m::LinearMixedModel, io::IO) + dict = JSON3.read(io) + ops = m.optsum + okay = (setdiff(propertynames(ops), keys(dict)) == [:lowerbd]) && + all(ops.lowerbd .≤ dict.initial) && + all(ops.lowerbd .≤ dict.final) + if !okay + throw(ArgumentError("initial or final parameters in io do not satify lowerbd")) + end + for fld in (:feval, :finitial, :fmin, :ftol_rel, :ftol_abs, :maxfeval, :nAGQ, :REML) + setproperty!(ops, fld, getproperty(dict, fld)) + end + ops.initial_step = copy(dict.initial_step) + ops.xtol_rel = copy(dict.xtol_rel) + copyto!(ops.initial, dict.initial) + copyto!(ops.final, dict.final) + for (v, f) in (:initial => :finitial, :final => :fmin) + if !isapprox(objective(updateL!(setθ!(m, getfield(ops, v)))), getfield(ops, f)) + throw(ArgumentError("model m at $v does not give stored $f")) + end + end + ops.optimizer = Symbol(dict.optimizer) + ops.returnvalue = Symbol(dict.returnvalue) + m +end + +function restoreoptsum!(m::LinearMixedModel, fnm::AbstractString) + open(fnm, "r") do io + restoreoptsum!(m, io) + end +end + function reweight!(m::LinearMixedModel, weights) sqrtwts = map!(sqrt, m.sqrtwts, weights) reweight!.(m.reterms, Ref(sqrtwts)) @@ -790,6 +828,22 @@ function reweight!(m::LinearMixedModel, weights) updateL!(m) end +""" + saveoptsum(io::IO, m::LinearMixedModel) + saveoptsum(fnm::AbstractString, m::LinearMixedModel) + +Save `m.optsum` (w/o the `lowerbd` field) in JSON format to an IO stream or a file + +The reason for omitting the `lowerbd` field is because it often contains `-Inf` +values that are not allowed in JSON. +""" +saveoptsum(io::IO, m::LinearMixedModel) = JSON3.write(io, m.optsum) +function saveoptsum(fnm::AbstractString, m::LinearMixedModel) + open(fnm, "w") do io + saveoptsum(io, m) + end +end + """ sdest(m::LinearMixedModel) @@ -802,7 +856,7 @@ sdest(m::LinearMixedModel) = √varest(m) Install `v` as the θ parameters in `m`. """ -function setθ!(m::LinearMixedModel{T}, θ::Vector{T}) where {T} +function setθ!(m::LinearMixedModel{T}, θ::AbstractVector) where {T} parmap, reterms = m.parmap, m.reterms length(θ) == length(parmap) || throw(DimensionMismatch()) reind = 1 diff --git a/src/optsummary.jl b/src/optsummary.jl index bc2ed089c..5c375a3c1 100644 --- a/src/optsummary.jl +++ b/src/optsummary.jl @@ -110,3 +110,6 @@ function NLopt.Opt(optsum::OptSummary) end opt end + +StructTypes.StructType(::Type{<:OptSummary}) = StructTypes.Mutable() +StructTypes.excludes(::Type{<:OptSummary}) = (:lowerbd, ) diff --git a/test/pls.jl b/test/pls.jl index ef99450f7..200543907 100644 --- a/test/pls.jl +++ b/test/pls.jl @@ -401,6 +401,24 @@ end @test countlines(seekstart(io)) == 3 @test "BlkDiag" in Set(split(String(take!(io)), r"\s+")) + @testset "optsumJSON" begin + fm = last(models(:sleepstudy)) + # using a IOBuffer for saving JSON + saveoptsum(seekstart(io), fm) + m = LinearMixedModel(fm.formula, MixedModels.dataset(:sleepstudy)) + restoreoptsum!(m, seekstart(io)) + @test loglikelihood(fm) ≈ loglikelihood(m) + @test bic(fm) ≈ bic(m) + @test coef(fm) ≈ coef(m) + # using a temporary file for saving JSON + fnm = first(mktemp()) + saveoptsum(fnm, fm) + m = LinearMixedModel(fm.formula, MixedModels.dataset(:sleepstudy)) + restoreoptsum!(m, fnm) + @test loglikelihood(fm) ≈ loglikelihood(m) + @test bic(fm) ≈ bic(m) + @test coef(fm) ≈ coef(m) + end end @testset "d3" begin