Skip to content

Commit

Permalink
Replace pybind11 with nanobind (#1299)
Browse files Browse the repository at this point in the history
  • Loading branch information
petiaccja authored Aug 17, 2023
1 parent 36b4821 commit fa30664
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 114 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,15 @@ def _type_string(type_: ts.TypeSpec) -> str:
return f"std::tuple<{','.join(_type_string(t) for t in type_.types)}>"
elif isinstance(type_, ts.FieldType):
ndims = len(type_.dims)
buffer_t = "pybind11::buffer"
dtype = cpp_interface.render_scalar_type(type_.dtype)
shape = f"nanobind::shape<{', '.join(['nanobind::any'] * ndims)}>"
buffer_t = f"nanobind::ndarray<{dtype}, {shape}>"
origin_t = f"std::tuple<{', '.join(['ptrdiff_t'] * ndims)}>"
return f"std::pair<{buffer_t}, {origin_t}>"
elif isinstance(type_, ts.ScalarType):
return cpp_interface.render_scalar_type(type_)
else:
raise ValueError(f"Type '{type_}' is not supported in pybind11 interfaces.")
raise ValueError(f"Type '{type_}' is not supported in nanobind interfaces.")


class BindingCodeGenerator(TemplatedGenerator):
Expand Down Expand Up @@ -131,7 +133,7 @@ class BindingCodeGenerator(TemplatedGenerator):

BindingModule = as_jinja(
"""\
PYBIND11_MODULE({{name}}, module) {
NB_MODULE({{name}}, module) {
module.doc() = "{{doc}}";
{{"\n".join(functions)}}
}\
Expand All @@ -149,9 +151,7 @@ def visit_BufferSID(self, sid: BufferSID, **kwargs):
dims = [self.visit(dim) for dim in sid.dimensions]
origin = f"{sid.source_buffer}.second"

as_sid = f"gridtools::as_sid<{cpp_interface.render_scalar_type(sid.scalar_type)},\
{sid.dimensions.__len__()},\
gridtools::sid::unknown_kind>({pybuffer})"
as_sid = f"gridtools::nanobind::as_sid({pybuffer})"
shifted = f"gridtools::sid::shift_sid_origin({as_sid}, {origin})"
renamed = f"gridtools::sid::rename_numbered_dimensions<{', '.join(dims)}>({shifted})"
return renamed
Expand Down Expand Up @@ -187,7 +187,7 @@ def make_argument(name: str, type_: ts.TypeSpec) -> str | BufferSID | CompositeS
elif isinstance(type_, ts.ScalarType):
return name
else:
raise ValueError(f"Type '{type_}' is not supported in pybind11 interfaces.")
raise ValueError(f"Type '{type_}' is not supported in nanobind interfaces.")


def create_bindings(
Expand All @@ -210,9 +210,10 @@ def create_bindings(
file_binding = BindingFile(
callee_header_file=f"{program_source.entry_point.name}.{program_source.language_settings.header_extension}",
header_files=[
"pybind11/pybind11.h",
"pybind11/stl.h",
"gridtools/storage/adapter/python_sid_adapter.hpp",
"nanobind/nanobind.h",
"nanobind/stl/tuple.h",
"nanobind/stl/pair.h",
"nanobind/ndarray.h",
"gridtools/sid/composite.hpp",
"gridtools/sid/unknown_kind.hpp",
"gridtools/sid/rename_dimensions.hpp",
Expand All @@ -221,6 +222,7 @@ def create_bindings(
"gridtools/fn/unstructured.hpp",
"gridtools/fn/cartesian.hpp",
"gridtools/fn/backend/naive.hpp",
"gridtools/storage/adapter/nanobind_adapter.hpp",
],
wrapper=WrapperFunction(
name=wrapper_name,
Expand Down Expand Up @@ -258,7 +260,7 @@ def create_bindings(

return stages.BindingSource(
src,
(interface.LibraryDependency("pybind11", "2.9.2"),),
(interface.LibraryDependency("nanobind", "1.4.0"),),
)


Expand Down
23 changes: 20 additions & 3 deletions src/gt4py/next/otf/compilation/build_systems/cmake_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ def visit_FindDependency(self, dep: FindDependency):
import pybind11

return f"find_package(pybind11 CONFIG REQUIRED PATHS {pybind11.get_cmake_dir()} NO_DEFAULT_PATH)"
case "nanobind":
import nanobind

py = "find_package(Python COMPONENTS Interpreter Development REQUIRED)"
nb = f"find_package(nanobind CONFIG REQUIRED PATHS {nanobind.cmake_dir()} NO_DEFAULT_PATHS)"
return py + "\n" + nb
case "gridtools":
import gridtools_cpp

Expand All @@ -93,13 +99,24 @@ def visit_LinkDependency(self, dep: LinkDependency):
match dep.name:
case "pybind11":
lib_name = "pybind11::module"
case "nanobind":
lib_name = "nanobind-static"
case "gridtools":
lib_name = "GridTools::fn_naive"
case _:
raise ValueError("Library {name} is not supported".format(name=dep.name))
return "target_link_libraries({target} PUBLIC {lib})".format(
target=dep.target, lib=lib_name
)

cfg = ""
if dep.name == "nanobind":
cfg = "\n".join(
[
"nanobind_build_library(nanobind-static)",
f"nanobind_compile_options({dep.target})",
f"nanobind_link_options({dep.target})",
]
)
lnk = f"target_link_libraries({dep.target} PUBLIC {lib_name})"
return cfg + "\n" + lnk


def generate_cmakelists_source(
Expand Down
45 changes: 29 additions & 16 deletions src/gt4py/next/otf/compilation/build_systems/compiledb.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import json
import pathlib
import re
import shutil
import subprocess
from typing import Optional

Expand Down Expand Up @@ -124,6 +125,20 @@ def build(self):
self._run_build()

def _write_files(self):
def ignore_not_libraries(folder: str, children: list[str]) -> list[str]:
pattern = r"((lib.*\.a)|(.*\.lib))"
libraries = [child for child in children if re.match(pattern, child)]
folders = [child for child in children if (pathlib.Path(folder) / child).is_dir()]
ignored = list(set(children) - set(libraries) - set(folders))
return ignored

shutil.copytree(
self.compile_commands_cache.parent,
self.root_path,
ignore=ignore_not_libraries,
dirs_exist_ok=True,
)

for name, content in self.source_files.items():
(self.root_path / name).write_text(content, encoding="utf-8")

Expand All @@ -140,7 +155,7 @@ def _run_config(self):
compile_db = json.loads(self.compile_commands_cache.read_text())

(self.root_path / "build").mkdir(exist_ok=True)
(self.root_path / "bin").mkdir(exist_ok=True)
(self.root_path / "build" / "bin").mkdir(exist_ok=True)

for entry in compile_db:
for key, value in entry.items():
Expand All @@ -155,7 +170,7 @@ def _run_config(self):
build_data.write_data(
build_data.BuildData(
status=build_data.BuildStatus.CONFIGURED,
module=pathlib.Path(compile_db[-1]["output"]),
module=pathlib.Path(compile_db[-1]["directory"]) / compile_db[-1]["output"],
entry_point_name=self.program_name,
),
self.root_path,
Expand All @@ -171,7 +186,7 @@ def _run_build(self):
log_file_pointer.write(entry["command"] + "\n")
subprocess.check_call(
entry["command"],
cwd=self.root_path,
cwd=entry["directory"],
shell=True,
stdout=log_file_pointer,
stderr=log_file_pointer,
Expand Down Expand Up @@ -251,19 +266,17 @@ def _cc_create_compiledb(
program_name=name,
)

prototype_project._write_files()
prototype_project._run_config()
prototype_project.build()

log_file = cache_path / "log_compiledb.txt"

with log_file.open("w") as log_file_pointer:
commands = json.loads(
subprocess.check_output(
["ninja", "-t", "compdb"],
cwd=cache_path / "build",
stderr=log_file_pointer,
).decode("utf-8")
)
commands_json_str = subprocess.check_output(
["ninja", "-t", "compdb"],
cwd=cache_path / "build",
stderr=log_file_pointer,
).decode("utf-8")
commands = json.loads(commands_json_str)

compile_db = [
cmd for cmd in commands if name in pathlib.Path(cmd["file"]).stem and cmd["command"]
Expand All @@ -272,24 +285,24 @@ def _cc_create_compiledb(
assert compile_db

for entry in compile_db:
entry["directory"] = "$SRC_PATH"
entry["directory"] = entry["directory"].replace(str(cache_path), "$SRC_PATH")
entry["command"] = (
entry["command"]
.replace(f"CMakeFiles/{name}.dir", "build")
.replace(f"CMakeFiles/{name}.dir", ".")
.replace(str(cache_path), "$SRC_PATH")
.replace(f"{name}.cpp", "$BINDINGS_FILE")
.replace(f"{name}", "$NAME")
.replace("-I$SRC_PATH/build/_deps", f"-I{cache_path}/build/_deps")
)
entry["file"] = (
entry["file"]
.replace(f"CMakeFiles/{name}.dir", "build")
.replace(f"CMakeFiles/{name}.dir", ".")
.replace(str(cache_path), "$SRC_PATH")
.replace(f"{name}.cpp", "$BINDINGS_FILE")
)
entry["output"] = (
entry["output"]
.replace(f"CMakeFiles/{name}.dir", "build")
.replace(f"CMakeFiles/{name}.dir", ".")
.replace(f"{name}.cpp", "$BINDINGS_FILE")
.replace(f"{name}", "$NAME")
)
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/program_processors/runners/gtfn_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from gt4py.eve.utils import content_hash
from gt4py.next import common
from gt4py.next.otf import languages, recipes, stages, workflow
from gt4py.next.otf.binding import cpp_interface, pybind
from gt4py.next.otf.binding import cpp_interface, nanobind
from gt4py.next.otf.compilation import cache, compiler
from gt4py.next.otf.compilation.build_systems import compiledb
from gt4py.next.program_processors import otf_compile_executor
Expand Down Expand Up @@ -102,7 +102,7 @@ def compilation_hash(otf_closure: stages.ProgramCall) -> int:

GTFN_DEFAULT_WORKFLOW = recipes.OTFCompileWorkflow(
translation=GTFN_DEFAULT_TRANSLATION_STEP,
bindings=pybind.bind_source,
bindings=nanobind.bind_source,
compilation=GTFN_DEFAULT_COMPILE_STEP,
decoration=convert_args,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np

from gt4py.next.otf import workflow
from gt4py.next.otf.binding import pybind
from gt4py.next.otf.binding import nanobind
from gt4py.next.otf.compilation import cache, compiler
from gt4py.next.otf.compilation.build_systems import cmake, compiledb

Expand All @@ -28,7 +28,7 @@

def test_gtfn_cpp_with_cmake(program_source_with_name):
example_program_source = program_source_with_name("gtfn_cpp_with_cmake")
build_the_program = workflow.make_step(pybind.bind_source).chain(
build_the_program = workflow.make_step(nanobind.bind_source).chain(
compiler.Compiler(
cache_strategy=cache.Strategy.SESSION, builder_factory=cmake.CMakeFactory()
),
Expand All @@ -46,7 +46,7 @@ def test_gtfn_cpp_with_cmake(program_source_with_name):

def test_gtfn_cpp_with_compiledb(program_source_with_name):
example_program_source = program_source_with_name("gtfn_cpp_with_compiledb")
build_the_program = workflow.make_step(pybind.bind_source).chain(
build_the_program = workflow.make_step(nanobind.bind_source).chain(
compiler.Compiler(
cache_strategy=cache.Strategy.SESSION,
builder_factory=compiledb.CompiledbFactory(),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2023, ETH Zurich
# All rights reserved.
#
# This file is part of the GT4Py project and the GridTools framework.
# GT4Py is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the
# Free Software Foundation, either version 3 of the License, or any later
# version. See the LICENSE.txt file at the top-level directory of this
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

from gt4py.next.otf.binding import nanobind

from next_tests.unit_tests.otf_tests.compilation_tests.build_systems_tests.conftest import (
program_source_example,
)


def test_bindings(program_source_example):
module = nanobind.create_bindings(program_source_example)
assert module.library_deps[0].name == "nanobind"
77 changes: 0 additions & 77 deletions tests/next_tests/unit_tests/otf_tests/binding_tests/test_pybind.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import gt4py.next as gtx
import gt4py.next.type_system.type_specifications as ts
from gt4py.next.otf import languages, stages
from gt4py.next.otf.binding import cpp_interface, interface, pybind
from gt4py.next.otf.binding import cpp_interface, interface, nanobind
from gt4py.next.otf.compilation import cache


Expand Down Expand Up @@ -99,7 +99,7 @@ def program_source_example():
def compilable_source_example(program_source_example):
return stages.CompilableSource(
program_source=program_source_example,
binding_source=pybind.create_bindings(program_source_example),
binding_source=nanobind.create_bindings(program_source_example),
)


Expand Down

0 comments on commit fa30664

Please sign in to comment.