diff --git a/class_wrapper.h b/class_wrapper.h index dbb64e4a..4fedde09 100644 --- a/class_wrapper.h +++ b/class_wrapper.h @@ -65,7 +65,38 @@ class ClassWrapper template void def_read_write(MemberType T::*pm, const char* name) { m_class->def_read_write(pm, name); } - private: + + // declare the given class a base class of this and register + // conversion functions + template + void declare_base(ClassWrapper const & base) + { + m_class->declare_base(base.m_class.get()); + } + + // declare the given class a base class of this and register + // conversion functions + template + void declare_base(ClassWrapper const & base, WithoutDowncast) + { + m_class->declare_base(base.m_class.get(), without_downcast); + } + + template + void declare_base(ExtensionClass * base) + { + m_class->declare_base(base); + } + + // declare the given class a base class of this and register + // only upcast function + template + void declare_base(ExtensionClass * base, WithoutDowncast) + { + m_class->declare_base(base, without_downcast); + } + +// private: PyPtr > m_class; }; diff --git a/extclass.cpp b/extclass.cpp index a1734760..a4f723ae 100644 --- a/extclass.cpp +++ b/extclass.cpp @@ -157,7 +157,7 @@ void report_missing_instance_data( } else { - two_string_error(PyExc_TypeError, "extension class '%.*s' is not derived from '%.*s'.", + two_string_error(PyExc_TypeError, "extension class '%.*s' is not convertible into '%.*s'.", instance->ob_type->tp_name, target_class->tp_name); } } @@ -218,6 +218,41 @@ ExtensionClassBase::ExtensionClassBase(const char* name) { } +void * ExtensionClassBase::try_class_conversions(InstanceHolderBase * object) const +{ + void * result = try_sub_class_conversions(object); + if(result) return result; + result = try_super_class_conversions(object); + return result; +} + +void * ExtensionClassBase::try_super_class_conversions(InstanceHolderBase * object) const +{ + void * result = 0; + for(int i=0; iconvert_from_holder(object); + if(result) return (*base_classes()[i].second)(result); + result = base_classes()[i].first->try_super_class_conversions(object); + if(result) return (*base_classes()[i].second)(result); + } + return 0; +} + +void * ExtensionClassBase::try_sub_class_conversions(InstanceHolderBase * object) const +{ + void * result = 0; + for(int i=0; iconvert_from_holder(object); + if(result) return (*sub_classes()[i].second)(result); + result = sub_classes()[i].first->try_sub_class_conversions(object); + if(result) return (*sub_classes()[i].second)(result); + } + return 0; +} + void ExtensionClassBase::add_method(Function* method, const char* name) { add_method(PyPtr(method), name); diff --git a/extclass.h b/extclass.h index 3374c0ae..fb96955d 100644 --- a/extclass.h +++ b/extclass.h @@ -52,6 +52,9 @@ T* check_non_null(T* p) } template class HeldInstance; +typedef void * (*ConversionFct)(void *); +typedef std::pair BaseClassInfo; +typedef std::pair SubClassInfo; class ExtensionClassBase : public Class { @@ -66,18 +69,30 @@ class ExtensionClassBase : public Class void add_constructor_object(Function*); void add_setter_method(Function*, const char* name); void add_getter_method(Function*, const char* name); + + virtual void * try_class_conversions(InstanceHolderBase*) const; + virtual void * try_super_class_conversions(InstanceHolderBase*) const; + virtual void * try_sub_class_conversions(InstanceHolderBase*) const; + virtual std::vector const & base_classes() const = 0; + virtual std::vector const & sub_classes() const = 0; }; template class ClassRegistry { - public: + public: static Class* class_object() { return static_class_object; } static void register_class(py::Class*); static void unregister_class(py::Class*); + static void register_base_class(BaseClassInfo const &); + static void register_sub_class(SubClassInfo const &); + static std::vector const & base_classes(); + static std::vector const & sub_classes(); private: static py::Class* static_class_object; + static std::vector static_base_class_info; + static std::vector static_sub_class_info; }; #ifdef PY_NO_INLINE_FRIENDS_IN_NAMESPACE // back to global namespace for this GCC bug @@ -123,6 +138,10 @@ class PyExtensionClassConverters py::InstanceHolder* held = dynamic_cast*>(*p); if (held != 0) return held->target(); + + void * target = py::ClassRegistry::class_object()->try_class_conversions(*p); + if(target) + return static_cast(target); } py::report_missing_instance_data(self, py::ClassRegistry::class_object(), typeid(T)); throw py::ArgumentError(); @@ -250,6 +269,23 @@ class ReadOnlySetattrFunction : public Function String m_name; }; +template +struct ConversionFunction +{ + static void * upcast_ptr(void * v) + { + return static_cast(static_cast(v)); + } + + static void * downcast_ptr(void * v) + { + return dynamic_cast(static_cast(v)); + } +}; + +enum WithoutDowncast { without_downcast }; + + // An easy way to make an extension base class which wraps T. Note that Python // subclasses of this class will simply be Class objects. // @@ -344,10 +380,40 @@ class ExtensionClass this->def_getter(pm, name); this->def_setter(pm, name); } + + // declare the given class a base class of this and register + // conversion functions + template + void declare_base(ExtensionClass * base) + { + BaseClassInfo baseInfo(base, &ConversionFunction::downcast_ptr); + ClassRegistry::register_base_class(baseInfo); + add_base(Ptr(as_object(base))); + + SubClassInfo subInfo(this, &ConversionFunction::upcast_ptr); + ClassRegistry::register_sub_class(subInfo); + } + + // declare the given class a base class of this and register + // only upcast function + template + void declare_base(ExtensionClass * base, WithoutDowncast) + { + BaseClassInfo baseInfo(base, 0); + ClassRegistry::register_base_class(baseInfo); + add_base(Ptr(as_object(base))); + + SubClassInfo subInfo(this, &ConversionFunction::upcast_ptr); + ClassRegistry::register_sub_class(subInfo); + } private: typedef InstanceValueHolder Holder; + virtual std::vector const & base_classes() const; + virtual std::vector const & sub_classes() const; + virtual void * convert_from_holder(InstanceHolderBase * v) const; + template void add_constructor(Signature sig) { @@ -464,6 +530,28 @@ ExtensionClass::ExtensionClass(const char* name) ClassRegistry::register_class(this); } +template +inline +std::vector const & ExtensionClass::base_classes() const +{ + return ClassRegistry::base_classes(); +} + +template +inline +std::vector const & ExtensionClass::sub_classes() const +{ + return ClassRegistry::sub_classes(); +} + +template +void * ExtensionClass::convert_from_holder(InstanceHolderBase * v) const +{ + py::InstanceHolder* held = dynamic_cast*>(v); + if(held) return held->target(); + return 0; +} + template ExtensionClass::~ExtensionClass() { @@ -487,11 +575,39 @@ inline void ClassRegistry::unregister_class(Class* p) static_class_object = 0; } +template +void ClassRegistry::register_base_class(BaseClassInfo const & i) +{ + static_base_class_info.push_back(i); +} + +template +void ClassRegistry::register_sub_class(SubClassInfo const & i) +{ + static_sub_class_info.push_back(i); +} + +template +std::vector const & ClassRegistry::base_classes() +{ + return static_base_class_info; +} + +template +std::vector const & ClassRegistry::sub_classes() +{ + return static_sub_class_info; +} + // // Static data member declaration. // template Class* ClassRegistry::static_class_object; +template +std::vector ClassRegistry::static_base_class_info; +template +std::vector ClassRegistry::static_sub_class_info; } // namespace py diff --git a/extclass_demo.cpp b/extclass_demo.cpp index 7b678120..6ea54f5e 100644 --- a/extclass_demo.cpp +++ b/extclass_demo.cpp @@ -363,6 +363,68 @@ static int getX(OverloadTest * u) return u->x(); } + +/************************************************************/ +/* */ +/* classes to test base declarations snd conversions */ +/* */ +/************************************************************/ + +struct Dummy +{ + int dummy_; +}; + +struct Base +{ + virtual int x() const { return 999; }; +}; + +// inherit Dummy so that the Base part of Concrete starts at an offset +// otherwise, typecast tests wouldn't be very meaningful +struct Derived1 : public Dummy, public Base +{ + Derived1(int x): x_(x) {} + virtual int x() const { return x_; } + + private: + int x_; +}; + +struct Derived2 : public Dummy, public Base +{ + Derived2(int x): x_(x) {} + virtual int x() const { return x_; } + + private: + int x_; +}; + +static int testUpcast(Base * b) +{ + return b->x(); +} + +static std::auto_ptr derived1Factory(int i) +{ + return std::auto_ptr(new Derived1(i)); +} + +static std::auto_ptr derived2Factory(int i) +{ + return std::auto_ptr(new Derived2(i)); +} + +static int testDowncast1(Derived1 * d) +{ + return d->x(); +} + +static int testDowncast2(Derived2 * d) +{ + return d->x(); +} + /************************************************************/ /* */ /* init the module */ @@ -426,6 +488,25 @@ void init_module(py::Module& m) over.def(&OverloadTest::p4, "overloaded"); over.def(&OverloadTest::p5, "overloaded"); + py::ClassWrapper base(m, "Base"); + base.def(&Base::x, "x"); + + py::ClassWrapper derived1(m, "Derived1"); + // this enables conversions between Base and Derived1 + // and makes wrapped methods of Base available + derived1.declare_base(base); + derived1.def(py::Constructor()); + + py::ClassWrapper derived2(m, "Derived2"); + // don't enable downcast from Base to Derived2 + derived2.declare_base(base, py::without_downcast); + derived2.def(py::Constructor()); + + m.def(&testUpcast, "testUpcast"); + m.def(&derived1Factory, "derived1Factory"); + m.def(&derived2Factory, "derived2Factory"); + m.def(&testDowncast1, "testDowncast1"); + m.def(&testDowncast2, "testDowncast2"); } void init_module() diff --git a/newtypes.h b/newtypes.h index 94382232..51a27e9d 100644 --- a/newtypes.h +++ b/newtypes.h @@ -30,6 +30,8 @@ namespace py { +class InstanceHolderBase; + class TypeObjectBase : public PythonType { public: @@ -59,6 +61,11 @@ class TypeObjectBase : public PythonType virtual PyObject* instance_getattr(PyObject* instance, const char* name) const; virtual int instance_setattr(PyObject* instance, const char* name, PyObject* value) const; + virtual void * try_class_conversions(InstanceHolderBase*) const { return 0; } + virtual void * try_super_class_conversions(InstanceHolderBase*) const { return 0; } + virtual void * try_sub_class_conversions(InstanceHolderBase*) const { return 0; } + virtual void * convert_from_holder(InstanceHolderBase*) const { return 0; } + // Dealloc is a special case, since every type needs a nonzero tp_dealloc slot. virtual void instance_dealloc(PyObject*) const = 0; diff --git a/test_extclass.py b/test_extclass.py index 138a7107..8a0c06f0 100644 --- a/test_extclass.py +++ b/test_extclass.py @@ -130,7 +130,7 @@ But objects not derived from Bar cannot: >>> baz.pass_bar(baz) Traceback (innermost last): ... - TypeError: extension class 'Baz' is not derived from 'Bar'. + TypeError: extension class 'Baz' is not convertible into 'Bar'. The clone function on Baz returns a smart pointer; we wrap it into an ExtensionInstance and make it look just like any other Baz instance. @@ -437,28 +437,28 @@ Testing overloaded free functions Testing overloaded constructors - >>> x = OverloadTest() - >>> x.getX() + >>> over = OverloadTest() + >>> over.getX() 1000 - >>> x = OverloadTest(1) - >>> x.getX() + >>> over = OverloadTest(1) + >>> over.getX() 1 - >>> x = OverloadTest(1,1) - >>> x.getX() + >>> over = OverloadTest(1,1) + >>> over.getX() 2 - >>> x = OverloadTest(1,1,1) - >>> x.getX() + >>> over = OverloadTest(1,1,1) + >>> over.getX() 3 - >>> x = OverloadTest(1,1,1,1) - >>> x.getX() + >>> over = OverloadTest(1,1,1,1) + >>> over.getX() 4 - >>> x = OverloadTest(1,1,1,1,1) - >>> x.getX() + >>> over = OverloadTest(1,1,1,1,1) + >>> over.getX() 5 - >>> x = OverloadTest(x) - >>> x.getX() + >>> over = OverloadTest(over) + >>> over.getX() 5 - >>> try: x = OverloadTest(1, 'foo') + >>> try: over = OverloadTest(1, 'foo') ... except TypeError, err: ... assert re.match("No overloaded functions match \(OverloadTest, int, string\)\. Candidates are:", ... str(err)) @@ -467,25 +467,51 @@ Testing overloaded constructors Testing overloaded methods - >>> x.setX(3) - >>> x.overloaded() + >>> over.setX(3) + >>> over.overloaded() 3 - >>> x.overloaded(1) + >>> over.overloaded(1) 1 - >>> x.overloaded(1,1) + >>> over.overloaded(1,1) 2 - >>> x.overloaded(1,1,1) + >>> over.overloaded(1,1,1) 3 - >>> x.overloaded(1,1,1,1) + >>> over.overloaded(1,1,1,1) 4 - >>> x.overloaded(1,1,1,1,1) + >>> over.overloaded(1,1,1,1,1) 5 - >>> try: x.overloaded(1,'foo') + >>> try: over.overloaded(1,'foo') ... except TypeError, err: ... assert re.match("No overloaded functions match \(OverloadTest, int, string\)\. Candidates are:", ... str(err)) ... else: ... print 'no exception' + +Testing base class conversions + + >>> testUpcast(over) + Traceback (innermost last): + TypeError: extension class 'OverloadTest' is not convertible into 'Base'. + >>> der1 = Derived1(333) + >>> der1.x() + 333 + >>> testUpcast(der1) + 333 + >>> der1 = derived1Factory(1000) + >>> testDowncast1(der1) + 1000 + >>> testDowncast2(der1) + Traceback (innermost last): + TypeError: extension class 'Base' is not convertible into 'Derived2'. + >>> der2 = Derived2(444) + >>> der2.x() + 444 + >>> testUpcast(der2) + 444 + >>> der2 = derived2Factory(1111) + >>> testDowncast2(der1) + Traceback (innermost last): + TypeError: extension class 'Base' is not convertible into 'Derived2'. ''' from demo import *