Skip to content

Commit

Permalink
Test evaluate subtraction with product with real
Browse files Browse the repository at this point in the history
  • Loading branch information
asterycs committed Feb 16, 2025
1 parent 907a857 commit 15d5f3a
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -728,11 +728,11 @@ function _sub_from_product(arg1::BinaryOperation{Mult}, arg2::Value)
end

if evaluate(arg1.arg1) isa Real && evaluate(arg1.arg2) == evaluate(arg2)
return BinaryOperation{Mult}(evaluate(arg1.arg1) - 1, evaluate(arg2))
return evaluate(BinaryOperation{Mult}(evaluate(arg1.arg1) - 1, evaluate(arg2)))
end

if evaluate(arg1.arg2) isa Real && evaluate(arg1.arg1) == evaluate(arg2)
return BinaryOperation{Mult}(evaluate(arg1.arg2) - 1, evaluate(arg2))
return evaluate(BinaryOperation{Mult}(evaluate(arg1.arg2) - 1, evaluate(arg2)))
end

return BinaryOperation{Sub}(evaluate(arg1), evaluate(arg2))
Expand Down
17 changes: 17 additions & 0 deletions test/ForwardTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,23 @@ end
@test equivalent(evaluate(op2), Zero(Upper(2)))
end

@testset "evaluate subtraction with product with real" begin
A = Tensor("A", Upper(1), Lower(2))

function mul(l, r)
return dc.BinaryOperation{dc.Mult}(l, r)
end

function sub(l, r)
return dc.BinaryOperation{dc.Sub}(l, r)
end

@test dc.evaluate(sub(mul(2, A), A)) == A
@test dc.evaluate(sub(mul(A, 2), A)) == A
# @test dc.evaluate(sub(mul(A, 2), A)) == A
# @test dc.evaluate(sub(mul(A, 2), A)) == A
end

@testset "evaluate unary operations" begin
A = Tensor("A", Upper(1), Lower(2))

Expand Down

0 comments on commit 15d5f3a

Please sign in to comment.