Skip to content

Commit 5a9bf57

Browse files
authored
Merge pull request #597 from gridap/shifted_nabla
Adding new differential operators
2 parents 74622e5 + e3c116d commit 5a9bf57

File tree

9 files changed

+195
-6
lines changed

9 files changed

+195
-6
lines changed

src/CellData/CellFields.jl

+8
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,14 @@ for op in (:inner,:outer,:double_contraction,:+,:-,:*,:cross,:dot,:/)
479479
end
480480
end
481481

482+
Base.broadcasted(f,a::CellField,b::CellField) = Operation((i,j)->f.(i,j))(a,b)
483+
Base.broadcasted(f,a::Number,b::CellField) = Operation((i,j)->f.(i,j))(a,b)
484+
Base.broadcasted(f,a::CellField,b::Number) = Operation((i,j)->f.(i,j))(a,b)
485+
Base.broadcasted(f,a::Function,b::CellField) = Operation((i,j)->f.(i,j))(a,b)
486+
Base.broadcasted(f,a::CellField,b::Function) = Operation((i,j)->f.(i,j))(a,b)
487+
Base.broadcasted(::typeof(*),::typeof(∇),f::CellField) = Operation(Fields._extract_grad_diag)((f))
488+
Base.broadcasted(::typeof(*),s::Fields.ShiftedNabla,f::CellField) = Operation(Fields._extract_grad_diag)(s(f))
489+
482490
dot(::typeof(∇),f::CellField) = divergence(f)
483491
function (*)(::typeof(∇),f::CellField)
484492
msg = "Syntax ∇*f has been removed, use ∇⋅f (\\nabla \\cdot f) instead"

src/Fields/DiffOperators.jl

+51
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,54 @@ Equivalent to
116116
"""
117117
cross(::typeof(∇),f::Field) = curl(f)
118118
cross(::typeof(∇),f::Function) = curl(f)
119+
120+
_extract_grad_diag(x::TensorValue) = diag(x)
121+
_extract_grad_diag(x) = @notimplemented
122+
123+
function Base.broadcasted(::typeof(*),::typeof(∇),f)
124+
g = (f)
125+
Operation(_extract_grad_diag)(g)
126+
end
127+
128+
function Base.broadcasted(::typeof(*),::typeof(∇),f::Function)
129+
Base.broadcasted(*,∇,GenericField(f))
130+
end
131+
132+
struct ShiftedNabla{N,T}
133+
v::VectorValue{N,T}
134+
end
135+
136+
(+)(::typeof(∇),v::VectorValue) = ShiftedNabla(v)
137+
(+)(v::VectorValue,::typeof(∇)) = ShiftedNabla(v)
138+
(-)(::typeof(∇),v::VectorValue) = ShiftedNabla(-v)
139+
140+
function (s::ShiftedNabla)(f)
141+
Operation((a,b)->a+s.vb)(gradient(f),f)
142+
end
143+
144+
(s::ShiftedNabla)(f::Function) = s(GenericField(f))
145+
146+
function evaluate!(cache,k::Broadcasting{<:ShiftedNabla},f)
147+
s = k.f
148+
g = Broadcasting(∇)(f)
149+
Broadcasting(Operation((a,b)->a+s.vb))(g,f)
150+
end
151+
152+
dot(s::ShiftedNabla,f) = Operation(tr)(s(f))
153+
outer(s::ShiftedNabla,f) = s(f)
154+
outer(f,s::ShiftedNabla) = transpose(gradient(f))
155+
cross(s::ShiftedNabla,f) = Operation(grad2curl)(s(f))
156+
157+
dot(s::ShiftedNabla,f::Function) = dot(s,GenericField(f))
158+
outer(s::ShiftedNabla,f::Function) = outer(s,GenericField(f))
159+
outer(f::Function,s::ShiftedNabla) = outer(GenericField(f),s)
160+
cross(s::ShiftedNabla,f::Function) = cross(s,GenericField(f))
161+
162+
function Base.broadcasted(::typeof(*),s::ShiftedNabla,f)
163+
g = s(f)
164+
Operation(_extract_grad_diag)(g)
165+
end
166+
167+
function Base.broadcasted(::typeof(*),s::ShiftedNabla,f::Function)
168+
Base.broadcasted(*,s,GenericField(f))
169+
end

src/Fields/Fields.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ using Gridap.Algebra: fill_entries!
1616

1717
using Gridap.TensorValues
1818

19-
using LinearAlgebra: mul!, Transpose
19+
using LinearAlgebra: mul!, Transpose, diag
2020

2121
using ForwardDiff
2222
using FillArrays

src/Fields/FieldsInterfaces.jl

+13-4
Original file line numberDiff line numberDiff line change
@@ -344,10 +344,15 @@ end
344344

345345
@inline transpose(f::Field) = f
346346

347-
@inline *(A::Number, B::Field) = ConstantField(A)*B
348-
@inline *(A::Field, B::Number) = A*ConstantField(B)
349-
@inline (A::Number, B::Field) = ConstantField(A)B
350-
@inline (A::Field, B::Number) = AConstantField(B)
347+
for op in (:+,:-,:*,:,:,:)
348+
@eval ($op)(a::Field,b::Number) = Operation($op)(a,ConstantField(b))
349+
@eval ($op)(a::Number,b::Field) = Operation($op)(ConstantField(a),b)
350+
end
351+
352+
#@inline *(A::Number, B::Field) = ConstantField(A)*B
353+
#@inline *(A::Field, B::Number) = A*ConstantField(B)
354+
#@inline ⋅(A::Number, B::Field) = ConstantField(A)⋅B
355+
#@inline ⋅(A::Field, B::Number) = A⋅ConstantField(B)
351356

352357
#@inline *(A::Function, B::Field) = GenericField(A)*B
353358
#@inline *(A::Field, B::Function) = GenericField(B)*A
@@ -390,6 +395,10 @@ function product_rule(::typeof(⋅),f1::VectorValue,f2::VectorValue,∇f1,∇f2)
390395
∇f1f2 + ∇f2f1
391396
end
392397

398+
function product_rule(::typeof(),f1::TensorValue,f2::VectorValue,∇f1,∇f2)
399+
∇f1f2 + ∇f2transpose(f1)
400+
end
401+
393402
for op in (:*,:,:,:)
394403
@eval begin
395404
function gradient(a::OperationField{typeof($op)})

src/TensorValues/Operations.jl

+33
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,39 @@ transpose(a::SymTensorValue) = a
598598
Meta.parse("SymTensorValue{D}($str)")
599599
end
600600

601+
###############################################################
602+
# diag
603+
###############################################################
604+
605+
function LinearAlgebra.diag(a::TensorValue{1,1})
606+
VectorValue(a.data[1])
607+
end
608+
609+
function LinearAlgebra.diag(a::TensorValue{2,2})
610+
VectorValue(a.data[1],a.data[4])
611+
end
612+
613+
function LinearAlgebra.diag(a::TensorValue{3,3})
614+
VectorValue(a.data[1],a.data[5],a.data[9])
615+
end
616+
617+
function LinearAlgebra.diag(a::TensorValue)
618+
@notimplemented
619+
end
620+
621+
###############################################################
622+
# Broadcast
623+
###############################################################
624+
# TODO more cases need to be added
625+
626+
function Base.broadcasted(f,a::VectorValue,b::VectorValue)
627+
VectorValue(map(f,a.data,b.data))
628+
end
629+
630+
function Base.broadcasted(f,a::TensorValue,b::TensorValue)
631+
TensorValue(map(f,a.data,b.data))
632+
end
633+
601634
###############################################################
602635
# Define new operations for Gridap types
603636
###############################################################

test/CellDataTests/CellFieldsTests.jl

+40
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,46 @@ test_array(∇vx,collect(∇vx))
9999
∇fx = (f)(x)
100100
test_array(∇fx,collect(∇fx))
101101

102+
103+
k = VectorValue(1.0,2.0)
104+
∇kfx = ((∇+k)(f))(x)
105+
test_array(∇kfx,collect(∇kfx))
106+
107+
∇kvx = ((∇+k)(v))(x)
108+
test_array(∇kvx,collect(∇kvx))
109+
110+
β(x) = 2*x[1]
111+
α = CellField(x->2*x,trian)
112+
ax = ((∇+k)(β*α))(x)
113+
test_array(ax,collect(ax))
114+
115+
ν = CellField(x->2*x,trian)
116+
ax =((∇-k)ν)(x)
117+
test_array(ax,collect(ax))
118+
119+
ax =((∇-k)×ν)(x)
120+
test_array(ax,collect(ax))
121+
122+
ax =((∇-k)ν)(x)
123+
test_array(ax,collect(ax))
124+
125+
ax =(∇.*ν)(x)
126+
test_array(ax,collect(ax))
127+
128+
ax =.*ν)(x)
129+
test_array(ax,collect(ax))
130+
131+
ax =((∇-k).*ν)(x)
132+
test_array(ax,collect(ax))
133+
134+
ax =(∇-k))(x)
135+
test_array(ax,collect(ax))
136+
137+
σ(x) = diagonal_tensor(VectorValue(1*x[1],2*x[2]))
138+
Fields.gradient(::typeof(σ)) = x-> ThirdOrderTensorValue{2,2,2,Float64}(1,0,0,0,0,0,0,2)
139+
ax = ((∇+k)(σα))(x)
140+
test_array(ax,collect(ax))
141+
102142
h = Operation(*)(2,f)
103143
hx = h(x)
104144
test_array(hx,2*fx)

test/FieldsTests/DiffOperatorsTests.jl

+17-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ np = 4
1212
p = Point(1,2)
1313
x = fill(p,np)
1414

15-
v = 3.0
15+
v = VectorValue(3.0,2.0)
1616
f = MockField(v)
1717

1818
@test (f) == gradient(f)
@@ -39,6 +39,20 @@ f = MockField(v)
3939

4040
@test Δ(f) ==(f)
4141

42+
@test (∇.*f)(x) != nothing
43+
44+
@test ((∇+p).*f)(x) != nothing
45+
46+
@test ((∇+p)(f))(x) == ((f) + pf)(x)
47+
48+
g(x) = 2*x[2]
49+
50+
@test ((∇+p)(g))(x) == ((GenericField(g)) + pGenericField(g))(x)
51+
52+
@test (∇+p)f != nothing
53+
@test (∇+p)×f != nothing
54+
@test (∇+p)f != nothing
55+
@test f(∇+p) != nothing
4256

4357
l = 10
4458
f = Fill(f,l)
@@ -47,6 +61,8 @@ f = Fill(f,l)
4761
@test Broadcasting(curl)(f) == Broadcasting(Operation(grad2curl))(Broadcasting(∇)(f))
4862
@test Broadcasting(ε)(f) == Broadcasting(Operation(symmetric_part))(Broadcasting(∇)(f))
4963

64+
@test evaluate(Broadcasting(∇+p)(f),x) == evaluate( Broadcasting(Operation((g,f)->g+pf))(Broadcasting(∇)(f),f) ,x)
65+
5066
# Test automatic differentiation
5167

5268
u_scal(x) = x[1]^2 + x[2]

test/FieldsTests/FieldInterfacesTests.jl

+16
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,22 @@ test_field(f,z,f.(z),grad=∇(f).(z))
250250
#c = return_cache(∇f,x)
251251
#@btime evaluate!($c,$∇f,$x)
252252

253+
Tfun(x) = diagonal_tensor(VectorValue(1*x[1],2*x[2]))
254+
bfun(x) = VectorValue(x[2],x[1])
255+
Fields.gradient(::typeof(Tfun)) = x-> ThirdOrderTensorValue{2,2,2,Float64}(1,0,0,0,0,0,0,2)
256+
a = GenericField(Tfun)
257+
b = GenericField(bfun)
258+
259+
f = Operation()(a,b)
260+
cp = Tfun(p)bfun(p)
261+
∇cp = (Tfun)(p)bfun(p) + (bfun)(p)transpose(Tfun(p))
262+
test_field(f,p,cp)
263+
test_field(f,p,cp,grad=∇cp)
264+
test_field(f,x,f.(x))
265+
test_field(f,x,f.(x),grad=(f).(x))
266+
test_field(f,z,f.(z))
267+
test_field(f,z,f.(z),grad=(f).(z))
268+
253269
afun(x) = x.+2
254270
bfun(x) = 2*x
255271

test/TensorValuesTests/OperationsTests.jl

+16
Original file line numberDiff line numberDiff line change
@@ -688,4 +688,20 @@ b = 4.0 - 3.0*im
688688
@test outer(a,b) == a*b
689689
@test inner(a,b) == a*b
690690

691+
# Broadcast
692+
a = VectorValue(1,2,3)
693+
b = VectorValue(1.,2.,3.)
694+
c = a .* b
695+
@test isa(c,VectorValue)
696+
@test c.data == map(*,a.data,b.data)
697+
698+
a = TensorValue(1,2,3,4)
699+
b = TensorValue(1.,2.,3.,4.)
700+
c = a .* b
701+
@test isa(c,TensorValue)
702+
@test c.data == map(*,a.data,b.data)
703+
704+
@test diag(a) == VectorValue(1,4)
705+
706+
691707
end # module OperationsTests

0 commit comments

Comments
 (0)