Method overloading decorator

后端 未结 3 1182
遥遥无期
遥遥无期 2021-01-12 04:42

I\'m trying to write a decorator that provides method overloading functionality to python, similar to the one mentioned in PEP 3124.

The decorator I wrote works grea

相关标签:
3条回答
  • 2021-01-12 05:18

    For reference, here is the working implementation, thanks to the detailed explanation by glglgl:

    argtype_tuple = lambda args: tuple(type(a) for a in args)
    
    class Overload(object):    
        def __init__(self, func):
            self.default = func
            self.map = {}
    
        def __call__(self, *args, **kwargs):
            key_tuple = argtype_tuple(args)
            c_inst = kwargs.pop("c_inst", None)
            if c_inst:
                args = (c_inst,) + args
            try:
                return self.map[key_tuple](*args, **kwargs)
            except KeyError:
                return self.default(*args, **kwargs)
    
        def __get__(self, obj, cls):
            if obj:
                return lambda *args, **kwargs: self(c_inst=obj, *args, **kwargs)
            else:
                return self
    
        def overload(self, *types):
            def wrapper(f):
                for type_seq in types:
                    if type(type_seq) == tuple:
                        type_seq = tuple(type_seq)
                    else:
                        type_seq = (type_seq,)
                    self.map[type_seq] = f
                return self
            return wrapper
    
    #Some tests/usage examples
    class A(object):
        @Overload
        def print_first(self, x):
            return x[0]
    
        @print_first.overload(str)
        def p_first(self, x):
            return x.split()[0]
    
        def __repr__(self):
            return "class A Instance"
    
    a = A()
    assert a.print_first([1,2,3]) == 1
    assert a.print_first("one two three") == "one"
    
    @Overload
    def flatten(seq):
        return [seq]
    
    @flatten.overload(list, tuple)
    def flat(seq):
        return sum((flatten(item) for item in seq), [])
    
    assert flatten([1,2,[3,4]]) == [1,2,3,4]
    assert flat([1,2,[3,4]]) == [1,2,3,4]
    
    0 讨论(0)
  • Essentially, your Overload class needs a __get__ method:

    def __get__(self, obj, cls):
        # Called on access of MyClass.print_first_item.
        # We return a wrapper which calls our 
        print "get", self, obj, cls
        if obj is None:
            # a function would do some checks here, but we leave that.
            return self
        else:
            return lambda *a, **k: self(obj, *a, **k)
    

    Why?

    Well, you use your Overload object as a kind of function replacement. You want it, like a function, to represent itself in a method context with different signature.

    Short explanation how method access works:

    object.meth(1, 2)
    

    gets translated to

    object.__dict__['meth'].__get__(object, type(object))(1, 2)
    

    A function's __get__() returns a method object which wraps the function by prepending the object to the parameter list (where it results in self):

    realmethod = object.__dict__['meth'].__get__(object, type(object))
    realmethod(1, 2)
    

    where realmethod is a method object which knows the function to be called and the self to be given to it and calls the "real" function appropriately by transforming the call into

    meth(object, 1, 2)
    

    .

    This behaviour we imitate in this new __get__ method.

    0 讨论(0)
  • 2021-01-12 05:26

    as abarnert says as you are using a class as your decorator 'self' is an instance of Overload rather than MyClass as you hope/expect.

    I couldn't find a simple solution. The best thing I could come up with is not using a class as a decorator and instead use a function but with a second argument with a default of a dictionary. Since this is an mutable type it will be the same dictionary every time the function is called. I use this to store my 'class variables'. The rests folows a similar pattern to your solution.

    Example:

    import inspect
    
    def overload(funcOrType, map={}, type=None):
        if not inspect.isclass(funcOrType):
            # We have a function so we are dealing with "@overload"
            if(type):
                map[type] = funcOrType
            else:
                map['default_function'] = funcOrType
        else:
            def overloadWithType(func):
                return overload(func, map, funcOrType)
            return  overloadWithType
    
        def doOverload(*args, **kwargs):
            for type in [t for t in map.keys() if t != 'default_function'] :
                if isinstance(args[1], type): # Note args[0] is 'self' i.e. MyClass instance.
                    return map[type](*args, **kwargs)
            return map['default_function'](*args, **kwargs)
    
        return doOverload
    

    Then:

    from overload import *
    
    class MyClass(object):
        def __init__(self):
            self.some_instance_var = 1
    
        @overload
        def print_first_item(self, x):
            return x[0], self.some_instance_var
    
        @overload(str)
        def print_first_item(self, x):
            return x.split()[0], self.some_instance_var
    
    
    m = MyClass()
    print (m.print_first_item(['a','b','c']))
    print (m.print_first_item("One Two Three"))
    

    Yeilds:

    ('a', 1)
    ('One', 1)
    
    0 讨论(0)
提交回复
热议问题