Cython/Python/C++ - Inheritance: Passing Derived Class as Argument to Function expecting base class

后端 未结 3 1391
旧时难觅i
旧时难觅i 2021-01-31 06:28

I am using Cython to wrap a set of C++ classes, allowing a Python interface to them. Example Code is provided below:

BaseClass.h:

#ifndef __BaseClass__
#         


        
相关标签:
3条回答
  • 2021-01-31 06:54

    Your code, as written, doesn't compile. I suspect that your real PyDerivedClass doesn't really derive from PyBaseClass as if it did that last line would have to be

    (<DerivedClass*>self.thisptr).SetObject(inputObject.thisptr)
    

    This would also explain the type error you're getting, which is a bug I can't reproduce.

    0 讨论(0)
  • 2021-01-31 07:06

    After a lot of help from the answers below, and experimentation, I think I understand how implementing basic inheritance within Cython works, I'm answering my own question to validate/improve my understanding, as well as hopefully help out anyone who in the future may encounter a related issue. If there is anything wrong with this explanation, feel free to correct me in the comments below, and I will edit it. I don't think this is the only way to do it, so I'm sure alternate methods work, but this is the way that worked for me.

    Overview/Things Learnt:

    So basically, from my understanding, Cython is smart enough (given the appropriate information) to traverse through the inheritance hiearchy/tree and call the appropriate implementation of a virtual function based on the type of the object that you are calling it on.

    The important thing is to try and mirror the C++ inheritance structure which you are trying to wrap in your .pyx file. This means that ensuring:

    1) Imported C++/Cython cppclasses (the ones which are declared as cdef extern from) inherit each other the same way the actual C++ classes do

    2) Only unique methods/member variables are declared for each imported class (should not have a function declaration for both BaseClass and DerivedClass for a virtual function that is implemented differently in the two classes). As long as one inherits from the other, the function declaration only needs to be in the Base imported class.

    3) Python wrapper classes (ie. PyBaseClass / PyDerivedClass) should also inherit from each other the same way the actual C++ classes do

    4) Similar to above, the interface to a virtual function only needs to exist in the PyBase wrapper class (should not be putting in both classes, the correct implementation will be called when you actually run the code).

    5) For each Python wrapper class that is subclassed or inherited from, you need a if type(self) is class-name: check in both the __cinit__() and the __dealloc__() functions. This will prevent seg-faults etc. You don't need this check for "leaf-nodes" in the hiearchy tree (classes which won't be inherited from or subclassed)

    6) Make sure that in the __dealloc__() function, you only delete the current pointer (and not any inherited ones)

    7) Again, in the __cinit__(), for inherited classes make sure you set the current pointer, as well as all derived pointers to an object of the type you are trying to create (ie. *self.nextDerivedptr = self.derivedptr = self.thisptr = new NextDerivedClass()*)

    Hopefully the above points make a lot of sense when you see the code below, this compiles and runs/works as I need/intend it to work.

    BaseClass.h:

    #ifndef __BaseClass__
    #define __BaseClass__
    
    #include <stdio.h>
    #include <stdlib.h>
    #include <string>
    
    using namespace std;
    
    class BaseClass
    {
        public:
            BaseClass(){};
            virtual ~BaseClass(){};
            virtual void SetName(string name){printf("BASE: in set name\n");}
            virtual float Evaluate(float time){printf("BASE: in Evaluate\n");return 0;}
            virtual bool DataExists(){printf("BASE: in data exists\n");return false;}
    };
    #endif /* defined(__BaseClass__) */ 
    

    DerivedClass.h:

    #ifndef __DerivedClass__
    #define __DerivedClass__
    
    #include "BaseClass.h"
    #include "string.h"
    
    using namespace std;
    
    class DerivedClass:public BaseClass
    {
        public:
            DerivedClass(){};
            virtual ~DerivedClass(){};
            virtual void SetName(string name){printf("DERIVED CLASS: in Set name \n");}
            virtual float Evaluate(float time){printf("DERIVED CLASS: in Evaluate\n");return 1.0;}
            virtual bool DataExists(){printf("DERIVED CLASS:in data exists\n");return true;}
            virtual void MyFunction(){printf("DERIVED CLASS: in my function\n");}
            virtual void SetObject(BaseClass *input){printf("DERIVED CLASS: in set object\n");}
    };
    #endif /* defined(__DerivedClass__) */
    

    NextDerivedClass.h:

        #ifndef __NextDerivedClass__
        #define __NextDerivedClass__
    
        #include "DerivedClass.h"
    
        class NextDerivedClass:public DerivedClass
        {
            public:
                NextDerivedClass(){};
                virtual ~NextDerivedClass(){};
                virtual void SetObject(BaseClass *input){printf("NEXT DERIVED CLASS: in set object\n");}
                virtual bool DataExists(){printf("NEXT DERIVED CLASS: in data exists \n");return true;}
        };
        #endif /* defined(__NextDerivedClass__) */
    

    inheritTest.pyx:

    #Necessary Compilation Options
    #distutils: language = c++
    #distutils: extra_compile_args = ["-std=c++11", "-g"]
    
    #Import necessary modules
    from libcpp cimport bool
    from libcpp.string cimport string
    from libcpp.map cimport map
    from libcpp.pair cimport pair
    from libcpp.vector cimport vector
    
    cdef extern from "BaseClass.h":
        cdef cppclass BaseClass:
            BaseClass() except +
            void SetName(string)
            float Evaluate(float)
            bool DataExists()
    
    cdef extern from "DerivedClass.h":
        cdef cppclass DerivedClass(BaseClass):
            DerivedClass() except +
            void MyFunction()
            void SetObject(BaseClass *)
    
    cdef extern from "NextDerivedClass.h":
        cdef cppclass NextDerivedClass(DerivedClass):
            NextDerivedClass() except +
    
    cdef class PyBaseClass:
        cdef BaseClass *thisptr
        def __cinit__(self):
            if type(self) is PyBaseClass:
                self.thisptr = new BaseClass()
        def __dealloc__(self):
            if type(self) is PyBaseClass:
                del self.thisptr
        def SetName(self, name):
            self.thisptr.SetName(name)
        def Evaluate(self, time):
            return self.thisptr.Evaluate(time)
        def DataExists(self):
            return self.thisptr.DataExists()
    
    cdef class PyDerivedClass(PyBaseClass):
        cdef DerivedClass *derivedptr
        def __cinit__(self):
            if type(self) is PyDerivedClass:
                self.derivedptr = self.thisptr = new DerivedClass()
        def __dealloc__(self):
            if type(self) is PyBaseClass:
                del self.derivedptr
        def SetObject(self, PyBaseClass inputObject):
            self.derivedptr.SetObject(<BaseClass *>inputObject.thisptr)
        def MyFunction(self):
            self.derivedptr.MyFunction()
    
    cdef class PyNextDerivedClass(PyDerivedClass):
        cdef NextDerivedClass *nextDerivedptr
        def __cinit__(self):
            self.nextDerivedptr = self.derivedptr = self.thisptr = new NextDerivedClass()
        def __dealloc__(self):
            del self.nextDerivedptr
    

    test.py:

    from inheritTest import PyBaseClass as base
    from inheritTest import PyDerivedClass as der
    from inheritTest import PyNextDerivedClass as nextDer
    
    a = der()
    b = der()
    a.SetObject(b)
    c = nextDer()
    a.SetObject(c)
    c.DataExists()
    c.SetObject(b)
    c.Evaluate(0.3)
    
    
    baseSig = base()
    signal = der()
    baseSig.SetName('test')
    signal.SetName('testingone')
    baseSig.Evaluate(0.3)
    signal.Evaluate(0.5)
    signal.SetObject(b)
    baseSig.DataExists()
    signal.DataExists()
    

    Notice that when I call:

    c = nextDer()
    c.Evaluate(0.3)
    

    The way it works is Cython goes down the inheritance tree to look for the "latest" implementation of Evaluate. If it existed in NextDerivedClass.h, it would call that (I have tried that and it works), since it's not there however, it goes one step up and checks DerivedClass. The function is implemented there, thus the output is:

    >> DERIVED CLASS: in Evaluate
    

    I hope this helps someone in the future, again, if there are errors in my understanding or just grammar/syntax, feel free to comment below and I will try and address them. Again, big thanks to those who answered below, this is sort of a summary of their answers, just to help validate my understanding. Thanks!

    0 讨论(0)
  • 2021-01-31 07:07

    Honestly, this looks like a bug. The object you're passing in is an instance of the desired class, but it still throws an error. You may want to bring it up on the cython-users mailing list so the main developers can look at it.

    A possible workaround would be to define a fused type that represents both types of arguments and use that inside the method. That seems like overkill though.

    0 讨论(0)
提交回复
热议问题