How to find all the subclasses of a class given its name?

后端 未结 10 1695
遥遥无期
遥遥无期 2020-11-22 10:14

I need a working approach of getting all classes that are inherited from a base class in Python.

相关标签:
10条回答
  • 2020-11-22 11:08

    If you just want direct subclasses then .__subclasses__() works fine. If you want all subclasses, subclasses of subclasses, and so on, you'll need a function to do that for you.

    Here's a simple, readable function that recursively finds all subclasses of a given class:

    def get_all_subclasses(cls):
        all_subclasses = []
    
        for subclass in cls.__subclasses__():
            all_subclasses.append(subclass)
            all_subclasses.extend(get_all_subclasses(subclass))
    
        return all_subclasses
    
    0 讨论(0)
  • 2020-11-22 11:09

    New-style classes (i.e. subclassed from object, which is the default in Python 3) have a __subclasses__ method which returns the subclasses:

    class Foo(object): pass
    class Bar(Foo): pass
    class Baz(Foo): pass
    class Bing(Bar): pass
    

    Here are the names of the subclasses:

    print([cls.__name__ for cls in Foo.__subclasses__()])
    # ['Bar', 'Baz']
    

    Here are the subclasses themselves:

    print(Foo.__subclasses__())
    # [<class '__main__.Bar'>, <class '__main__.Baz'>]
    

    Confirmation that the subclasses do indeed list Foo as their base:

    for cls in Foo.__subclasses__():
        print(cls.__base__)
    # <class '__main__.Foo'>
    # <class '__main__.Foo'>
    

    Note if you want subsubclasses, you'll have to recurse:

    def all_subclasses(cls):
        return set(cls.__subclasses__()).union(
            [s for c in cls.__subclasses__() for s in all_subclasses(c)])
    
    print(all_subclasses(Foo))
    # {<class '__main__.Bar'>, <class '__main__.Baz'>, <class '__main__.Bing'>}
    

    Note that if the class definition of a subclass hasn't been executed yet - for example, if the subclass's module hasn't been imported yet - then that subclass doesn't exist yet, and __subclasses__ won't find it.


    You mentioned "given its name". Since Python classes are first-class objects, you don't need to use a string with the class's name in place of the class or anything like that. You can just use the class directly, and you probably should.

    If you do have a string representing the name of a class and you want to find that class's subclasses, then there are two steps: find the class given its name, and then find the subclasses with __subclasses__ as above.

    How to find the class from the name depends on where you're expecting to find it. If you're expecting to find it in the same module as the code that's trying to locate the class, then

    cls = globals()[name]
    

    would do the job, or in the unlikely case that you're expecting to find it in locals,

    cls = locals()[name]
    

    If the class could be in any module, then your name string should contain the fully-qualified name - something like 'pkg.module.Foo' instead of just 'Foo'. Use importlib to load the class's module, then retrieve the corresponding attribute:

    import importlib
    modname, _, clsname = name.rpartition('.')
    mod = importlib.import_module(modname)
    cls = getattr(mod, clsname)
    

    However you find the class, cls.__subclasses__() would then return a list of its subclasses.

    0 讨论(0)
  • 2020-11-22 11:12

    Note: I see that someone (not @unutbu) changed the referenced answer so that it no longer uses vars()['Foo'] — so the primary point of my post no longer applies.

    FWIW, here's what I meant about @unutbu's answer only working with locally defined classes — and that using eval() instead of vars() would make it work with any accessible class, not only those defined in the current scope.

    For those who dislike using eval(), a way is also shown to avoid it.

    First here's a concrete example demonstrating the potential problem with using vars():

    class Foo(object): pass
    class Bar(Foo): pass
    class Baz(Foo): pass
    class Bing(Bar): pass
    
    # unutbu's approach
    def all_subclasses(cls):
        return cls.__subclasses__() + [g for s in cls.__subclasses__()
                                           for g in all_subclasses(s)]
    
    print(all_subclasses(vars()['Foo']))  # Fine because  Foo is in scope
    # -> [<class '__main__.Bar'>, <class '__main__.Baz'>, <class '__main__.Bing'>]
    
    def func():  # won't work because Foo class is not locally defined
        print(all_subclasses(vars()['Foo']))
    
    try:
        func()  # not OK because Foo is not local to func()
    except Exception as e:
        print('calling func() raised exception: {!r}'.format(e))
        # -> calling func() raised exception: KeyError('Foo',)
    
    print(all_subclasses(eval('Foo')))  # OK
    # -> [<class '__main__.Bar'>, <class '__main__.Baz'>, <class '__main__.Bing'>]
    
    # using eval('xxx') instead of vars()['xxx']
    def func2():
        print(all_subclasses(eval('Foo')))
    
    func2()  # Works
    # -> [<class '__main__.Bar'>, <class '__main__.Baz'>, <class '__main__.Bing'>]
    

    This could be improved by moving the eval('ClassName') down into the function defined, which makes using it easier without loss of the additional generality gained by using eval() which unlike vars() is not context-sensitive:

    # easier to use version
    def all_subclasses2(classname):
        direct_subclasses = eval(classname).__subclasses__()
        return direct_subclasses + [g for s in direct_subclasses
                                        for g in all_subclasses2(s.__name__)]
    
    # pass 'xxx' instead of eval('xxx')
    def func_ez():
        print(all_subclasses2('Foo'))  # simpler
    
    func_ez()
    # -> [<class '__main__.Bar'>, <class '__main__.Baz'>, <class '__main__.Bing'>]
    

    Lastly, it's possible, and perhaps even important in some cases, to avoid using eval() for security reasons, so here's a version without it:

    def get_all_subclasses(cls):
        """ Generator of all a class's subclasses. """
        try:
            for subclass in cls.__subclasses__():
                yield subclass
                for subclass in get_all_subclasses(subclass):
                    yield subclass
        except TypeError:
            return
    
    def all_subclasses3(classname):
        for cls in get_all_subclasses(object):  # object is base of all new-style classes.
            if cls.__name__.split('.')[-1] == classname:
                break
        else:
            raise ValueError('class %s not found' % classname)
        direct_subclasses = cls.__subclasses__()
        return direct_subclasses + [g for s in direct_subclasses
                                        for g in all_subclasses3(s.__name__)]
    
    # no eval('xxx')
    def func3():
        print(all_subclasses3('Foo'))
    
    func3()  # Also works
    # -> [<class '__main__.Bar'>, <class '__main__.Baz'>, <class '__main__.Bing'>]
    
    0 讨论(0)
  • 2020-11-22 11:18

    Here is a simple but efficient version of code:

    def get_all_subclasses(cls):
        subclass_list = []
    
        def recurse(klass):
            for subclass in klass.__subclasses__():
                subclass_list.append(subclass)
                recurse(subclass)
    
        recurse(cls)
    
        return set(subclass_list)
    

    Its time complexity is O(n) where n is the number of all subclasses if there's no multiple inheritance. It's more efficient than the functions that recursively create lists or yield classes with generators, whose complexity could be (1) O(nlogn) when the class hierarchy is a balanced tree or (2) O(n^2) when the class hierarchy is a biased tree.

    0 讨论(0)
提交回复
热议问题