Skip to content

Commit

Permalink
Remove unneeded eliminate_indices and eliminated_indices
Browse files Browse the repository at this point in the history
  • Loading branch information
asterycs committed Feb 15, 2025
1 parent 3b0899b commit 3872deb
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 92 deletions.
6 changes: 3 additions & 3 deletions src/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ function evaluate(::Mult, arg1::KrD, arg2::Zero)
end

function evaluate(::Mult, arg1::Zero, arg2::KrD)
contracting_index = eliminated_indices(get_free_indices(arg1), get_free_indices(arg2))
contracting_index = eliminated_indices([get_indices(arg1); get_free_indices(arg2)])

if isempty(contracting_index)
return Zero(union(arg1.indices, arg2.indices)...)
Expand All @@ -262,7 +262,7 @@ end

function evaluate(::Mult, arg1::KrD, arg2::Tensor)
arg2_indices = get_free_indices(arg2)
contracting_index = eliminated_indices(get_free_indices(arg1), arg2_indices)
contracting_index = eliminated_indices([get_indices(arg1); arg2_indices])

if is_diag(arg2, arg1)
return BinaryOperation{Mult}(arg1, arg2)
Expand Down Expand Up @@ -347,7 +347,7 @@ end

function evaluate(::Mult, arg1::Union{Tensor,KrD}, arg2::KrD)
arg1_indices = get_free_indices(arg1)
contracting_index = eliminated_indices(arg1_indices, get_free_indices(arg2))
contracting_index = eliminated_indices([arg1_indices; get_indices(arg2)])

if isempty(contracting_index) # Is an outer product
return BinaryOperation{Mult}(arg1, arg2)
Expand Down
39 changes: 0 additions & 39 deletions src/ricci.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,37 +118,6 @@ function cos(arg::TensorExpr)
return Cos(arg)
end

function _eliminate_indices(arg1::IndexList, arg2::IndexList)
CanBeNothing = Union{Nothing,Lower,Upper}
available1 = CanBeNothing[i for i unique(arg1)]
available2 = CanBeNothing[i for i unique(arg2)]
eliminated = LowerOrUpperIndex[]

for i eachindex(available1)
if isnothing(available1[i])
continue
end

for j eachindex(available2)
if isnothing(available2[j])
continue
end

if flip(available2[j]) == available1[i] # contraction
push!(eliminated, available1[i])
push!(eliminated, available2[j])
available1[i] = nothing
available2[j] = nothing
end
end
end

filtered1 = filter(i -> i available1, arg1)
filtered2 = filter(i -> i available2, arg2)

return (filtered1, filtered2), eliminated
end

function _eliminate_indices(arg::IndexList)
CanBeNothing = Union{Nothing,Lower,Upper}
available = CanBeNothing[i for i unique(arg)]
Expand Down Expand Up @@ -187,14 +156,6 @@ function eliminated_indices(arg::IndexList)
return setdiff(arg, remaining)
end

function eliminate_indices(arg1::IndexList, arg2::IndexList)
return first(_eliminate_indices(arg1, arg2))
end

function eliminated_indices(arg1::IndexList, arg2::IndexList)
return last(_eliminate_indices(arg1, arg2))
end

function count_values(input::AbstractArray{T}) where {T}
return Dict((i => count(==(i), input)) for i unique(input))
end
Expand Down
50 changes: 0 additions & 50 deletions test/RicciTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,56 +163,6 @@ end
@test dc.flip(Upper(3)) == Lower(3)
end

@testset "eliminate_indices removes correct indices" begin
IdxUnion = dc.LowerOrUpperIndex

indicesl = IdxUnion[
Lower(9)
Upper(9)
Upper(3)
Lower(2)
Lower(1)
]

indicesr = IdxUnion[
Lower(3)
Lower(2)
Upper(3)
Upper(9)
Lower(9)
]

l, r = dc.eliminate_indices(indicesl, indicesr)

@test [l; r] == [Lower(2); Lower(1); Lower(2); Upper(3)]
@test dc.eliminate_indices(IdxUnion[], IdxUnion[]) == (IdxUnion[], IdxUnion[])
end

@testset "eliminated_indices retains correct indices" begin
IdxUnion = dc.LowerOrUpperIndex

indicesl = IdxUnion[
Lower(9)
Upper(9)
Upper(3)
Lower(2)
Lower(1)
]

indicesr = IdxUnion[
Lower(3)
Lower(2)
Upper(3)
Upper(9)
Lower(9)
]

eliminated = dc.eliminated_indices(indicesl, indicesr)

@test eliminated == IdxUnion[Lower(9); Upper(9); Upper(9); Lower(9); Upper(3); Lower(3)]
@test dc.eliminated_indices(IdxUnion[], IdxUnion[]) == IdxUnion[]
end

@testset "get_free_indices with Tensor * Tensor and one matching pair" begin
xt = Tensor("x", Lower(1)) # row vector
A = Tensor("A", Upper(1), Lower(2))
Expand Down

0 comments on commit 3872deb

Please sign in to comment.