23 #include <type_traits>
29 # pragma warning(push)
30 # pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
37 static_assert(
sizeof(::
pybind11::ssize_t) ==
sizeof(Py_intptr_t),
"ssize_t != Py_intptr_t");
98 pybind11_fail(std::string(
"NumPy type info missing for ") + tinfo.name());
103 return get_type_info(
typeid(
typename std::remove_cv<T>::type), throw_if_missing);
108 ptr = &get_or_create_shared_data<numpy_internals>(
"_numpy_internals");
125 template <
typename Concrete,
typename T,
typename... Ts,
typename... Ints>
127 return sizeof(Concrete) ==
sizeof(T) ? I :
platform_lookup<Concrete, Ts...>(Is...);
157 NPY_INT32_ = platform_lookup<std::int32_t, long, int, short>(
159 NPY_UINT32_ = platform_lookup<std::uint32_t, unsigned long, unsigned int, unsigned short>(
161 NPY_INT64_ = platform_lookup<std::int64_t, long, long long, int>(
163 NPY_UINT64_ = platform_lookup<std::uint64_t, unsigned long, unsigned long long, unsigned int>(
185 PyObject *(*PyArray_DescrFromType_)(int);
186 PyObject *(*PyArray_NewFromDescr_)
187 (PyTypeObject *, PyObject *, int, Py_intptr_t
const *,
188 Py_intptr_t
const *,
void *, int, PyObject *);
190 PyObject *(*PyArray_DescrNewFromType_)(int);
192 PyObject *(*PyArray_NewCopy_)(PyObject *, int);
196 PyObject *(*PyArray_DescrFromScalar_)(PyObject *);
197 PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
201 Py_intptr_t *, PyObject **, PyObject *);
202 PyObject *(*PyArray_Squeeze_)(PyObject *);
208 API_PyArray_GetNDArrayCFeatureVersion = 211,
209 API_PyArray_Type = 2,
210 API_PyArrayDescr_Type = 3,
211 API_PyVoidArrType_Type = 39,
212 API_PyArray_DescrFromType = 45,
213 API_PyArray_DescrFromScalar = 57,
214 API_PyArray_FromAny = 69,
215 API_PyArray_Resize = 80,
216 API_PyArray_CopyInto = 82,
217 API_PyArray_NewCopy = 85,
218 API_PyArray_NewFromDescr = 94,
219 API_PyArray_DescrNewFromType = 96,
220 API_PyArray_DescrConverter = 174,
221 API_PyArray_EquivTypes = 182,
222 API_PyArray_GetArrayParamsFromObject = 278,
223 API_PyArray_Squeeze = 136,
224 API_PyArray_SetBaseObject = 282
229 auto c =
m.attr(
"_ARRAY_API");
230 #if PY_MAJOR_VERSION >= 3
231 void **api_ptr = (
void **) PyCapsule_GetPointer(c.ptr(), NULL);
233 void **api_ptr = (
void **) PyCObject_AsVoidPtr(c.ptr());
236 #define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
238 if (
api.PyArray_GetNDArrayCFeatureVersion_() < 0x7)
239 pybind11_fail(
"pybind11 numpy support requires numpy >= 1.7.0");
283 template <
typename T>
struct is_complex : std::false_type { };
284 template <
typename T>
struct is_complex<std::complex<T>> : std::true_type { };
301 static constexpr
size_t extent = N;
309 static constexpr
auto extents = _<array_info<T>::is_array>(
321 std::is_standard_layout<T>,
322 #if !defined(__GNUG__) || defined(_LIBCPP_VERSION) || defined(_GLIBCXX_USE_CXX11_ABI)
325 std::is_trivially_copyable<T>,
328 std::is_trivially_destructible<T>,
336 std::is_standard_layout<T>,
341 template <
ssize_t Dim = 0,
typename Strides,
typename... Ix>
343 return i * strides[Dim] + byte_offset_unsafe<Dim + 1>(strides, index...);
351 template <
typename T, s
size_t Dims>
364 template <
bool Dyn = Dynamic>
366 :
data_{
reinterpret_cast<const unsigned char *
>(
data)},
dims_{Dims} {
373 template <
bool Dyn = Dynamic>
383 template <
typename... Ix>
const T &
operator()(Ix... index)
const {
385 "Invalid number of indices for unchecked array reference");
392 template <s
size_t D = Dims,
typename = enable_if_t<D == 1 || Dynamic>>
408 template <
bool Dyn = Dynamic>
410 return std::accumulate(
shape_.begin(),
shape_.end(), (
ssize_t) 1, std::multiplies<ssize_t>());
412 template <
bool Dyn = Dynamic>
424 template <
typename T, s
size_t Dims>
428 using ConstBase::ConstBase;
432 using ConstBase::operator();
433 using ConstBase::operator[];
437 static_assert(
ssize_t{
sizeof...(Ix)} == Dims || Dynamic,
438 "Invalid number of indices for unchecked array reference");
446 template <s
size_t D = Dims,
typename = enable_if_t<D == 1 || Dynamic>>
453 template <
typename T, s
size_t Dim>
455 static_assert(Dim == 0 && Dim > 0 ,
"unchecked array proxy object is not castable");
457 template <
typename T, s
size_t Dim>
472 explicit dtype(
const std::string &format) {
473 m_ptr = from_args(pybind11::str(format)).release().ptr();
480 args[
"names"] = names;
481 args[
"formats"] = formats;
482 args[
"offsets"] = offsets;
483 args[
"itemsize"] = pybind11::int_(itemsize);
484 m_ptr = from_args(
args).release().ptr();
489 PyObject *ptr =
nullptr;
490 if (!detail::npy_api::get().PyArray_DescrConverter_(
args.
ptr(), &ptr) || !ptr)
492 return reinterpret_steal<dtype>(ptr);
497 return detail::npy_format_descriptor<typename std::remove_cv<T>::type>
::dtype();
516 static object _dtype_from_pep3118() {
518 .attr(
"_dtype_from_pep3118").
cast<
object>().release().ptr();
519 return reinterpret_borrow<object>(obj);
529 std::vector<field_descr> field_descriptors;
531 for (
auto field : attr(
"fields").attr(
"items")()) {
532 auto spec = field.cast<
tuple>();
533 auto name = spec[0].cast<pybind11::str>();
534 auto format = spec[1].cast<
tuple>()[0].cast<dtype>();
535 auto offset = spec[1].cast<
tuple>()[1].cast<pybind11::int_>();
536 if (!
len(
name) && format.kind() ==
'V')
538 field_descriptors.push_back({(
PYBIND11_STR_TYPE)
name, format.strip_padding(format.itemsize()), offset});
541 std::sort(field_descriptors.begin(), field_descriptors.end(),
542 [](
const field_descr& a,
const field_descr& b) {
543 return a.offset.cast<int>() < b.offset.cast<int>();
546 list names, formats, offsets;
547 for (
auto&
descr : field_descriptors) {
552 return dtype(names, formats, offsets, itemsize);
561 c_style = detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_,
562 f_style = detail::npy_api::NPY_ARRAY_F_CONTIGUOUS_,
563 forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
580 pybind11_fail(
"NumPy: shape ndim doesn't match strides ndim");
585 if (isinstance<array>(
base))
590 flags = detail::npy_api::NPY_ARRAY_WRITEABLE_;
593 auto &
api = detail::npy_api::get();
594 auto tmp = reinterpret_steal<object>(
api.PyArray_NewFromDescr_(
597 reinterpret_cast<Py_intptr_t*
>(
shape->data()),
598 reinterpret_cast<Py_intptr_t*
>(
strides->data()),
599 const_cast<void *
>(
ptr),
flags,
nullptr));
604 api.PyArray_SetBaseObject_(tmp.ptr(),
base.inc_ref().ptr());
606 tmp = reinterpret_steal<object>(
api.PyArray_NewCopy_(tmp.ptr(), -1 ));
609 m_ptr = tmp.release().ptr();
619 template <
typename T>
623 template <
typename T>
627 template <
typename T>
704 template<
typename... Ix>
const void*
data(Ix... index)
const {
720 fail_dim_check(
sizeof...(index),
"too many indices for an array");
739 if (Dims >= 0 &&
ndim() != Dims)
740 throw std::domain_error(
"array has incorrect number of dimensions: " + std::to_string(
ndim()) +
741 "; expected " + std::to_string(Dims));
752 template <
typename T,
ssize_t Dims = -1> detail::unchecked_reference<T, Dims>
unchecked() const & {
753 if (Dims >= 0 &&
ndim() != Dims)
754 throw std::domain_error(
"array has incorrect number of dimensions: " + std::to_string(
ndim()) +
755 "; expected " + std::to_string(Dims));
761 auto&
api = detail::npy_api::get();
762 return reinterpret_steal<array>(
api.PyArray_Squeeze_(
m_ptr));
769 detail::npy_api::PyArray_Dims d = {
771 reinterpret_cast<Py_intptr_t*
>(new_shape->data()),
772 int(new_shape->size())
775 auto new_array = reinterpret_steal<object>(
776 detail::npy_api::get().PyArray_Resize_(
m_ptr, &d,
int(refcheck), -1)
779 if (isinstance<array>(new_array)) { *
this =
std::move(new_array); }
785 auto result = reinterpret_steal<array>(
raw_array(h.
ptr(), ExtraFlags));
795 throw index_error(
msg +
": " + std::to_string(dim) +
796 " (ndim = " + std::to_string(
ndim()) +
")");
806 throw std::domain_error(
"array is not writeable");
817 throw index_error(std::string(
"index ") + std::to_string(
i) +
818 " is out of bounds for axis " + std::to_string(axis) +
819 " with size " + std::to_string(*
shape));
826 if (
ptr ==
nullptr) {
827 PyErr_SetString(PyExc_ValueError,
"cannot create a pybind11::array from a nullptr");
830 return detail::npy_api::get().PyArray_FromAny_(
831 ptr,
nullptr, 0, 0, detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags,
nullptr);
835 template <
typename T,
int ExtraFlags = array::forcecast>
class array_t :
public array {
837 struct private_ctor {};
842 static_assert(!detail::array_info<T>::is_array,
"Array types cannot be used with array_t");
852 if (!
m_ptr) PyErr_Clear();
883 template<
typename... Ix>
const T*
data(Ix... index)
const {
884 return static_cast<const T*
>(
array::data(index...));
892 template<
typename... Ix>
const T&
at(Ix... index)
const {
912 return array::mutable_unchecked<T, Dims>();
923 return array::unchecked<T, Dims>();
936 const auto &
api = detail::npy_api::get();
937 return api.PyArray_Check_(h.
ptr())
945 if (
ptr ==
nullptr) {
946 PyErr_SetString(PyExc_ValueError,
"cannot create a pybind11::array_t from a nullptr");
949 return detail::npy_api::get().PyArray_FromAny_(
951 detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags,
nullptr);
955 template <
typename T>
958 return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::format();
963 static std::string
format() {
return std::to_string(N) +
"s"; }
966 static std::string
format() {
return std::to_string(N) +
"s"; }
969 template <
typename T>
973 typename std::remove_cv<typename std::underlying_type<T>::type>
::type>::format();
977 template <
typename T>
980 using namespace detail;
987 template <
typename T,
int ExtraFlags>
992 if (!convert && !type::check_(src))
994 value = type::ensure(src);
995 return static_cast<bool>(
value);
1004 template <
typename T>
1011 template <
typename T,
typename =
void>
1014 template <
typename T>
1021 template <
typename T>
1024 _(
"numpy.float") + _<sizeof(T)*8>(),
_(
"numpy.longdouble")
1028 template <
typename T>
1032 _(
"numpy.complex") + _<sizeof(typename T::value_type)*16>(),
_(
"numpy.longcomplex")
1036 template <
typename T>
1041 constexpr
static const int values[15] = {
1050 static constexpr
int value = values[detail::is_fmt_numeric<T>::index];
1054 return reinterpret_steal<pybind11::dtype>(ptr);
1059 #define PYBIND11_DECL_CHAR_FMT \
1060 static constexpr auto name = _("S") + _<N>(); \
1061 static pybind11::dtype dtype() { return pybind11::dtype(std::string("S") + std::to_string(N)); }
1064 #undef PYBIND11_DECL_CHAR_FMT
1084 static constexpr
auto name = base_descr::name;
1085 static pybind11::dtype
dtype() {
return base_descr::dtype(); }
1098 const std::type_info& tinfo,
ssize_t itemsize,
1099 bool (*direct_converter)(PyObject *,
void *&)) {
1107 std::vector<field_descriptor> ordered_fields(
std::move(fields));
1108 std::sort(ordered_fields.begin(), ordered_fields.end(),
1111 list names, formats, offsets;
1112 for (
auto& field : ordered_fields) {
1114 pybind11_fail(std::string(
"NumPy: unsupported field dtype: `") +
1115 field.name +
"` @ " + tinfo.name());
1117 formats.
append(field.descr);
1118 offsets.
append(pybind11::int_(field.offset));
1120 auto dtype_ptr = pybind11::dtype(names, formats, offsets, itemsize).release().ptr();
1130 std::ostringstream oss;
1137 for (
auto& field : ordered_fields) {
1138 if (field.offset > offset)
1139 oss << (field.offset - offset) <<
'x';
1140 oss << field.format <<
':' << field.name <<
':';
1141 offset = field.offset + field.size;
1143 if (itemsize > offset)
1144 oss << (itemsize - offset) <<
'x';
1146 auto format_str = oss.str();
1151 if (!
api.PyArray_EquivTypes_(dtype_ptr,
arr.dtype().ptr()))
1154 auto tindex = std::type_index(tinfo);
1160 static_assert(
is_pod_struct<T>::value,
"Attempt to use a non-POD or unimplemented POD type as a numpy dtype");
1165 return reinterpret_borrow<pybind11::dtype>(dtype_ptr());
1175 sizeof(T), &direct_converter);
1179 static PyObject* dtype_ptr() {
1184 static bool direct_converter(PyObject *obj,
void*&
value) {
1186 if (!PyObject_TypeCheck(obj,
api.PyVoidArrType_Type_))
1188 if (
auto descr = reinterpret_steal<object>(
api.PyArray_DescrFromScalar_(obj))) {
1189 if (
api.PyArray_EquivTypes_(dtype_ptr(),
descr.ptr())) {
1198 #ifdef __CLION_IDE__ // replace heavy macro with dummy code for the IDE (doesn't affect code)
1199 # define PYBIND11_NUMPY_DTYPE(Type, ...) ((void)0)
1200 # define PYBIND11_NUMPY_DTYPE_EX(Type, ...) ((void)0)
1203 #define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name) \
1204 ::pybind11::detail::field_descriptor { \
1205 Name, offsetof(T, Field), sizeof(decltype(std::declval<T>().Field)), \
1206 ::pybind11::format_descriptor<decltype(std::declval<T>().Field)>::format(), \
1207 ::pybind11::detail::npy_format_descriptor<decltype(std::declval<T>().Field)>::dtype() \
1211 #define PYBIND11_FIELD_DESCRIPTOR(T, Field) PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, #Field)
1215 #define PYBIND11_EVAL0(...) __VA_ARGS__
1216 #define PYBIND11_EVAL1(...) PYBIND11_EVAL0 (PYBIND11_EVAL0 (PYBIND11_EVAL0 (__VA_ARGS__)))
1217 #define PYBIND11_EVAL2(...) PYBIND11_EVAL1 (PYBIND11_EVAL1 (PYBIND11_EVAL1 (__VA_ARGS__)))
1218 #define PYBIND11_EVAL3(...) PYBIND11_EVAL2 (PYBIND11_EVAL2 (PYBIND11_EVAL2 (__VA_ARGS__)))
1219 #define PYBIND11_EVAL4(...) PYBIND11_EVAL3 (PYBIND11_EVAL3 (PYBIND11_EVAL3 (__VA_ARGS__)))
1220 #define PYBIND11_EVAL(...) PYBIND11_EVAL4 (PYBIND11_EVAL4 (PYBIND11_EVAL4 (__VA_ARGS__)))
1221 #define PYBIND11_MAP_END(...)
1222 #define PYBIND11_MAP_OUT
1223 #define PYBIND11_MAP_COMMA ,
1224 #define PYBIND11_MAP_GET_END() 0, PYBIND11_MAP_END
1225 #define PYBIND11_MAP_NEXT0(test, next, ...) next PYBIND11_MAP_OUT
1226 #define PYBIND11_MAP_NEXT1(test, next) PYBIND11_MAP_NEXT0 (test, next, 0)
1227 #define PYBIND11_MAP_NEXT(test, next) PYBIND11_MAP_NEXT1 (PYBIND11_MAP_GET_END test, next)
1228 #if defined(_MSC_VER) && !defined(__clang__) // MSVC is not as eager to expand macros, hence this workaround
1229 #define PYBIND11_MAP_LIST_NEXT1(test, next) \
1230 PYBIND11_EVAL0 (PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0))
1232 #define PYBIND11_MAP_LIST_NEXT1(test, next) \
1233 PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0)
1235 #define PYBIND11_MAP_LIST_NEXT(test, next) \
1236 PYBIND11_MAP_LIST_NEXT1 (PYBIND11_MAP_GET_END test, next)
1237 #define PYBIND11_MAP_LIST0(f, t, x, peek, ...) \
1238 f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST1) (f, t, peek, __VA_ARGS__)
1239 #define PYBIND11_MAP_LIST1(f, t, x, peek, ...) \
1240 f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST0) (f, t, peek, __VA_ARGS__)
1242 #define PYBIND11_MAP_LIST(f, t, ...) \
1243 PYBIND11_EVAL (PYBIND11_MAP_LIST1 (f, t, __VA_ARGS__, (), 0))
1245 #define PYBIND11_NUMPY_DTYPE(Type, ...) \
1246 ::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
1247 (::std::vector<::pybind11::detail::field_descriptor> \
1248 {PYBIND11_MAP_LIST (PYBIND11_FIELD_DESCRIPTOR, Type, __VA_ARGS__)})
1250 #if defined(_MSC_VER) && !defined(__clang__)
1251 #define PYBIND11_MAP2_LIST_NEXT1(test, next) \
1252 PYBIND11_EVAL0 (PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0))
1254 #define PYBIND11_MAP2_LIST_NEXT1(test, next) \
1255 PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0)
1257 #define PYBIND11_MAP2_LIST_NEXT(test, next) \
1258 PYBIND11_MAP2_LIST_NEXT1 (PYBIND11_MAP_GET_END test, next)
1259 #define PYBIND11_MAP2_LIST0(f, t, x1, x2, peek, ...) \
1260 f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT (peek, PYBIND11_MAP2_LIST1) (f, t, peek, __VA_ARGS__)
1261 #define PYBIND11_MAP2_LIST1(f, t, x1, x2, peek, ...) \
1262 f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT (peek, PYBIND11_MAP2_LIST0) (f, t, peek, __VA_ARGS__)
1264 #define PYBIND11_MAP2_LIST(f, t, ...) \
1265 PYBIND11_EVAL (PYBIND11_MAP2_LIST1 (f, t, __VA_ARGS__, (), 0))
1267 #define PYBIND11_NUMPY_DTYPE_EX(Type, ...) \
1268 ::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
1269 (::std::vector<::pybind11::detail::field_descriptor> \
1270 {PYBIND11_MAP2_LIST (PYBIND11_FIELD_DESCRIPTOR_EX, Type, __VA_ARGS__)})
1272 #endif // __CLION_IDE__
1283 : p_ptr(reinterpret_cast<char*>(ptr)), m_strides(strides.
size()) {
1284 m_strides.back() =
static_cast<value_type>(strides.back());
1285 for (
size_type i = m_strides.size() - 1;
i != 0; --
i) {
1288 m_strides[j] = strides[j] + m_strides[
i] - strides[
i] * s;
1293 p_ptr += m_strides[dim];
1302 container_type m_strides;
1311 : m_shape(shape.
size()), m_index(shape.
size(), 0),
1312 m_common_iterator() {
1315 for (
size_t i = 0;
i < shape.size(); ++
i)
1316 m_shape[
i] = shape[
i];
1319 for (
size_t i = 0;
i < N; ++
i)
1320 init_common_iterator(buffers[
i], shape, m_common_iterator[
i], strides);
1324 for (
size_t j = m_index.size(); j != 0; --j) {
1326 if (++m_index[
i] != m_shape[
i]) {
1327 increment_common_iterator(
i);
1336 template <
size_t K,
class T =
void> T*
data()
const {
1337 return reinterpret_cast<T*
>(m_common_iterator[K].data());
1345 const container_type &shape,
1347 container_type &strides) {
1348 auto buffer_shape_iter =
buffer.shape.rbegin();
1349 auto buffer_strides_iter =
buffer.strides.rbegin();
1350 auto shape_iter = shape.rbegin();
1351 auto strides_iter = strides.rbegin();
1353 while (buffer_shape_iter !=
buffer.shape.rend()) {
1354 if (*shape_iter == *buffer_shape_iter)
1355 *strides_iter = *buffer_strides_iter;
1359 ++buffer_shape_iter;
1360 ++buffer_strides_iter;
1365 std::fill(strides_iter, strides.rend(), 0);
1369 void increment_common_iterator(
size_t dim) {
1370 for (
auto &
iter : m_common_iterator)
1371 iter.increment(dim);
1374 container_type m_shape;
1375 container_type m_index;
1376 std::array<common_iter, N> m_common_iterator;
1388 return std::max(res, buf.
ndim);
1392 shape.resize((
size_t) ndim, 1);
1396 for (
size_t i = 0;
i < N; ++
i) {
1397 auto res_iter = shape.rbegin();
1398 auto end = buffers[
i].shape.rend();
1399 for (
auto shape_iter = buffers[
i].shape.rbegin(); shape_iter !=
end; ++shape_iter, ++res_iter) {
1400 const auto &dim_size_in = *shape_iter;
1401 auto &dim_size_out = *res_iter;
1404 if (dim_size_out == 1)
1405 dim_size_out = dim_size_in;
1406 else if (dim_size_in != 1 && dim_size_in != dim_size_out)
1407 pybind11_fail(
"pybind11::vectorize: incompatible size/dimension of inputs!");
1411 bool trivial_broadcast_c =
true;
1412 bool trivial_broadcast_f =
true;
1413 for (
size_t i = 0;
i < N && (trivial_broadcast_c || trivial_broadcast_f); ++
i) {
1414 if (buffers[
i].
size == 1)
1418 if (buffers[
i].ndim != ndim)
1419 return broadcast_trivial::non_trivial;
1422 if (!std::equal(buffers[
i].shape.cbegin(), buffers[
i].shape.cend(), shape.cbegin()))
1423 return broadcast_trivial::non_trivial;
1426 if (trivial_broadcast_c) {
1427 ssize_t expect_stride = buffers[
i].itemsize;
1428 auto end = buffers[
i].shape.crend();
1429 for (
auto shape_iter = buffers[
i].shape.crbegin(), stride_iter = buffers[
i].strides.crbegin();
1430 trivial_broadcast_c && shape_iter !=
end; ++shape_iter, ++stride_iter) {
1431 if (expect_stride == *stride_iter)
1432 expect_stride *= *shape_iter;
1434 trivial_broadcast_c =
false;
1439 if (trivial_broadcast_f) {
1440 ssize_t expect_stride = buffers[
i].itemsize;
1441 auto end = buffers[
i].shape.cend();
1442 for (
auto shape_iter = buffers[
i].shape.cbegin(), stride_iter = buffers[
i].strides.cbegin();
1443 trivial_broadcast_f && shape_iter !=
end; ++shape_iter, ++stride_iter) {
1444 if (expect_stride == *stride_iter)
1445 expect_stride *= *shape_iter;
1447 trivial_broadcast_f =
false;
1453 trivial_broadcast_c ? broadcast_trivial::c_trivial :
1455 broadcast_trivial::non_trivial;
1458 template <
typename T>
1475 template <
typename Func,
typename Return,
typename... Args>
1494 static void call(Return *out,
size_t i, Func &f, Args &...
args) {
1495 out[
i] = f(
args...);
1500 template <
typename Func,
typename... Args>
1512 static detail::void_type
call(Func &f, Args &...
args) {
1517 static void call(
void *,
size_t, Func &f, Args &...
args) {
1523 template <
typename Func,
typename Return,
typename... Args>
1533 static constexpr
size_t N =
sizeof...(Args);
1535 static_assert(NVectorized >= 1,
1536 "pybind11::vectorize(...) requires a function with at least one vectorizable argument");
1539 template <
typename T>
1554 using arg_call_types = std::tuple<typename vectorize_arg<Args>::call_type...>;
1555 template <
size_t Index>
using param_n_t =
typename std::tuple_element<Index, arg_call_types>::type;
1566 template <
size_t... Index,
size_t... VIndex,
size_t... BIndex>
object run(
1573 std::array<void *, N> params{{ &
args... }};
1576 std::array<buffer_info, NVectorized> buffers{{
reinterpret_cast<array *
>(params[VIndex])->request()... }};
1580 std::vector<ssize_t> shape(0);
1581 auto trivial =
broadcast(buffers, nd, shape);
1584 size_t size = std::accumulate(shape.begin(), shape.end(), (
size_t) 1, std::multiplies<size_t>());
1588 if (
size == 1 && ndim == 0) {
1590 return cast(returned_array::call(f, *
reinterpret_cast<param_n_t<Index> *
>(params[Index])...));
1593 auto result = returned_array::create(trivial, shape);
1598 auto mutable_data = returned_array::mutable_data(result);
1599 if (trivial == broadcast_trivial::non_trivial)
1600 apply_broadcast(buffers, params, mutable_data,
size, shape, i_seq, vi_seq, bi_seq);
1602 apply_trivial(buffers, params, mutable_data,
size, i_seq, vi_seq, bi_seq);
1607 template <
size_t... Index,
size_t... VIndex,
size_t... BIndex>
1608 void apply_trivial(std::array<buffer_info, NVectorized> &buffers,
1609 std::array<void *, N> ¶ms,
1617 std::array<std::pair<unsigned char *&, const size_t>, NVectorized> vecparams{{
1618 std::pair<unsigned char *&, const size_t>(
1619 reinterpret_cast<unsigned char *&
>(params[VIndex] = buffers[BIndex].ptr),
1620 buffers[BIndex].
size == 1 ? 0 :
sizeof(param_n_t<VIndex>)
1624 for (
size_t i = 0;
i <
size; ++
i) {
1625 returned_array::call(out,
i, f, *
reinterpret_cast<param_n_t<Index> *
>(params[Index])...);
1626 for (
auto &
x : vecparams)
x.first +=
x.second;
1630 template <
size_t... Index,
size_t... VIndex,
size_t... BIndex>
1631 void apply_broadcast(std::array<buffer_info, NVectorized> &buffers,
1632 std::array<void *, N> ¶ms,
1635 const std::vector<ssize_t> &output_shape,
1640 for (
size_t i = 0;
i <
size; ++
i, ++input_iter) {
1642 params[VIndex] = input_iter.template data<BIndex>()
1644 returned_array::call(out,
i, f, *
reinterpret_cast<param_n_t<Index> *
>(std::get<Index>(params))...);
1649 template <
typename Func,
typename Return,
typename... Args>
1652 return detail::vectorize_helper<Func, Return, Args...>(f);
1662 template <
typename Return,
typename... Args>
1663 detail::vectorize_helper<Return (*)(Args...), Return, Args...>
1665 return detail::vectorize_helper<Return (*)(Args...), Return, Args...>(f);
1676 template <
typename Return,
typename Class,
typename... Args,
1677 typename Helper = detail::vectorize_helper<decltype(std::mem_fn(std::declval<Return (Class::*)(Args...)>())), Return, Class *, Args...>>
1679 return Helper(std::mem_fn(f));
1683 template <
typename Return,
typename Class,
typename... Args,
1684 typename Helper = detail::vectorize_helper<decltype(std::mem_fn(std::declval<Return (Class::*)(Args...)
const>())), Return,
const Class *, Args...>>
1685 Helper
vectorize(Return (Class::*f)(Args...)
const) {
1686 return Helper(std::mem_fn(f));
1691 #if defined(_MSC_VER)
1692 #pragma warning(pop)