Skip to content

Add void* to tabulate_tensor kernel #749

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
3 changes: 2 additions & 1 deletion ffcx/codegeneration/C/expressions_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
const {scalar_type}* restrict c,
const {geom_type}* restrict coordinate_dofs,
const int* restrict entity_local_index,
const uint8_t* restrict quadrature_permutation)
const uint8_t* restrict quadrature_permutation,
void* custom_data)
{{
{tabulate_expression}
}}
Expand Down
3 changes: 2 additions & 1 deletion ffcx/codegeneration/C/integrals_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
const {scalar_type}* restrict c,
const {geom_type}* restrict coordinate_dofs,
const int* restrict entity_local_index,
const uint8_t* restrict quadrature_permutation)
const uint8_t* restrict quadrature_permutation,
void* custom_data)
{{
{tabulate_tensor}
}}
Expand Down
15 changes: 11 additions & 4 deletions ffcx/codegeneration/ufcx.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,15 @@ extern "C"
/// For integrals not on interior facets, this argument has no effect and a
/// null pointer can be passed. For interior facets the array will have size 2
/// (one permutation for each cell adjacent to the facet).
/// @param[in] custom_data Custom user data passed to the tabulate function.
/// For example, a struct with additional data needed for the tabulate function.
/// See the implementation of runtime integrals for further details.
typedef void(ufcx_tabulate_tensor_float32)(
float* restrict A, const float* restrict w, const float* restrict c,
const float* restrict coordinate_dofs,
const int* restrict entity_local_index,
const uint8_t* restrict quadrature_permutation);
const uint8_t* restrict quadrature_permutation,
void* custom_data);

/// Tabulate integral into tensor A with compiled
/// quadrature rule and double precision
Expand All @@ -100,7 +104,8 @@ extern "C"
double* restrict A, const double* restrict w, const double* restrict c,
const double* restrict coordinate_dofs,
const int* restrict entity_local_index,
const uint8_t* restrict quadrature_permutation);
const uint8_t* restrict quadrature_permutation,
void* custom_data);

#ifndef __STDC_NO_COMPLEX__
/// Tabulate integral into tensor A with compiled
Expand All @@ -111,7 +116,8 @@ extern "C"
float _Complex* restrict A, const float _Complex* restrict w,
const float _Complex* restrict c, const float* restrict coordinate_dofs,
const int* restrict entity_local_index,
const uint8_t* restrict quadrature_permutation);
const uint8_t* restrict quadrature_permutation,
void* custom_data);
#endif // __STDC_NO_COMPLEX__

#ifndef __STDC_NO_COMPLEX__
Expand All @@ -123,7 +129,8 @@ extern "C"
double _Complex* restrict A, const double _Complex* restrict w,
const double _Complex* restrict c, const double* restrict coordinate_dofs,
const int* restrict entity_local_index,
const uint8_t* restrict quadrature_permutation);
const uint8_t* restrict quadrature_permutation,
void* custom_data);
#endif // __STDC_NO_COMPLEX__

typedef struct ufcx_integral
Expand Down
57 changes: 57 additions & 0 deletions ffcx/codegeneration/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
import numpy as np
import numpy.typing as npt

try:
import numba
except ImportError:
numba = None


def dtype_to_c_type(dtype: typing.Union[npt.DTypeLike, str]) -> str:
"""For a NumPy dtype, return the corresponding C type.
Expand Down Expand Up @@ -80,6 +85,58 @@ def numba_ufcx_kernel_signature(dtype: npt.DTypeLike, xdtype: npt.DTypeLike):
types.CPointer(from_dtype(xdtype)),
types.CPointer(types.intc),
types.CPointer(types.uint8),
types.CPointer(types.void),
)
except ImportError as e:
raise e


if numba is not None:

@numba.extending.intrinsic
def empty_void_pointer(typingctx):
"""Custom intrinsic to return an empty void* pointer.
This function creates a void pointer initialized to null (0).
This is used to pass a nullptr to the UFCx tabulate_tensor interface.

Args:
typingctx: The typing context.

Returns:
A Numba signature and a code generation function that returns a void pointer.
""" # noqa: D205
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be enough to add a new line after Custom intrinsic to return an empty void* pointer. to drop the noqa

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. Done.


def codegen(context, builder, signature, args):
null_ptr = context.get_constant(numba.types.voidptr, 0)
return null_ptr

sig = numba.types.voidptr()
return sig, codegen

@numba.extending.intrinsic
def get_void_pointer(typingctx, arr):
"""Custom intrinsic to get a void* pointer from a NumPy array.

This function takes a NumPy array and returns a void pointer to the array's data.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually don't know how numpy lays out its data in an ndarray - could we be a bit more precise here on what this void ptr points to?

This is used to pass custom data organised in a NumPy array
to the UFCx tabulate_tensor interface.

Args:
typingctx: The typing context.
arr: The NumPy array to get the void pointer from.

Returns:
A Numba signature and a code generation function that returns a void pointer
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, more precision on array's data.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have expanded on the comment and I have added a test.

to the array's data.
"""
if not isinstance(arr, numba.types.Array):
raise TypeError("Expected a NumPy array")

def codegen(context, builder, signature, args):
[arr] = args
raw_ptr = numba.core.cgutils.alloca_once_value(builder, arr)
void_ptr = builder.bitcast(raw_ptr, context.get_value_type(numba.types.voidptr))
return void_ptr

sig = numba.types.voidptr(arr)
return sig, codegen
3 changes: 3 additions & 0 deletions test/test_add_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def test_additive_facet_integral(dtype, compile_args):
ffi.cast(f"{c_xtype} *", coords.ctypes.data),
ffi.cast("int *", facets.ctypes.data),
ffi.cast("uint8_t *", perm.ctypes.data),
ffi.NULL,
)
assert np.isclose(A.sum(), np.sqrt(12) * (i + 1))

Expand Down Expand Up @@ -158,6 +159,7 @@ def test_additive_cell_integral(dtype, compile_args):
ffi.cast(f"{c_xtype} *", coords.ctypes.data),
ffi.NULL,
ffi.NULL,
ffi.NULL,
)

A0 = np.array(A)
Expand All @@ -169,6 +171,7 @@ def test_additive_cell_integral(dtype, compile_args):
ffi.cast(f"{c_xtype} *", coords.ctypes.data),
ffi.NULL,
ffi.NULL,
ffi.NULL,
)

assert np.all(np.isclose(A, (i + 2) * A0))
7 changes: 7 additions & 0 deletions test/test_jit_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def test_matvec(compile_args):
ffi.cast(f"{c_xtype} *", coords.ctypes.data),
ffi.cast("int *", entity_index.ctypes.data),
ffi.cast("uint8_t *", quad_perm.ctypes.data),
ffi.NULL,
)

# Check the computation against correct NumPy value
Expand Down Expand Up @@ -133,6 +134,7 @@ def test_rank1(compile_args):
ffi.cast(f"{c_xtype} *", coords.ctypes.data),
ffi.cast("int *", entity_index.ctypes.data),
ffi.cast("uint8_t *", quad_perm.ctypes.data),
ffi.NULL,
)

f = np.array([[1.0, 2.0, 3.0], [-4.0, -5.0, 6.0]])
Expand Down Expand Up @@ -203,6 +205,7 @@ def test_elimiate_zero_tables_tensor(compile_args):
ffi.cast(f"{c_xtype} *", coords.ctypes.data),
ffi.cast("int *", entity_index.ctypes.data),
ffi.cast("uint8_t *", quad_perm.ctypes.data),
ffi.NULL,
)

def exact_expr(x):
Expand Down Expand Up @@ -261,6 +264,7 @@ def test_grad_constant(compile_args):
ffi.cast(f"{c_xtype} *", coords.ctypes.data),
ffi.cast("int *", entity_index.ctypes.data),
ffi.cast("uint8_t *", quad_perm.ctypes.data),
ffi.NULL,
)

assert output[0] == pytest.approx(consts[1] * 2 * points[0, 0])
Expand Down Expand Up @@ -316,6 +320,7 @@ def test_facet_expression(compile_args):
ffi.cast(f"{c_xtype} *", coords.ctypes.data),
ffi.cast("int *", entity_index.ctypes.data),
ffi.cast("uint8_t *", quad_perm.ctypes.data),
ffi.NULL,
)
# Assert that facet normal is perpendicular to tangent
assert np.isclose(np.dot(output, tangent), 0)
Expand Down Expand Up @@ -366,6 +371,7 @@ def check_expression(expression_class, output_shape, entity_values, reference_va
ffi_data["coords"],
ffi_data["entity_index"],
ffi_data["quad_perm"],
ffi.NULL,
)
np.testing.assert_allclose(output, ref_val)

Expand Down Expand Up @@ -430,5 +436,6 @@ def test_facet_geometry_expressions_3D(compile_args):
ffi.cast(f"{c_xtype} *", coords.ctypes.data),
ffi.cast("int *", entity_index.ctypes.data),
ffi.cast("uint8_t *", quad_perm.ctypes.data),
ffi.NULL,
)
np.testing.assert_allclose(output, np.asarray(ref_fev)[:3, :])
17 changes: 17 additions & 0 deletions test/test_jit_forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def test_laplace_bilinear_form_2d(dtype, expected_result, compile_args):
ffi.cast(f"{c_xtype} *", coords.ctypes.data),
ffi.NULL,
ffi.NULL,
ffi.NULL,
)

assert np.allclose(A, np.trace(kappa_value) * expected_result)
Expand Down Expand Up @@ -233,6 +234,7 @@ def test_helmholtz_form_2d(dtype, expected_result, compile_args):
ffi.cast(f"{c_xtype} *", coords.ctypes.data),
ffi.NULL,
ffi.NULL,
ffi.NULL,
)

np.testing.assert_allclose(A, expected_result)
Expand Down Expand Up @@ -305,6 +307,7 @@ def test_laplace_bilinear_form_3d(dtype, expected_result, compile_args):
ffi.cast(f"{c_xtype} *", coords.ctypes.data),
ffi.NULL,
ffi.NULL,
ffi.NULL,
)

assert np.allclose(A, expected_result)
Expand Down Expand Up @@ -342,6 +345,7 @@ def test_form_coefficient(compile_args):
ffi.cast("double *", coords.ctypes.data),
ffi.NULL,
ffi.cast("uint8_t *", perm.ctypes.data),
ffi.NULL,
)

A_analytic = np.array([[2, 1, 1], [1, 2, 1], [1, 1, 2]], dtype=np.float64) / 24.0
Expand Down Expand Up @@ -452,6 +456,7 @@ def test_interior_facet_integral(dtype, compile_args):
ffi.cast(f"{c_xtype} *", coords.ctypes.data),
ffi.cast("int *", facets.ctypes.data),
ffi.cast("uint8_t *", perms.ctypes.data),
ffi.NULL,
)


Expand Down Expand Up @@ -512,6 +517,7 @@ def test_conditional(dtype, compile_args):
ffi.cast(f"{c_xtype} *", coords.ctypes.data),
ffi.NULL,
ffi.NULL,
ffi.NULL,
)

expected_result = np.array([[2, -1, -1], [-1, 1, 0], [-1, 0, 1]], dtype=dtype)
Expand All @@ -530,6 +536,7 @@ def test_conditional(dtype, compile_args):
ffi.cast(f"{c_xtype} *", coords.ctypes.data),
ffi.NULL,
ffi.NULL,
ffi.NULL,
)

expected_result = np.ones(3, dtype=dtype)
Expand Down Expand Up @@ -581,6 +588,7 @@ def test_custom_quadrature(compile_args):
ffi.cast("double *", coords.ctypes.data),
ffi.NULL,
ffi.NULL,
ffi.NULL,
)

# Check that A is diagonal
Expand Down Expand Up @@ -690,6 +698,7 @@ def test_lagrange_triangle(compile_args, order, dtype, sym_fun, ufl_fun):
ffi.cast(f"{c_xtype} *", coords.ctypes.data),
ffi.NULL,
ffi.NULL,
ffi.NULL,
)

# Check that the result is the same as for sympy
Expand Down Expand Up @@ -817,6 +826,7 @@ def test_lagrange_tetrahedron(compile_args, order, dtype, sym_fun, ufl_fun):
ffi.cast(f"{c_xtype} *", coords.ctypes.data),
ffi.NULL,
ffi.NULL,
ffi.NULL,
)

# Check that the result is the same as for sympy
Expand Down Expand Up @@ -852,6 +862,7 @@ def test_prism(compile_args):
ffi.cast("double *", coords.ctypes.data),
ffi.NULL,
ffi.NULL,
ffi.NULL,
)

assert np.isclose(sum(b), 0.5)
Expand Down Expand Up @@ -898,6 +909,7 @@ def test_complex_operations(compile_args):
ffi.cast(f"{c_xtype} *", coords.ctypes.data),
ffi.NULL,
ffi.NULL,
ffi.NULL,
)

expected_result = np.array(
Expand All @@ -918,6 +930,7 @@ def test_complex_operations(compile_args):
ffi.cast(f"{c_xtype} *", coords.ctypes.data),
ffi.NULL,
ffi.NULL,
ffi.NULL,
)

assert np.allclose(J_2, expected_result)
Expand Down Expand Up @@ -980,6 +993,7 @@ def test_interval_vertex_quadrature(compile_args):
ffi.cast("double *", coords.ctypes.data),
ffi.NULL,
ffi.NULL,
ffi.NULL,
)
assert np.isclose(J[0], (0.5 * a + 0.5 * b) * np.abs(b - a))

Expand Down Expand Up @@ -1033,6 +1047,7 @@ def test_facet_vertex_quadrature(compile_args):
ffi.cast("double *", coords.ctypes.data),
ffi.cast("int *", facets.ctypes.data),
ffi.NULL,
ffi.NULL,
)
solutions.append(J[0])
# Test against exact result
Expand Down Expand Up @@ -1084,6 +1099,7 @@ def test_manifold_derivatives(compile_args):
ffi.cast("double *", coords.ctypes.data),
ffi.NULL,
ffi.cast("uint8_t *", perm.ctypes.data),
ffi.NULL,
)

assert np.isclose(J[0], 0.0)
Expand Down Expand Up @@ -1186,6 +1202,7 @@ def tabulate_tensor(ele_type, V_cell_type, W_cell_type, coeffs):
ffi.cast(f"{c_xtype} *", coords.ctypes.data),
ffi.cast("int *", facet.ctypes.data),
ffi.cast("uint8_t *", perm.ctypes.data),
ffi.NULL,
)

return A
Expand Down
1 change: 1 addition & 0 deletions test/test_submesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def compute_tensor(forms: list[ufl.form.Form], dtype: str, compile_args: list[st
ffi.cast(f"{c_xtype} *", coords.ctypes.data),
ffi.NULL,
ffi.NULL,
ffi.NULL,
)
return A

Expand Down
2 changes: 2 additions & 0 deletions test/test_tensor_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def test_bilinear_form(dtype, P, cell_type):
ffi.cast(f"{c_xtype} *", coords.ctypes.data),
ffi.NULL,
ffi.NULL,
ffi.NULL,
)

# Use sum factorization
Expand All @@ -125,6 +126,7 @@ def test_bilinear_form(dtype, P, cell_type):
ffi.cast(f"{c_xtype} *", coords.ctypes.data),
ffi.NULL,
ffi.NULL,
ffi.NULL,
)

np.testing.assert_allclose(A, A1, rtol=1e-6, atol=1e-6)