Skip to content

Commit

Permalink
Merge pull request #557 from tachukao/master
Browse files Browse the repository at this point in the history
Fixed bug in Jacobian with different input/output dimensions
  • Loading branch information
jzstark authored Nov 12, 2020
2 parents f164864 + 060eb15 commit 6c1c601
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 9 deletions.
20 changes: 17 additions & 3 deletions src/base/algodiff/owl_algodiff_check.ml
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,29 @@ module Make (Algodiff : Owl_algodiff_generic_sig.Sig) = struct

module Reverse = struct
let finite_difference_grad =
let two = F A.(float_to_elt 2.) in
let h x = F A.(float_to_elt x) in
let two = h 2. in
let eight = h 8. in
let twelve = h 12. in
fun ~order ~f ?(eps = 1E-5) x d ->
let eps = F A.(float_to_elt eps) in
let dx = Maths.(eps * d) in
let df1 = Maths.(f (x + dx) - f (x - dx)) in
match order with
| `eighth ->
let twodx = Maths.(h 2. * dx) in
let threedx = Maths.(h 3. * dx) in
let fourdx = Maths.(h 4. * dx) in
let df2 = Maths.(f (x + twodx) - f (x - twodx)) in
let df3 = Maths.(f (x + threedx) - f (x - threedx)) in
let df4 = Maths.(f (x + fourdx) - f (x - fourdx)) in
Maths.(
((h (4. /. 5.) * df1)
+ (h (-1. /. 5.) * df2)
+ (h (4. /. 105.) * df3)
+ (h (-1. /. 280.) * df4))
/ eps)
| `fourth ->
let eight = F A.(float_to_elt 8.) in
let twelve = F A.(float_to_elt 12.) in
let df2 =
let twodx = Maths.(two * dx) in
Maths.(f (x + twodx) - f (x - twodx))
Expand Down
2 changes: 1 addition & 1 deletion src/base/algodiff/owl_algodiff_check.mli
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ module Make (AD : Owl_algodiff_generic_sig.Sig) : sig
module Reverse : sig
val check
: threshold:float
-> order:[ `second | `fourth ]
-> order:[ `second | `fourth | `eighth ]
-> ?verbose:bool
-> ?eps:float
-> f:(AD.t -> AD.t)
Expand Down
10 changes: 5 additions & 5 deletions src/base/algodiff/owl_algodiff_generic.ml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ module Make (A : Owl_types_ndarray_algodiff.Sig) = struct
let jacobianTv f x v = jacobianTv' f x v |> snd

(* jacobian of f (vector -> vector) at x, both x and y are row vectors, also return the
original value *)
original value. The jacobian J satisfies dx = J df *)
let jacobian' =
let dim_typ x =
match primal' x with
Expand Down Expand Up @@ -143,14 +143,14 @@ module Make (A : Owl_types_ndarray_algodiff.Sig) = struct
match v with
| Arr v ->
if m > n
then A.copy_col_to (A.transpose v) z i
else A.copy_row_to v z i
then A.copy_row_to v z i
else A.copy_col_to (A.transpose v) z i
| _ -> failwith "error: jacobian");
Arr z
| DF _ | DR _ ->
if m > n
then Ops.Maths.concatenate ~axis:1 jvps
else Ops.Maths.concatenate ~axis:0 jvps
then Ops.Maths.concatenate ~axis:0 jvps
else Ops.Maths.concatenate ~axis:0 jvps |> Ops.Maths.transpose
in
primal y, z

Expand Down

0 comments on commit 6c1c601

Please sign in to comment.