Skip to content

Commit

Permalink
Extend and test sum of sums
Browse files Browse the repository at this point in the history
  • Loading branch information
asterycs committed Feb 21, 2025
1 parent 7f7e40f commit ba4d82c
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 2 deletions.
31 changes: 29 additions & 2 deletions src/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -577,8 +577,35 @@ function _add_to_product(arg1::BinaryOperation{Mult}, arg2::BinaryOperation{Mult
end

function evaluate(::Add, arg1::BinaryOperation{Add}, arg2::BinaryOperation{Add})
# TODO: extend and change the below overload to (BinaryOp{Add}, UnaryValue)
return invoke(evaluate, Tuple{Add,BinaryOperation{Add},Value}, Add(), arg1, arg2)
if arg1.arg1 == arg2.arg1
return BinaryOperation{Add}(
BinaryOperation{Mult}(2, arg1.arg1),
evaluate(BinaryOperation{Add}(arg1.arg2, arg2.arg2)),
)
end

if arg1.arg1 == arg2.arg2
return BinaryOperation{Add}(
BinaryOperation{Mult}(2, arg1.arg1),
evaluate(BinaryOperation{Add}(arg1.arg2, arg2.arg1)),
)
end

if arg1.arg2 == arg2.arg1
return BinaryOperation{Add}(
BinaryOperation{Mult}(2, arg1.arg2),
evaluate(BinaryOperation{Add}(arg1.arg1, arg2.arg2)),
)
end

if arg1.arg2 == arg2.arg2
return BinaryOperation{Add}(
BinaryOperation{Mult}(2, arg1.arg2),
evaluate(BinaryOperation{Add}(arg1.arg1, arg2.arg1)),
)
end

return BinaryOperation{Add}(arg1, arg2)
end

function evaluate(::Add, arg1::BinaryOperation{Add}, arg2::Zero)
Expand Down
38 changes: 38 additions & 0 deletions test/ForwardTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,44 @@ end
@test evaluate(dc.BinaryOperation{dc.Add}(z, d)) == d
end

@testset "evaluate sum of addition and addition" begin
a = Tensor("a", Upper(1))
b = Tensor("b", Upper(1))
c = Tensor("c", Upper(1))
d = Tensor("d", Upper(1))

l = dc.BinaryOperation{dc.Add}(a, b)
r = dc.BinaryOperation{dc.Add}(a, c)
s = dc.BinaryOperation{dc.Add}(l, r)

@test dc.evaluate(s) == dc.evaluate(2 * a + (b + c))

l = dc.BinaryOperation{dc.Add}(a, b)
r = dc.BinaryOperation{dc.Add}(c, a)
s = dc.BinaryOperation{dc.Add}(l, r)

@test dc.evaluate(s) == dc.evaluate(2 * a + (b + c))

l = dc.BinaryOperation{dc.Add}(b, a)
r = dc.BinaryOperation{dc.Add}(a, c)
s = dc.BinaryOperation{dc.Add}(l, r)

@test dc.evaluate(s) == dc.evaluate(2 * a + (b + c))

l = dc.BinaryOperation{dc.Add}(b, a)
r = dc.BinaryOperation{dc.Add}(c, a)
s = dc.BinaryOperation{dc.Add}(l, r)

@test dc.evaluate(s) == dc.evaluate(2 * a + (b + c))

l = dc.BinaryOperation{dc.Add}(a, b)
r = dc.BinaryOperation{dc.Add}(c, d)
s = dc.BinaryOperation{dc.Add}(l, r)

@test dc.evaluate(s) == dc.evaluate((a + b) + (c + d))
end


@testset "evaluate sum of subtraction and addition" begin
a = Tensor("a", Upper(1))
b = Tensor("b", Upper(1))
Expand Down

0 comments on commit ba4d82c

Please sign in to comment.