Skip to content

Commit

Permalink
Add AbstractReMat, make reterms and feterms fields (#380)
Browse files Browse the repository at this point in the history
* Add AbstractReMat, make reterms and feterms fields

* Clean up problems with tests

* drop fetrm and feind methods
  • Loading branch information
dmbates authored Sep 28, 2020
1 parent ad29709 commit fd07f92
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 47 deletions.
1 change: 1 addition & 0 deletions src/MixedModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import NLopt: Opt
import StatsBase: fit, fit!

export @formula,
AbstractReMat,
Bernoulli,
Binomial,
Block,
Expand Down
7 changes: 5 additions & 2 deletions src/generalizedlinearmixedmodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,8 @@ function GeneralizedLinearMixedModel(
LMM = LinearMixedModel(
LMM.formula,
LMM.allterms,
LMM.reterms,
LMM.feterms,
fill!(similar(y), 1),
LMM.parmap,
LMM.dims,
Expand Down Expand Up @@ -412,7 +414,9 @@ function Base.getproperty(m::GeneralizedLinearMixedModel, s::Symbol)
σs(m)
elseif s == :σρs
σρs(m)
elseif s (:A, :L, , :lowerbd, :corr, :PCA, :rePCA, :optsum, :X, :reterms, :feterms, :formula)
elseif s (:A, :L, :optsum, :allterms, :reterms, :feterms, :formula)
getfield(m.LMM, s)
elseif s (, :lowerbd, :corr, :PCA, :rePCA, :X,)
getproperty(m.LMM, s)
elseif s == :y
m.resp.y
Expand Down Expand Up @@ -612,7 +616,6 @@ varest(m::GeneralizedLinearMixedModel{T}) where {T} = one(T)
# delegate GLMM method to LMM field
for f in (
:feL,
:fetrm,
:(LinearAlgebra.logdet),
:lowerbd,
:PCA,
Expand Down
62 changes: 24 additions & 38 deletions src/linearmixedmodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ Linear mixed-effects model representation
* `formula`: the formula for the model
* `allterms`: a vector of random-effects terms, the fixed-effects terms and the response
* `reterms`: a `Vector{AbstractReMat{T}}` of random-effects terms.
* `feterms`: a `Vector{FeMat{T}}` of the fixed-effects model matrix and the response
* `sqrtwts`: vector of square roots of the case weights. Can be empty.
* `parmap` : Vector{NTuple{3,Int}} of (block, row, column) mapping of θ to λ
* `dims` : NamedTuple{(:n, :p, :nretrms),NTuple{3,Int}} of dimensions. `p` is the rank of `X`, which may be smaller than `size(X, 2)`.
Expand All @@ -21,16 +23,16 @@ Linear mixed-effects model representation
* `λ` or `lambda`: a vector of lower triangular matrices repeated on the diagonal blocks of `Λ`
* `σ` or `sigma`: current value of the standard deviation of the per-observation noise
* `b`: random effects on the original scale, as a vector of matrices
* `reterms`: a `Vector{ReMat{T}}` of random-effects terms.
* `feterms`: a `Vector{FeMat{T}}` of the fixed-effects model matrix and the response
* `u`: random effects on the orthogonal scale, as a vector of matrices
* `lowerbd`: lower bounds on the elements of θ
* `X`: the fixed-effects model matrix
* `y`: the response vector
"""
struct LinearMixedModel{T<:AbstractFloat} <: MixedModel{T}
formula::FormulaTerm
allterms::Vector{Union{ReMat{T}, FeMat{T}}}
allterms::Vector{Union{AbstractReMat{T}, FeMat{T}}}
reterms::Vector{AbstractReMat{T}}
feterms::Vector{FeMat{T}}
sqrtwts::Vector{T}
parmap::Vector{NTuple{3,Int}}
dims::NamedTuple{(:n, :p, :nretrms),NTuple{3,Int}}
Expand Down Expand Up @@ -70,10 +72,10 @@ function LinearMixedModel(
y = reshape(float(y), (:, 1)) # y as a floating-point matrix
T = eltype(y)

reterms = ReMat{T}[]
reterms = AbstractReMat{T}[]
feterms = FeMat{T}[]
for (i, x) in enumerate(Xs)
if isa(x, ReMat{T})
if isa(x, AbstractReMat{T})
push!(reterms, x)
else
cnames = coefnames(form.rhs[i])
Expand All @@ -88,7 +90,7 @@ function LinearMixedModel(
end

sort!(reterms, by = nranef, rev = true)
allterms = convert(Vector{Union{ReMat{T},FeMat{T}}}, vcat(reterms, feterms))
allterms = convert(Vector{Union{AbstractReMat{T},FeMat{T}}}, vcat(reterms, feterms))
sqrtwts = sqrt.(convert(Vector{T}, wts))
reweight!.(allterms, Ref(sqrtwts))
A, L = createAL(allterms)
Expand All @@ -100,6 +102,8 @@ function LinearMixedModel(
LinearMixedModel(
form,
allterms,
reterms,
feterms,
sqrtwts,
mkparmap(reterms),
(n = size(X, 1), p = X.rank, nretrms = length(reterms)),
Expand Down Expand Up @@ -183,14 +187,14 @@ fit(
)

function StatsBase.coef(m::LinearMixedModel{T}) where {T}
piv = fetrm(m).piv
piv = first(m.feterms).piv
invpermute!(fixef!(similar(piv, T), m), piv)
end

βs(m::LinearMixedModel) = NamedTuple{(Symbol.(coefnames(m))...,)}(coef(m))

function StatsBase.coefnames(m::LinearMixedModel)
Xtrm = fetrm(m)
Xtrm = first(m.feterms)
invpermute!(copy(Xtrm.cnames), Xtrm.piv)
end

Expand Down Expand Up @@ -240,7 +244,7 @@ function condVar(m::LinearMixedModel{T}) where {T}
Array{T,3}[reshape(abs2.(ll ./ Ld) .* varest(m), (1, 1, length(Ld)))]
end

function createAL(allterms::Vector{Union{ReMat{T},FeMat{T}}}) where {T}
function createAL(allterms::Vector{Union{AbstractReMat{T},FeMat{T}}}) where {T}
k = length(allterms)
sz = [isa(t, ReMat) ? size(t, 2) : rank(t) for t in allterms]
A = BlockArray(undef_blocks, AbstractMatrix{T}, sz, sz)
Expand Down Expand Up @@ -285,31 +289,17 @@ function StatsBase.dof_residual(m::LinearMixedModel)::Int
dd.n - dd.p - 1
end

"""
feind(m::LinearMixedModel)
An internal utility to return the index in `m.allterms` of the fixed-effects term.
"""
feind(m::LinearMixedModel) = m.dims.nretrms + 1

"""
feL(m::LinearMixedModel)
Return the lower Cholesky factor for the fixed-effects parameters, as an `LowerTriangular`
`p × p` matrix.
"""
function feL(m::LinearMixedModel)
k = feind(m)
k = m.dims.nretrms + 1
LowerTriangular(m.L.blocks[k, k])
end

"""
fetrm(m::LinearMixedModel)
Return the fixed-effects term from `m.allterms`
"""
fetrm(m::LinearMixedModel) = m.allterms[feind(m)]

"""
fit!(m::LinearMixedModel[; verbose::Bool=false, REML::Bool=false])
Expand Down Expand Up @@ -359,7 +349,7 @@ end

function fitted!(v::AbstractArray{T}, m::LinearMixedModel{T}) where {T}
## FIXME: Create and use `effects(m) -> β, b` w/o calculating β twice
Xtrm = fetrm(m)
Xtrm = first(m.feterms)
vv = mul!(vec(v), Xtrm, fixef!(similar(Xtrm.piv, T), m))
for (rt, bb) in zip(m.reterms, ranef(m))
unscaledre!(vv, rt, bb)
Expand All @@ -379,7 +369,7 @@ the length of `v` can be the rank of `X` or the number of columns of `X`. In th
case the calculated coefficients are padded with -0.0 out to the number of columns.
"""
function fixef!(v::AbstractVector{T}, m::LinearMixedModel{T}) where {T}
Xtrm = fetrm(m)
Xtrm = first(m.feterms)
if isfullrank(Xtrm)
ldiv!(feL(m)', copyto!(v, m.L.blocks[end, end-1]))
else
Expand All @@ -400,15 +390,15 @@ In the rank-deficient case the truncated parameter vector, of length `rank(m)` i
This is unlike `coef` which always returns a vector whose length matches the number of
columns in `X`.
"""
fixef(m::LinearMixedModel{T}) where {T} = fixef!(Vector{T}(undef, fetrm(m).rank), m)
fixef(m::LinearMixedModel{T}) where {T} = fixef!(Vector{T}(undef, first(m.feterms).rank), m)

"""
fixefnames(m::MixedModel)
Return a (permuted and truncated in the rank-deficient case) vector of coefficient names.
"""
function fixefnames(m::LinearMixedModel{T}) where {T}
Xtrm = fetrm(m)
Xtrm = first(m.feterms)
Xtrm.cnames[1:Xtrm.rank]
end

Expand Down Expand Up @@ -461,8 +451,6 @@ function Base.getproperty(m::LinearMixedModel{T}, s::Symbol) where {T}
σρs(m)
elseif s == :b
ranef(m)
elseif s == :feterms
convert(Vector{FeMat{T}}, filter(Base.Fix2(isa, FeMat), getfield(m, :allterms)))
elseif s == :objective
objective(m)
elseif s == :corr
Expand All @@ -473,8 +461,6 @@ function Base.getproperty(m::LinearMixedModel{T}, s::Symbol) where {T}
NamedTuple{fnames(m)}(PCA.(m.reterms))
elseif s == :pvalues
ccdf.(Chisq(1), abs2.(coef(m) ./ stderror(m)))
elseif s == :reterms
convert(Vector{ReMat{T}}, getfield(m, :allterms)[Base.OneTo(getfield(m, :dims).nretrms)])
elseif s == :stderror
stderror(m)
elseif s == :u
Expand Down Expand Up @@ -529,7 +515,7 @@ end

lowerbd(m::LinearMixedModel) = m.optsum.lowerbd

function mkparmap(reterms::Vector)
function mkparmap(reterms::Vector{AbstractReMat{T}}) where {T}
parmap = NTuple{3,Int}[]
for (k, trm) in enumerate(reterms)
n = LinearAlgebra.checksquare(trm.λ)
Expand All @@ -542,7 +528,7 @@ function mkparmap(reterms::Vector)
end

function StatsBase.modelmatrix(m::LinearMixedModel)
fe = fetrm(m)
fe = first(m.feterms)
if fe.rank == size(fe, 2)
fe.x
else
Expand Down Expand Up @@ -622,7 +608,7 @@ function ranef!(
β::AbstractArray{T},
uscale::Bool,
) where {T}
(k = length(v)) == length(m.reterms) || throw(DimensionMismatch(""))
(k = length(v)) == m.dims.nretrms || throw(DimensionMismatch(""))
L = m.L
for j = 1:k
mul!(
Expand Down Expand Up @@ -750,7 +736,7 @@ sdest(m::LinearMixedModel) = √varest(m)
Install `v` as the θ parameters in `m`.
"""
function setθ!(m::LinearMixedModel{T}, θ::Vector{T}) where {T}
parmap, reterms = m.parmap, m.allterms
parmap, reterms = m.parmap, m.reterms
length(θ) == length(parmap) || throw(DimensionMismatch())
reind = 1
λ = first(reterms).λ
Expand Down Expand Up @@ -871,12 +857,12 @@ function stderror!(v::AbstractVector{T}, m::LinearMixedModel{T}) where {T}
scr[i] = true
v[i] = s * norm(ldiv!(L, scr))
end
invpermute!(v, fetrm(m).piv)
invpermute!(v, first(m.feterms).piv)
v
end

function StatsBase.stderror(m::LinearMixedModel{T}) where {T}
stderror!(similar(fetrm(m).piv, T), m)
stderror!(similar(first(m.feterms).piv, T), m)
end

"""
Expand Down
2 changes: 1 addition & 1 deletion src/mixedmodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Returns the variance-covariance matrix of the fixed effects.
If `corr=true`, then correlation of fixed effects is returned instead.
"""
function StatsBase.vcov(m::MixedModel; corr=false)
Xtrm = fetrm(m)
Xtrm = first(m isa GeneralizedLinearMixedModel ? m.LMM.feterms : m.feterms)
iperm = invperm(Xtrm.piv)
p = length(iperm)
r = Xtrm.rank
Expand Down
12 changes: 6 additions & 6 deletions src/remat.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
abstract type AbstractReMat{T} <: AbstractMatrix{T} end

"""
ReMat{T,S} <: AbstractMatrix{T}
Expand All @@ -12,7 +14,7 @@ A section of a model matrix generated by a random-effects term.
- `inds`: a `Vector{Int}` of linear indices of the potential nonzeros in `λ`
- `adjA`: the adjoint of the matrix as a `SparseMatrixCSC{T}`
"""
mutable struct ReMat{T,S} <: AbstractMatrix{T}
mutable struct ReMat{T,S} <: AbstractReMat{T}
trm
refs::Vector{Int32}
levels
Expand All @@ -26,21 +28,19 @@ mutable struct ReMat{T,S} <: AbstractMatrix{T}
end

"""
amalgamate(reterms::Vector{ReMat})
amalgamate(reterms::Vector{AbstractReMat})
Combine multiple ReMat with the same grouping variable into a single object.
"""
amalgamate(reterms::Vector{ReMat{T,S} where S}) where {T} = _amalgamate(reterms,T)
# constant S
amalgamate(reterms::Vector{ReMat{T,S}}) where {T,S} = _amalgamate(reterms,T)
amalgamate(reterms::Vector{AbstractReMat{T}}) where {T} = _amalgamate(reterms,T)

function _amalgamate(reterms::Vector, T::Type)
factordict = Dict{Symbol, Vector{Int}}()
for (i, rt) in enumerate(reterms)
push!(get!(factordict, fname(rt), Int[]), i)
end
length(factordict) == length(reterms) && return reterms
value = ReMat{T}[]
value = AbstractReMat{T}[]
for (f, inds) in factordict
if isone(length(inds))
push!(value, reterms[only(inds)])
Expand Down

0 comments on commit fd07f92

Please sign in to comment.