Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend base linalg functions to complex numbers #479

Merged
merged 10 commits into from
Dec 19, 2019
Merged
98 changes: 57 additions & 41 deletions src/base/linalg/owl_base_linalg_generic.ml
Original file line number Diff line number Diff line change
Expand Up @@ -192,34 +192,42 @@ let linsolve_gauss a b =
* Test: https://github.com/scipy/scipy/blob/master/scipy/linalg/tests/test_decomp.py
*)
let _lu_base a =
let k = M.kind a in
let _abs = Owl_base_dense_common._abs_elt k in
let _mul = Owl_base_dense_common._mul_elt k in
let _div = Owl_base_dense_common._div_elt k in
let _sub = Owl_base_dense_common._sub_elt k in
let _flt = Owl_base_dense_common._float_typ_elt k in
let _zero = Owl_const.zero k in
let _one = Owl_const.one k in
let lu = M.copy a in
let n = (M.shape a).(0) in
let m = (M.shape a).(1) in
assert (n = m);
let indx = Array.make n 0 in
(* implicit scaling of each row *)
let vv = Array.make n 0. in
let tiny = 1.0e-40 in
let big = ref 0. in
let temp = ref 0. in
let vv = Array.make n _zero in
let tiny = _flt 1.0e-40 in
let big = ref _zero in
let temp = ref _zero in
(* flag of row exchange *)
let d = ref 1.0 in
let imax = ref 0 in
(* loop over rows to get the implicit scaling information *)
for i = 0 to n - 1 do
big := 0.;
big := _zero;
for j = 0 to n - 1 do
temp := M.get lu [| i; j |] |> abs_float;
temp := M.get lu [| i; j |] |> _abs;
if !temp > !big then big := !temp
done;
if !big = 0. then raise Owl_exception.SINGULAR;
vv.(i) <- 1.0 /. !big
if !big = _zero then raise Owl_exception.SINGULAR;
vv.(i) <- _div _one !big
done;
for k = 0 to n - 1 do
big := 0.;
big := _zero;
(* choose suitable pivot *)
for i = k to n - 1 do
temp := (M.get lu [| i; k |] |> abs_float) *. vv.(i);
temp := _mul (M.get lu [| i; k |] |> _abs) vv.(i);
if !temp > !big
then (
big := !temp;
Expand All @@ -237,15 +245,15 @@ let _lu_base a =
d := !d *. -1.;
vv.(!imax) <- vv.(k));
indx.(k) <- !imax;
if M.get lu [| k; k |] = 0. then M.set lu [| k; k |] tiny;
if M.get lu [| k; k |] = _zero then M.set lu [| k; k |] tiny;
for i = k + 1 to n - 1 do
let tmp0 = M.get lu [| i; k |] in
let tmp1 = M.get lu [| k; k |] in
temp := tmp0 /. tmp1;
temp := _div tmp0 tmp1;
M.set lu [| i; k |] !temp;
for j = k + 1 to n - 1 do
let prev = M.get lu [| i; j |] in
M.set lu [| i; j |] (prev -. (!temp *. M.get lu [| k; j |]))
M.set lu [| i; j |] (_sub prev (_mul !temp (M.get lu [| k; j |])))
done
done
done;
Expand All @@ -255,6 +263,7 @@ let _lu_base a =
(* LU decomposition, return L, U, and permutation vector *)
let lu a =
let k = M.kind a in
let _zero = Owl_const.zero k in
let lu, indx, _ = _lu_base a in
let n = (M.shape lu).(0) in
let m = (M.shape lu).(1) in
Expand All @@ -264,18 +273,23 @@ let lu a =
for c = 0 to r - 1 do
let v = M.get lu [| r; c |] in
M.set l [| r; c |] v;
M.set lu [| r; c |] 0.
M.set lu [| r; c |] _zero
done
done;
l, lu, indx


let _lu_solve_vec a b =
let _k = M.kind a in
let _mul = Owl_base_dense_common._mul_elt _k in
let _div = Owl_base_dense_common._div_elt _k in
let _sub = Owl_base_dense_common._sub_elt _k in
let _zero = Owl_const.zero _k in
assert (Array.length (M.shape b) = 1);
let n = (M.shape a).(0) in
if (M.shape b).(0) <> n then failwith "LUdcmp::solve bad sizes";
let ii = ref 0 in
let sum = ref 0. in
let sum = ref _zero in
let x = M.copy b in
let lu, indx, _ = _lu_base a in
for i = 0 to n - 1 do
Expand All @@ -285,18 +299,18 @@ let _lu_solve_vec a b =
if !ii <> 0
then
for j = !ii - 1 to i - 1 do
sum := !sum -. (M.get lu [| i; j |] *. M.get x [| j |])
sum := _sub !sum (_mul (M.get lu [| i; j |]) (M.get x [| j |]))
done
else if !sum <> 0.
else if !sum <> _zero
then ii := !ii + 1;
M.set x [| i |] !sum
done;
for i = n - 1 downto 0 do
sum := M.get x [| i |];
for j = i + 1 to n - 1 do
sum := !sum -. (M.get lu [| i; j |] *. M.get x [| j |])
sum := _sub !sum (_mul (M.get lu [| i; j |]) (M.get x [| j |]))
done;
M.set x [| i |] (!sum /. M.get lu [| i; i |])
M.set x [| i |] (_div !sum (M.get lu [| i; i |]))
done;
x

Expand All @@ -320,14 +334,17 @@ let linsolve_lu a b =

(* Determinant of matrix a *)
let det a =
let k = M.kind a in
let _mul = Owl_base_dense_common._mul_elt k in
let _flt = Owl_base_dense_common._float_typ_elt k in
let dims_a = M.shape a in
_check_is_matrix dims_a |> ignore;
assert (dims_a.(0) = dims_a.(1));
let n = dims_a.(0) in
let lu, _, sign = _lu_base a in
let big = ref sign in
let big = ref (_flt sign) in
for i = 0 to n - 1 do
big := !big *. M.get lu [| i; i |]
big := _mul !big (M.get lu [| i; i |])
done;
!big

Expand Down Expand Up @@ -359,43 +376,42 @@ let tridiag_solve_vec a b c r =
x


(* Matrix inverse *)

(* NOTE: deprecated implementation? *)
(* let inv a =
let dims_a = M.shape a in
_check_is_matrix dims_a |> ignore;
assert (dims_a.(0) = dims_a.(1));
let n = dims_a.(0) in
let b = M.eye (M.kind a) n in
linsolve_lu a b *)

(* TODO: optimise and test *)
(*
Implementing the following algorithm:
(* TODO: optimise and test *)
(* Implementing the following algorithm:
http://www.irma-international.org/viewtitle/41011/ *)
let inv varr =
let _k = M.kind varr in
let _add = Owl_base_dense_common._add_elt _k in
let _mul = Owl_base_dense_common._mul_elt _k in
let _div = Owl_base_dense_common._div_elt _k in
let _neg = Owl_base_dense_common._neg_elt _k in
let _zero = Owl_const.zero _k in
let _one = Owl_const.one _k in
let dims = M.shape varr in
let _ = _check_is_matrix dims in
let n = Array.unsafe_get dims 0 in
if Array.unsafe_get dims 1 != n
then failwith "no inverse - the matrix is not square"
else (
let pivot_row = Array.make n 0. in
let pivot_row = Array.make n _zero in
let result_varr = M.copy varr in
for p = 0 to n - 1 do
let pivot_elem = M.get result_varr [| p; p |] in
if M.get result_varr [| p; p |] = 0.
if M.get result_varr [| p; p |] = _zero
then failwith "the matrix does not have an inverse";
(* update elements of the pivot row, save old vals *)
for j = 0 to n - 1 do
pivot_row.(j) <- M.get result_varr [| p; j |];
if j != p then M.set result_varr [| p; j |] (pivot_row.(j) /. pivot_elem)
if j != p then M.set result_varr [| p; j |] (_div pivot_row.(j) pivot_elem)
done;
(* update elements of the pivot col *)
for i = 0 to n - 1 do
if i != p
then M.set result_varr [| i; p |] (M.get result_varr [| i; p |] /. ~-.pivot_elem)
then
M.set
result_varr
[| i; p |]
(_div (M.get result_varr [| i; p |]) (_neg pivot_elem))
done;
(* update the rest of the matrix *)
for i = 0 to n - 1 do
Expand All @@ -406,12 +422,12 @@ let inv varr =
let pivot_row_elem = pivot_row.(j) in
(* use old value *)
let old_val = M.get result_varr [| i; j |] in
let new_val = old_val +. (pivot_row_elem *. pivot_col_elem) in
let new_val = _add old_val (_mul pivot_row_elem pivot_col_elem) in
M.set result_varr [| i; j |] new_val)
done
done;
(* update the pivot element *)
M.set result_varr [| p; p |] (1. /. pivot_elem)
M.set result_varr [| p; p |] (_div _one pivot_elem)
done;
result_varr)

Expand Down
85 changes: 53 additions & 32 deletions src/base/linalg/owl_base_linalg_generic.mli
Original file line number Diff line number Diff line change
Expand Up @@ -9,42 +9,45 @@ open Bigarray

type ('a, 'b) t = ('a, 'b) Owl_base_dense_ndarray_generic.t

(** {6 Core functions} *)
(** {6 Basic functions} *)

val inv : ('a, 'b) t -> ('a, 'b) t
(**
``inv x`` calculates the inverse of an invertible square matrix ``x``
such that ``x *@ x = I`` wherein ``I`` is an identity matrix. (If ``x``
is singular, ``inv`` will return a useless result.)
*)

val det : ('a, 'b) t -> 'a
(** ``det x`` computes the determinant of a square matrix ``x``. *)

val logdet : ('a, 'b) t -> 'a
(** Refer to :doc:`owl_dense_matrix_generic` *)

(** {6 Check matrix types} *)

val is_tril : ('a, 'b) t -> bool
(** ``is_tril x`` returns ``true`` if ``x`` is lower triangular otherwise ``false``. *)

val is_triu : ('a, 'b) t -> bool
(** ``is_triu x`` returns ``true`` if ``x`` is upper triangular otherwise ``false``. *)

val is_diag : ('a, 'b) t -> bool
(** ``is_diag x`` returns ``true`` if ``x`` is diagonal otherwise ``false``. *)

val is_symmetric : ('a, 'b) t -> bool
(** ``is_symmetric x`` returns ``true`` if ``x`` is symmetric otherwise ``false``. *)

val is_hermitian : (Complex.t, 'b) t -> bool
(** ``is_hermitian x`` returns ``true`` if ``x`` is hermitian otherwise ``false``. *)

val lu : (float, 'a) t -> (float, 'a) t * (float, 'a) t * int array

val det : (float, 'a) t -> float
(** {6 Factorisation} *)

val linsolve_lu : (float, 'a) t -> (float, 'b) t -> (float, 'b) t

val linsolve_gauss : (float, 'a) t -> (float, 'b) t -> (float, 'a) t * (float, 'b) t

val tridiag_solve_vec
: float array
-> float array
-> float array
-> float array
-> float array

(* TODO: change float to 'a *)
val inv : (float, 'b) t -> (float, 'b) t
(** Refer to :doc:`owl_dense_matrix_generic` *)

val logdet : ('a, 'b) t -> 'a
(** Refer to :doc:`owl_dense_matrix_generic` *)

val chol : ?upper:bool -> ('a, 'b) t -> ('a, 'b) t
(** Refer to :doc:`owl_dense_matrix_generic` *)
val lu : ('a, 'b) t -> ('a, 'b) t * ('a, 'b) t * int array
(**
``lu x -> (l, u, ipiv)`` calculates LU decomposition of ``x``. The pivoting is
used by default.
*)

val qr
: ?thin:bool
Expand All @@ -59,22 +62,27 @@ val lq : ?thin:bool -> ('a, 'b) t -> ('a, 'b) t * ('a, 'b) t
val svd : ?thin:bool -> ('a, 'b) t -> ('a, 'b) t * ('a, 'b) t * ('a, 'b) t
(** Refer to :doc:`owl_dense_matrix_generic` *)

val sylvester : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
val chol : ?upper:bool -> ('a, 'b) t -> ('a, 'b) t
(** Refer to :doc:`owl_dense_matrix_generic` *)

val lyapunov : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
(** Refer to :doc:`owl_dense_matrix_generic` *)
(** {6 Linear system of equations} *)

val discrete_lyapunov
: ?solver:[ `default | `bilinear | `direct ]
val linsolve
: ?trans:bool
-> ?typ:[ `n | `u | `l ]
-> ('a, 'b) t
-> ('a, 'b) t
-> ('a, 'b) t
(** Refer to :doc:`owl_dense_matrix_generic` *)

val linsolve
: ?trans:bool
-> ?typ:[ `n | `u | `l ]
val sylvester : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
(** Refer to :doc:`owl_dense_matrix_generic` *)

val lyapunov : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
(** Refer to :doc:`owl_dense_matrix_generic` *)

val discrete_lyapunov
: ?solver:[ `default | `bilinear | `direct ]
-> ('a, 'b) t
-> ('a, 'b) t
-> ('a, 'b) t
Expand All @@ -88,3 +96,16 @@ val care
-> (float, 'b) t
-> (float, 'b) t
(** Refer to :doc:`owl_dense_matrix_generic` *)

(** {6 Non-standard functions} *)

val linsolve_lu : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

val linsolve_gauss : (float, 'a) t -> (float, 'b) t -> (float, 'a) t * (float, 'b) t

val tridiag_solve_vec
: float array
-> float array
-> float array
-> float array
-> float array
25 changes: 20 additions & 5 deletions src/base/linalg/owl_base_linalg_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,24 @@ module type Common = sig

type int32_mat

(** {6 Basic functions} *)

val inv : mat -> mat

val det : mat -> elt

val logdet : mat -> elt

val is_triu : mat -> bool

val is_tril : mat -> bool

val is_symmetric : mat -> bool

val is_diag : mat -> bool

(** {6 Factorisation} *)

val svd : ?thin:bool -> mat -> mat * mat * mat

val chol : ?upper:bool -> mat -> mat
Expand All @@ -22,22 +38,21 @@ module type Common = sig

val lq : ?thin:bool -> mat -> mat * mat

(** {6 Linear system of equations} *)

val linsolve : ?trans:bool -> ?typ:[ `n | `u | `l ] -> mat -> mat -> mat

val sylvester : mat -> mat -> mat -> mat

val lyapunov : mat -> mat -> mat

val discrete_lyapunov : ?solver:[ `default | `direct | `bilinear ] -> mat -> mat -> mat

val linsolve : ?trans:bool -> ?typ:[ `n | `u | `l ] -> mat -> mat -> mat
end

module type Real = sig
type elt

type mat

(* TODO: implement inv for both real and complex matrices *)
val inv : mat -> mat

val care : ?diag_r:bool -> mat -> mat -> mat -> mat -> mat
end
Loading