Skip to content

Commit 9ea9a80

Browse files
authored
[Nonlinear.SymbolicAD] simplify quadratic functions if possible (#2685)
1 parent f31be21 commit 9ea9a80

File tree

2 files changed

+205
-69
lines changed

2 files changed

+205
-69
lines changed

src/Nonlinear/SymbolicAD/SymbolicAD.jl

+156-48
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,19 @@ function simplify!(f::MOI.ScalarAffineFunction{T}) where {T}
7777
if isempty(f.terms)
7878
return f.constant
7979
end
80+
if iszero(f.constant) && length(f.terms) == 1
81+
term = only(f.terms)
82+
if isone(term.coefficient)
83+
return term.variable
84+
end
85+
end
8086
return f
8187
end
8288

8389
function simplify!(f::MOI.ScalarQuadraticFunction{T}) where {T}
8490
f = MOI.Utilities.canonicalize!(f)
8591
if isempty(f.quadratic_terms)
86-
if isempty(f.affine_terms)
87-
return f.constant
88-
end
89-
return MOI.ScalarAffineFunction(f.affine_terms, f.constant)
92+
return simplify!(MOI.ScalarAffineFunction(f.affine_terms, f.constant))
9093
end
9194
return f
9295
end
@@ -117,7 +120,7 @@ function simplify!(f::MOI.ScalarNonlinearFunction)
117120
push!(result_stack, arg)
118121
end
119122
end
120-
return _simplify_if_affine!(only(result_stack))
123+
return _simplify_if_quadratic!(only(result_stack))
121124
end
122125

123126
function simplify!(f::MOI.VectorAffineFunction{T}) where {T}
@@ -140,10 +143,12 @@ function simplify!(f::MOI.VectorQuadraticFunction{T}) where {T}
140143
end
141144

142145
function simplify!(f::MOI.VectorNonlinearFunction)
143-
for (i, row) in enumerate(f.rows)
144-
f.rows[i] = simplify!(row)
146+
rows = simplify!.(f.rows)
147+
Y = reduce(promote_type, typeof.(rows))
148+
if isconcretetype(Y)
149+
return MOI.Utilities.vectorize(convert(Vector{Y}, rows))
145150
end
146-
return f
151+
return MOI.VectorNonlinearFunction(rows)
147152
end
148153

149154
# If a ScalarNonlinearFunction has only constant arguments, we should return
@@ -1507,100 +1512,203 @@ function MOI.eval_hessian_lagrangian(model::Evaluator, H, x, σ, μ)
15071512
end
15081513

15091514
# A default fallback for all types
1510-
_add_to_affine!(::Any, ::Any, ::T) where {T} = nothing
1515+
_add_to_quadratic!(::Any, ::Real, ::Any) = nothing
1516+
_add_to_quadratic!(::Any, ::Real, ::Any, ::Any) = nothing
15111517

1512-
# The creation of `ret::MOI.ScalarAffineFunction` has been delayed until now.
1513-
function _add_to_affine!(
1514-
::Nothing,
1515-
f::Union{Real,MOI.VariableIndex,MOI.ScalarAffineFunction},
1518+
# The creation of `ret::MOI.ScalarQuadraticFunction` has been delayed until now.
1519+
function _add_to_quadratic!(
1520+
::Missing,
15161521
scale::T,
1517-
) where {T}
1518-
return _add_to_affine!(zero(MOI.ScalarAffineFunction{T}), f, scale)
1519-
end
1520-
1521-
function _add_to_affine!(
1522-
ret::MOI.ScalarAffineFunction{T},
1523-
x::S,
1522+
f::Union{
1523+
Real,
1524+
MOI.VariableIndex,
1525+
MOI.ScalarAffineFunction,
1526+
MOI.ScalarQuadraticFunction,
1527+
}...,
1528+
) where {T<:Real}
1529+
return _add_to_quadratic!(zero(MOI.ScalarQuadraticFunction{T}), scale, f...)
1530+
end
1531+
1532+
function _add_to_quadratic!(
1533+
ret::MOI.ScalarQuadraticFunction{T},
15241534
scale::T,
1525-
) where {T,S<:Real}
1535+
x::S,
1536+
) where {T<:Real,S<:Real}
15261537
if promote_type(T, S) != T
15271538
return # We can't store `S` in `T`.
15281539
end
15291540
ret.constant += scale * convert(T, x)
15301541
return ret
15311542
end
15321543

1533-
function _add_to_affine!(
1534-
ret::MOI.ScalarAffineFunction{T},
1535-
x::MOI.VariableIndex,
1544+
function _add_to_quadratic!(
1545+
ret::MOI.ScalarQuadraticFunction{T},
15361546
scale::T,
1537-
) where {T}
1538-
push!(ret.terms, MOI.ScalarAffineTerm(scale, x))
1547+
f::MOI.ScalarAffineTerm{S},
1548+
) where {T<:Real,S}
1549+
@assert promote_type(T, S) == T
1550+
push!(
1551+
ret.affine_terms,
1552+
MOI.ScalarAffineTerm{T}(scale * f.coefficient, f.variable),
1553+
)
15391554
return ret
15401555
end
15411556

1542-
function _add_to_affine!(
1543-
ret::MOI.ScalarAffineFunction{T},
1544-
f::MOI.ScalarAffineFunction{S},
1557+
function _add_to_quadratic!(
1558+
ret::MOI.ScalarQuadraticFunction{T},
15451559
scale::T,
1546-
) where {T,S}
1560+
f::MOI.ScalarQuadraticTerm{S},
1561+
) where {T<:Real,S}
1562+
@assert promote_type(T, S) == T
1563+
push!(
1564+
ret.quadratic_terms,
1565+
MOI.ScalarQuadraticTerm{T}(
1566+
scale * f.coefficient,
1567+
f.variable_1,
1568+
f.variable_2,
1569+
),
1570+
)
1571+
return ret
1572+
end
1573+
1574+
function _add_to_quadratic!(
1575+
ret::MOI.ScalarQuadraticFunction{T},
1576+
scale::T,
1577+
x::MOI.VariableIndex,
1578+
) where {T<:Real}
1579+
return _add_to_quadratic!(ret, scale, MOI.ScalarAffineTerm(one(T), x))
1580+
end
1581+
1582+
function _add_to_quadratic!(
1583+
ret::MOI.ScalarQuadraticFunction{T},
1584+
scale::T,
1585+
f::MOI.ScalarAffineFunction{S},
1586+
) where {T<:Real,S}
15471587
if promote_type(T, S) != T
15481588
return # We can't store `S` in `T`.
15491589
end
1550-
ret = _add_to_affine!(ret, f.constant, scale)
1590+
ret = _add_to_quadratic!(ret, scale, f.constant)
15511591
for term in f.terms
1552-
ret = _add_to_affine!(ret, term.variable, scale * term.coefficient)
1592+
ret = _add_to_quadratic!(ret, scale, term)
15531593
end
15541594
return ret
15551595
end
15561596

1557-
function _add_to_affine!(
1558-
ret::Union{Nothing,MOI.ScalarAffineFunction{T}},
1559-
f::MOI.ScalarNonlinearFunction,
1597+
function _add_to_quadratic!(
1598+
ret::MOI.ScalarQuadraticFunction{T},
15601599
scale::T,
1561-
) where {T}
1600+
f::MOI.ScalarQuadraticFunction{S},
1601+
) where {T<:Real,S}
1602+
if promote_type(T, S) != T
1603+
return # We can't store `S` in `T`.
1604+
end
1605+
ret = _add_to_quadratic!(ret, scale, f.constant)
1606+
for term in f.affine_terms
1607+
ret = _add_to_quadratic!(ret, scale, term)
1608+
end
1609+
for q_term in f.quadratic_terms
1610+
ret = _add_to_quadratic!(ret, scale, q_term)
1611+
end
1612+
return ret
1613+
end
1614+
1615+
function _add_to_quadratic!(
1616+
ret::MOI.ScalarQuadraticFunction{T},
1617+
scale::T,
1618+
f::MOI.VariableIndex,
1619+
g::MOI.VariableIndex,
1620+
) where {T<:Real}
1621+
return _add_to_quadratic!(ret, scale, one(T) * f * g)
1622+
end
1623+
1624+
function _add_to_quadratic!(
1625+
ret::MOI.ScalarQuadraticFunction{T},
1626+
scale::T,
1627+
f::MOI.ScalarAffineFunction{F},
1628+
g::MOI.ScalarAffineFunction{G},
1629+
) where {T<:Real,F,G}
1630+
H = MOI.ScalarAffineFunction{promote_type(F, G)}
1631+
return _add_to_quadratic!(ret, scale, convert(H, f) * convert(H, g))
1632+
end
1633+
1634+
function _add_to_quadratic!(
1635+
ret::MOI.ScalarQuadraticFunction{T},
1636+
scale::T,
1637+
f::MOI.VariableIndex,
1638+
g::MOI.ScalarAffineFunction,
1639+
) where {T<:Real}
1640+
return _add_to_quadratic!(ret, scale, f * g)
1641+
end
1642+
1643+
function _add_to_quadratic!(
1644+
ret::MOI.ScalarQuadraticFunction{T},
1645+
scale::T,
1646+
f::MOI.ScalarAffineFunction,
1647+
g::MOI.VariableIndex,
1648+
) where {T<:Real}
1649+
return _add_to_quadratic!(ret, scale, g, f)
1650+
end
1651+
1652+
function _add_to_quadratic!(
1653+
ret::Union{Missing,MOI.ScalarQuadraticFunction{T}},
1654+
scale::T,
1655+
f::MOI.ScalarNonlinearFunction,
1656+
) where {T<:Real}
15621657
if f.head == :+
15631658
for arg in f.args
1564-
ret = _add_to_affine!(ret, arg, scale)
1659+
ret = _add_to_quadratic!(ret, scale, arg)
15651660
if ret === nothing
15661661
return
15671662
end
15681663
end
15691664
return ret
15701665
elseif f.head == :-
15711666
if length(f.args) == 1
1572-
return _add_to_affine!(ret, only(f.args), -scale)
1667+
return _add_to_quadratic!(ret, -scale, only(f.args))
15731668
end
15741669
@assert length(f.args) == 2
1575-
ret = _add_to_affine!(ret, f.args[1], scale)
1670+
ret = _add_to_quadratic!(ret, scale, f.args[1])
15761671
if ret === nothing
15771672
return
15781673
end
1579-
return _add_to_affine!(ret, f.args[2], -scale)
1674+
return _add_to_quadratic!(ret, -scale, f.args[2])
15801675
elseif f.head == :*
1581-
y = nothing
1676+
y1, y2 = nothing, nothing
15821677
for arg in f.args
15831678
if arg isa Real
15841679
scale *= arg
1585-
elseif y === nothing
1586-
y = arg
1680+
elseif y1 === nothing
1681+
y1 = arg
1682+
elseif y2 === nothing
1683+
y2 = arg
15871684
else
15881685
return # We already have a `y`. Can't multiple factors.
15891686
end
15901687
end
1591-
return _add_to_affine!(ret, something(y, one(T)), convert(T, scale))
1688+
if y1 === nothing
1689+
@assert y2 === nothing
1690+
return _add_to_quadratic!(ret, one(T), scale)
1691+
elseif y2 === nothing
1692+
return _add_to_quadratic!(ret, scale, y1)
1693+
else
1694+
return _add_to_quadratic!(ret, scale, y1, y2)
1695+
end
1696+
elseif f.head == :^ && f.args[2] isa Real && f.args[2] == 2
1697+
return _add_to_quadratic!(ret, scale, f.args[1], f.args[1])
1698+
elseif f.head == :/ && f.args[2] isa Real
1699+
return _add_to_quadratic!(ret, convert(T, scale / f.args[2]), f.args[1])
15921700
end
15931701
return # An unsupported f.head
15941702
end
15951703

1596-
function _simplify_if_affine!(f::MOI.ScalarNonlinearFunction)
1597-
ret = _add_to_affine!(nothing, f, 1.0)
1704+
function _simplify_if_quadratic!(f::MOI.ScalarNonlinearFunction)
1705+
ret = _add_to_quadratic!(missing, 1.0, f)
15981706
if ret === nothing
15991707
return f
16001708
end
1601-
return simplify!(ret::MOI.ScalarAffineFunction{Float64})
1709+
return simplify!(ret::MOI.ScalarQuadraticFunction{Float64})
16021710
end
16031711

1604-
_simplify_if_affine!(f::Any) = f
1712+
_simplify_if_quadratic!(f::Any) = f
16051713

16061714
end # module

0 commit comments

Comments
 (0)