@@ -116,3 +116,54 @@ Equivalent to
116
116
"""
117
117
cross (:: typeof (∇),f:: Field ) = curl (f)
118
118
cross (:: typeof (∇),f:: Function ) = curl (f)
119
+
120
+ _extract_grad_diag (x:: TensorValue ) = diag (x)
121
+ _extract_grad_diag (x) = @notimplemented
122
+
123
+ function Base. broadcasted (:: typeof (* ),:: typeof (∇),f)
124
+ g = ∇ (f)
125
+ Operation (_extract_grad_diag)(g)
126
+ end
127
+
128
+ function Base. broadcasted (:: typeof (* ),:: typeof (∇),f:: Function )
129
+ Base. broadcasted (* ,∇,GenericField (f))
130
+ end
131
+
132
+ struct ShiftedNabla{N,T}
133
+ v:: VectorValue{N,T}
134
+ end
135
+
136
+ (+ )(:: typeof (∇),v:: VectorValue ) = ShiftedNabla (v)
137
+ (+ )(v:: VectorValue ,:: typeof (∇)) = ShiftedNabla (v)
138
+ (- )(:: typeof (∇),v:: VectorValue ) = ShiftedNabla (- v)
139
+
140
+ function (s:: ShiftedNabla )(f)
141
+ Operation ((a,b)-> a+ s. v⊗ b)(gradient (f),f)
142
+ end
143
+
144
+ (s:: ShiftedNabla )(f:: Function ) = s (GenericField (f))
145
+
146
+ function evaluate! (cache,k:: Broadcasting{<:ShiftedNabla} ,f)
147
+ s = k. f
148
+ g = Broadcasting (∇)(f)
149
+ Broadcasting (Operation ((a,b)-> a+ s. v⊗ b))(g,f)
150
+ end
151
+
152
+ dot (s:: ShiftedNabla ,f) = Operation (tr)(s (f))
153
+ outer (s:: ShiftedNabla ,f) = s (f)
154
+ outer (f,s:: ShiftedNabla ) = transpose (gradient (f))
155
+ cross (s:: ShiftedNabla ,f) = Operation (grad2curl)(s (f))
156
+
157
+ dot (s:: ShiftedNabla ,f:: Function ) = dot (s,GenericField (f))
158
+ outer (s:: ShiftedNabla ,f:: Function ) = outer (s,GenericField (f))
159
+ outer (f:: Function ,s:: ShiftedNabla ) = outer (GenericField (f),s)
160
+ cross (s:: ShiftedNabla ,f:: Function ) = cross (s,GenericField (f))
161
+
162
+ function Base. broadcasted (:: typeof (* ),s:: ShiftedNabla ,f)
163
+ g = s (f)
164
+ Operation (_extract_grad_diag)(g)
165
+ end
166
+
167
+ function Base. broadcasted (:: typeof (* ),s:: ShiftedNabla ,f:: Function )
168
+ Base. broadcasted (* ,s,GenericField (f))
169
+ end
0 commit comments