diff --git a/include/pybind11/cast.h b/include/pybind11/cast.h index 8b5beb0ef6..13ea5efa73 100644 --- a/include/pybind11/cast.h +++ b/include/pybind11/cast.h @@ -876,6 +876,10 @@ struct handle_type_name { static constexpr auto name = const_name(); }; template <> +struct handle_type_name { + static constexpr auto name = const_name("object"); +}; +template <> struct handle_type_name { static constexpr auto name = const_name("bool"); }; diff --git a/include/pybind11/stl_bind.h b/include/pybind11/stl_bind.h index 4d0854f6a8..1eb81f38d2 100644 --- a/include/pybind11/stl_bind.h +++ b/include/pybind11/stl_bind.h @@ -15,6 +15,8 @@ #include "operators.h" #include +#include +#include #include #include @@ -483,6 +485,137 @@ void vector_buffer(Class_ &cl) { cl, detail::any_of...>{}); } +// Issue #3986 and #4529: map C++ types to Python types with typing strings +template +struct type_mapper { + using py_type = T; + static std::string py_name() { return detail::type_info_description(typeid(T)); } +}; + +template <> +struct type_mapper { + using py_type = pybind11::none; + static std::string py_name() { + constexpr auto descr = const_name("None"); + return descr.text; + } +}; + +template <> +struct type_mapper { + using py_type = pybind11::bool_; + static std::string py_name() { + constexpr auto descr = const_name("bool"); + return descr.text; + } +}; + +template +struct type_mapper::value && !is_std_char_type::value>> { + using py_type + = conditional_t::value, pybind11::float_, pybind11::int_>; + static std::string py_name() { + constexpr auto descr = const_name::value>("int", "float"); + return descr.text; + } +}; + +template +struct type_mapper> { + using py_type = std::complex::py_type>; + static std::string py_name() { + constexpr auto descr = const_name("complex"); + return descr.text; + } +}; + +template +struct type_mapper::value>> { + using py_type = pybind11::str; + static std::string py_name() { + constexpr auto descr = const_name(PYBIND11_STRING_NAME); + return descr.text; + } +}; + +template +struct type_mapper::value>> { + using py_type = T; + static std::string py_name() { + constexpr auto descr = handle_type_name::name; + return descr.text; + } +}; + +template +struct type_mapper> : public type_mapper {}; + +template +struct type_mapper> : public type_mapper {}; + +template +struct type_mapper, + enable_if_t::value>> { + using py_type = pybind11::str; + static std::string py_name() { + constexpr auto descr = const_name(PYBIND11_STRING_NAME); + return descr.text; + } +}; + +#ifdef PYBIND11_HAS_STRING_VIEW +template +struct type_mapper, + enable_if_t::value>> { + using py_type = pybind11::str; + static std::string py_name() { + constexpr auto descr = const_name(PYBIND11_STRING_NAME); + return descr.text; + } +}; +#endif + +template +struct type_mapper> { + using py_type + = std::tuple::py_type, typename type_mapper::py_type>; + static std::string py_name() { + return "tuple[" + type_mapper::py_name() + ", " + type_mapper::py_name() + "]"; + } +}; + +template +struct type_mapper> { + using py_type = std::tuple::py_type...>; + static std::string py_name() { + std::vector names = {type_mapper::py_name()...}; + std::ostringstream s; + s << "tuple["; + for (size_t i = 0; i < names.size(); ++i) { + s << (i != 0 ? ", " : "") << names[i]; + } + s << "]"; + return s.str(); + } +}; + +template +struct type_mapper> { + using retval_type = conditional_t::value, std::nullptr_t, Return>; + using py_type = std::function::py_type( + typename type_mapper::py_type...)>; + static std::string py_name() { + std::vector names = {type_mapper::py_name()...}; + std::ostringstream s; + s << "Callable[["; + for (size_t i = 0; i < names.size(); ++i) { + s << (i != 0 ? ", " : "") << names[i]; + } + s << "], " << type_mapper::py_name() << "]"; + return s.str(); + } +}; + PYBIND11_NAMESPACE_END(detail) // @@ -649,8 +782,7 @@ template struct keys_view { virtual size_t len() = 0; virtual iterator iter() = 0; - virtual bool contains(const KeyType &k) = 0; - virtual bool contains(const object &k) = 0; + virtual bool contains(const handle &k) = 0; virtual ~keys_view() = default; }; @@ -673,8 +805,13 @@ struct KeysViewImpl : public KeysView { explicit KeysViewImpl(Map &map) : map(map) {} size_t len() override { return map.size(); } iterator iter() override { return make_key_iterator(map.begin(), map.end()); } - bool contains(const typename Map::key_type &k) override { return map.find(k) != map.end(); } - bool contains(const object &) override { return false; } + bool contains(const handle &k) override { + try { + return map.find(k.template cast()) != map.end(); + } catch (const cast_error &) { + return false; + } + } Map ↦ }; @@ -702,9 +839,11 @@ class_ bind_map(handle scope, const std::string &name, Args && using MappedType = typename Map::mapped_type; using StrippedKeyType = detail::remove_cvref_t; using StrippedMappedType = detail::remove_cvref_t; - using KeysView = detail::keys_view; - using ValuesView = detail::values_view; - using ItemsView = detail::items_view; + using PyKeyType = typename detail::type_mapper::py_type; + using PyMappedType = typename detail::type_mapper::py_type; + using KeysView = detail::keys_view; + using ValuesView = detail::values_view; + using ItemsView = detail::items_view; using Class_ = class_; // If either type is a non-module-local bound type then make the map binding non-local as well; @@ -718,20 +857,10 @@ class_ bind_map(handle scope, const std::string &name, Args && } Class_ cl(scope, name.c_str(), pybind11::module_local(local), std::forward(args)...); - static constexpr auto key_type_descr = detail::make_caster::name; - static constexpr auto mapped_type_descr = detail::make_caster::name; - std::string key_type_name(key_type_descr.text), mapped_type_name(mapped_type_descr.text); - - // If key type isn't properly wrapped, fall back to C++ names - if (key_type_name == "%") { - key_type_name = detail::type_info_description(typeid(KeyType)); - } - // Similarly for value type: - if (mapped_type_name == "%") { - mapped_type_name = detail::type_info_description(typeid(MappedType)); - } + std::string key_type_name = detail::type_mapper::py_name(); + std::string mapped_type_name = detail::type_mapper::py_name(); - // Wrap KeysView[KeyType] if it wasn't already wrapped + // Wrap KeysView[PyKeyType] if it wasn't already wrapped if (!detail::get_type_info(typeid(KeysView))) { class_ keys_view( scope, ("KeysView[" + key_type_name + "]").c_str(), pybind11::module_local(local)); @@ -741,10 +870,7 @@ class_ bind_map(handle scope, const std::string &name, Args && keep_alive<0, 1>() /* Essential: keep view alive while iterator exists */ ); keys_view.def("__contains__", - static_cast(&KeysView::contains)); - // Fallback for when the object is not of the key type - keys_view.def("__contains__", - static_cast(&KeysView::contains)); + static_cast(&KeysView::contains)); } // Similarly for ValuesView: if (!detail::get_type_info(typeid(ValuesView))) { diff --git a/tests/test_stl_binders.cpp b/tests/test_stl_binders.cpp index e52a03b6d2..32802eec50 100644 --- a/tests/test_stl_binders.cpp +++ b/tests/test_stl_binders.cpp @@ -187,6 +187,35 @@ TEST_SUBMODULE(stl_binders, m) { py::bind_map>(m, "MapStringDouble"); py::bind_map>(m, "UnorderedMapStringDouble"); + // test_map_view_types + py::bind_map>(m, "MapStringFloat"); + py::bind_map>(m, "UnorderedMapStringFloat"); + py::bind_map>(m, "MapInt16Double"); + py::bind_map>(m, "MapInt32Double"); + py::bind_map>(m, "MapInt64Double"); + py::bind_map>(m, "MapUInt64Double"); + py::bind_map, double>>(m, "MapPairShortShortDouble"); + py::bind_map, std::complex>>( + m, "MapPairShortLongComplexFloat"); + py::bind_map, std::complex>>( + m, "MapPairLongShortComplexDouble"); + py::bind_map, std::complex>>( + m, "MapTupleLongLongComplexDouble"); + py::bind_map>>(m, + "MapCharFunctionFloatIntFloat"); + py::bind_map>>( + m, "MapStringFunctionDoubleLongDouble"); + py::bind_map>>( + m, "MapStringFunctionVoidLongDouble"); + py::bind_map>(m, "MapStringNone"); + + py::bind_map, int>>>(m, "MapIntMapIntIntInt"); + py::bind_map, long>>>(m, "MapIntMapIntIntLong"); + py::bind_map, long>>>(m, "MapIntMapLongIntLong"); + + py::bind_map>(m, "MapPyIntInt"); + py::bind_map>(m, "MapPyIntPyInt"); + // test_map_string_double_const py::bind_map>(m, "MapStringDoubleConst"); py::bind_map>(m, diff --git a/tests/test_stl_binders.py b/tests/test_stl_binders.py index c00d45b926..4bbf2d0850 100644 --- a/tests/test_stl_binders.py +++ b/tests/test_stl_binders.py @@ -314,6 +314,8 @@ def test_map_delitem(): def test_map_view_types(): map_string_double = m.MapStringDouble() unordered_map_string_double = m.UnorderedMapStringDouble() + map_string_float = m.MapStringFloat() + unordered_map_string_float = m.UnorderedMapStringFloat() map_string_double_const = m.MapStringDoubleConst() unordered_map_string_double_const = m.UnorderedMapStringDoubleConst() @@ -321,20 +323,123 @@ def test_map_view_types(): assert map_string_double.values().__class__.__name__ == "ValuesView[float]" assert map_string_double.items().__class__.__name__ == "ItemsView[str, float]" - keys_type = type(map_string_double.keys()) - assert type(unordered_map_string_double.keys()) is keys_type - assert type(map_string_double_const.keys()) is keys_type - assert type(unordered_map_string_double_const.keys()) is keys_type + assert map_string_float.keys().__class__.__name__ == "KeysView[str]" + assert map_string_float.values().__class__.__name__ == "ValuesView[float]" + assert map_string_float.items().__class__.__name__ == "ItemsView[str, float]" + + keys_type = map_string_double.keys().__class__ + assert unordered_map_string_double.keys().__class__ is keys_type + assert map_string_double_const.keys().__class__ is keys_type + assert unordered_map_string_double_const.keys().__class__ is keys_type + assert map_string_float.keys().__class__ is keys_type + assert unordered_map_string_float.keys().__class__ is keys_type + + values_type = map_string_double.values().__class__ + assert unordered_map_string_double.values().__class__ is values_type + assert map_string_double_const.values().__class__ is values_type + assert unordered_map_string_double_const.values().__class__ is values_type + assert map_string_float.values().__class__ is values_type + assert unordered_map_string_float.values().__class__ is values_type + + items_type = map_string_double.items().__class__ + assert unordered_map_string_double.items().__class__ is items_type + assert map_string_double_const.items().__class__ is items_type + assert unordered_map_string_double_const.items().__class__ is items_type + assert map_string_float.items().__class__ is items_type + assert unordered_map_string_float.items().__class__ is items_type + + map_int16_double = m.MapInt16Double() + map_int32_double = m.MapInt32Double() + map_int64_double = m.MapInt64Double() + map_uint64_double = m.MapUInt64Double() + + assert map_int16_double.keys().__class__.__name__ == "KeysView[int]" + assert map_int16_double.keys().__class__ is map_int32_double.keys().__class__ + assert map_int16_double.keys().__class__ is map_int64_double.keys().__class__ + assert map_int16_double.keys().__class__ is map_uint64_double.keys().__class__ + + assert (1 << 50) not in map_uint64_double.keys() + map_uint64_double[1 << 50] = 1.0 + assert (1 << 50) in map_uint64_double.keys() + + map_pair_short_short_double = m.MapPairShortShortDouble() + map_pair_short_long_complex_float = m.MapPairShortLongComplexFloat() + map_pair_long_short_complex_double = m.MapPairLongShortComplexDouble() + map_tuple_long_long_complex_double = m.MapTupleLongLongComplexDouble() + assert ( + map_pair_short_long_complex_float.keys().__class__.__name__ + == "KeysView[tuple[int, int]]" + ) + assert ( + map_pair_short_long_complex_float.values().__class__.__name__ + == "ValuesView[complex]" + ) + assert ( + map_pair_short_long_complex_float.items().__class__.__name__ + == "ItemsView[tuple[int, int], complex]" + ) + assert ( + map_pair_short_long_complex_float.keys().__class__ + is map_pair_short_short_double.keys().__class__ + ) + assert ( + map_pair_short_long_complex_float.keys().__class__ + is map_pair_long_short_complex_double.keys().__class__ + ) + assert ( + map_pair_short_long_complex_float.keys().__class__ + is map_tuple_long_long_complex_double.keys().__class__ + ) + assert ( + map_pair_short_long_complex_float.values().__class__ + is map_pair_long_short_complex_double.values().__class__ + ) + assert ( + map_pair_short_long_complex_float.values().__class__ + is map_tuple_long_long_complex_double.values().__class__ + ) - values_type = type(map_string_double.values()) - assert type(unordered_map_string_double.values()) is values_type - assert type(map_string_double_const.values()) is values_type - assert type(unordered_map_string_double_const.values()) is values_type + map_char_func_fif = m.MapCharFunctionFloatIntFloat() + map_string_func_dld = m.MapStringFunctionDoubleLongDouble() + assert map_char_func_fif.keys().__class__.__name__ == "KeysView[str]" + assert ( + map_char_func_fif.values().__class__.__name__ + == "ValuesView[Callable[[int, float], float]]" + ) + assert ( + map_char_func_fif.items().__class__.__name__ + == "ItemsView[str, Callable[[int, float], float]]" + ) + assert map_char_func_fif.keys().__class__ is map_string_func_dld.keys().__class__ + assert ( + map_char_func_fif.values().__class__ is map_string_func_dld.values().__class__ + ) + assert map_char_func_fif.items().__class__ is map_string_func_dld.items().__class__ + + map_string_func_vld = m.MapStringFunctionVoidLongDouble() + map_string_none = m.MapStringNone() + assert ( + map_string_func_vld.values().__class__.__name__ + == "ValuesView[Callable[[int, float], None]]" + ) + assert map_string_none.values().__class__.__name__ == "ValuesView[None]" + + map_int_map_int_int_int = m.MapIntMapIntIntInt() + map_int_map_int_int_long = m.MapIntMapIntIntLong() + map_int_map_long_int_long = m.MapIntMapLongIntLong() + assert ( + map_int_map_int_int_int.values().__class__ + is map_int_map_int_int_long.values().__class__ + ) + assert ( + map_int_map_int_int_long.values().__class__ + is not map_int_map_long_int_long.values().__class__ + ) - items_type = type(map_string_double.items()) - assert type(unordered_map_string_double.items()) is items_type - assert type(map_string_double_const.items()) is items_type - assert type(unordered_map_string_double_const.items()) is items_type + map_pyint_int = m.MapPyIntInt() + map_pyint_pyint = m.MapPyIntPyInt() + assert map_pyint_int.items().__class__.__name__ == "ItemsView[int, int]" + assert map_pyint_int.items().__class__ is map_pyint_pyint.items().__class__ def test_recursive_vector():