From 07abf703c6a44846665d48ceb26ef379c2c3f3e7 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Wed, 3 Apr 2024 05:48:06 -0400 Subject: [PATCH 1/5] failing test --- samples/tests/higher/mod.rs | 33 +++++++++++++++++++++++++++++++++ samples/tests/mod.rs | 2 ++ 2 files changed, 35 insertions(+) create mode 100644 samples/tests/higher/mod.rs diff --git a/samples/tests/higher/mod.rs b/samples/tests/higher/mod.rs new file mode 100644 index 0000000..881d1ef --- /dev/null +++ b/samples/tests/higher/mod.rs @@ -0,0 +1,33 @@ +#![feature(autodiff)] + +// A direct translation of +// https://enzyme.mit.edu/index.fcgi/julia/stable/generated/autodiff/#Forward-over-reverse + +#[autodiff(ddf, Forward, Dual, Dual, Dual, Dual)] +fn df2(x: &[f32;2], dx: &mut [f32;2], out: &mut [f32;1], dout: &mut [f32;1]) { + df(x, dx, out, dout); +} + +#[autodiff(df, ReverseFirst, Duplicated, Duplicated)] +fn f(x: &[f32;2], y: &mut [f32;1]) { + y[0] = x[0] * x[0] + x[1] * x[0] +} + +fn sum2(x: &f32, out: &mut f32) { *out = 2.0 * x; } + +fn main() { + let mut y = [0.0]; + let mut x = [2.0, 2.0]; + + let mut dy = [0.0]; + let mut dx = [1.0, 0.0]; + + let mut bx = [0.0, 0.0]; + let mut by = [1.0]; + let mut dbx = [0.0, 0.0]; + let mut dby = [0.0]; + + ddf(&x, &mut bx, &mut dx, &mut dbx, + &mut y, &mut by, &mut dy, &mut dby); +} + diff --git a/samples/tests/mod.rs b/samples/tests/mod.rs index 95f9e3c..457923a 100644 --- a/samples/tests/mod.rs +++ b/samples/tests/mod.rs @@ -3,3 +3,5 @@ mod forward; mod neohookean; mod reverse; +mod higher; + From c1cde59bf934f8e0d3672bb64bb9551126217b64 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Wed, 3 Apr 2024 13:37:22 -0400 Subject: [PATCH 2/5] simplify --- samples/tests/mod.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/samples/tests/mod.rs b/samples/tests/mod.rs index 457923a..b3ab1d8 100644 --- a/samples/tests/mod.rs +++ b/samples/tests/mod.rs @@ -1,7 +1,7 @@ #![feature(autodiff)] -mod forward; -mod neohookean; -mod reverse; +//mod forward; +//mod neohookean; +//mod reverse; mod higher; From 12da7da1ca2dd19b6b16d1dfd89db19f8eb92eb4 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Wed, 3 Apr 2024 13:39:30 -0400 Subject: [PATCH 3/5] running it as test --- samples/tests/higher/mod.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/samples/tests/higher/mod.rs b/samples/tests/higher/mod.rs index 881d1ef..1fe6a7c 100644 --- a/samples/tests/higher/mod.rs +++ b/samples/tests/higher/mod.rs @@ -15,6 +15,7 @@ fn f(x: &[f32;2], y: &mut [f32;1]) { fn sum2(x: &f32, out: &mut f32) { *out = 2.0 * x; } +#[test] fn main() { let mut y = [0.0]; let mut x = [2.0, 2.0]; From becb2e9eb086644d6838252dc4b2a184c179bbc3 Mon Sep 17 00:00:00 2001 From: Jed Brown Date: Wed, 3 Apr 2024 11:43:35 -0600 Subject: [PATCH 4/5] remove syntax warnings --- samples/tests/higher/mod.rs | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/samples/tests/higher/mod.rs b/samples/tests/higher/mod.rs index 1fe6a7c..488af09 100644 --- a/samples/tests/higher/mod.rs +++ b/samples/tests/higher/mod.rs @@ -1,6 +1,4 @@ -#![feature(autodiff)] - -// A direct translation of +// A direct translation of // https://enzyme.mit.edu/index.fcgi/julia/stable/generated/autodiff/#Forward-over-reverse #[autodiff(ddf, Forward, Dual, Dual, Dual, Dual)] @@ -13,12 +11,10 @@ fn f(x: &[f32;2], y: &mut [f32;1]) { y[0] = x[0] * x[0] + x[1] * x[0] } -fn sum2(x: &f32, out: &mut f32) { *out = 2.0 * x; } - #[test] fn main() { let mut y = [0.0]; - let mut x = [2.0, 2.0]; + let x = [2.0, 2.0]; let mut dy = [0.0]; let mut dx = [1.0, 0.0]; From d8f4369057823f8e45aaf3b36c40368e06c89e00 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Wed, 3 Apr 2024 17:54:17 -0400 Subject: [PATCH 5/5] add new test --- samples/tests/mod.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/samples/tests/mod.rs b/samples/tests/mod.rs index b3ab1d8..457923a 100644 --- a/samples/tests/mod.rs +++ b/samples/tests/mod.rs @@ -1,7 +1,7 @@ #![feature(autodiff)] -//mod forward; -//mod neohookean; -//mod reverse; +mod forward; +mod neohookean; +mod reverse; mod higher;