diff --git a/src/module.cpp b/src/module.cpp index 9faaca06..0851805f 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -7,7 +7,6 @@ // producing this work. #include -#include namespace boost { namespace python { @@ -47,9 +46,9 @@ void module_base::add(PyObject* x, const char* name) void module_base::add(ref x, const char* name) { - ref f(x); // First take possession of the object. - if (PyObject_SetAttrString(m_module, const_cast(name), x.get()) < 0) - throw error_already_set(); + // Use function::add_to_namespace to achieve overloading if + // appropriate. + objects::function::add_to_namespace(m_module, name, x.get()); } void module_base::add(PyTypeObject* x, const char* name /*= 0*/) @@ -57,22 +56,6 @@ void module_base::add(PyTypeObject* x, const char* name /*= 0*/) this->add((PyObject*)x, name ? name : x->tp_name); } -void module_base::add_overload(objects::function* x, const char* name) -{ - PyObject* existing = PyObject_HasAttrString(m_module, const_cast(name)) - ? PyObject_GetAttrString(m_module, const_cast(name)) - : 0; - - if (existing != 0 && existing->ob_type == &objects::function_type) - { - static_cast(existing)->add_overload(x); - } - else - { - add(x, name); - } -} - PyMethodDef module_base::initial_methods[] = { { 0, 0, 0, 0 } }; }} // namespace boost::python diff --git a/src/object/function.cpp b/src/object/function.cpp index 52a0456c..3755dfb1 100644 --- a/src/object/function.cpp +++ b/src/object/function.cpp @@ -6,6 +6,8 @@ #include #include +#include +#include namespace boost { namespace python { namespace objects { @@ -28,7 +30,7 @@ function::~function() PyObject* function::call(PyObject* args, PyObject* keywords) const { - int nargs = PyTuple_GET_SIZE(args); + std::size_t nargs = PyTuple_GET_SIZE(args); function const* f = this; do { @@ -63,6 +65,8 @@ void function::argument_error(PyObject* args, PyObject* keywords) const void function::add_overload(function* overload) { + Py_XINCREF(overload); + function* parent = this; while (parent->m_overloads != 0) @@ -72,6 +76,42 @@ void function::add_overload(function* overload) parent->m_overloads = overload; } +void function::add_to_namespace( + PyObject* name_space, char const* name_, PyObject* attribute_) +{ + ref attribute(attribute_, ref::increment_count); + string name(name_); + + if (attribute_->ob_type == &function_type) + { + PyObject* dict = 0; + + if (PyClass_Check(name_space)) + dict = ((PyClassObject*)name_space)->cl_dict; + else if (PyType_Check(name_space)) + dict = ((PyTypeObject*)name_space)->tp_dict; + else + dict = PyObject_GetAttrString(name_space, "__dict__"); + + if (dict == 0) + throw error_already_set(); + + ref existing(PyObject_GetItem(dict, name.get()), ref::null_ok); + + if (existing.get() && existing->ob_type == &function_type) + { + static_cast(existing.get())->add_overload( + static_cast(attribute_)); + return; + } + } + + // The PyObject_GetAttrString() call above left an active error + PyErr_Clear(); + if (PyObject_SetAttr(name_space, name.get(), attribute_) < 0) + throw error_already_set(); +} + extern "C" { // Stolen from Python's funcobject.c