Method overloading decorator

后端 未结 3 1181
遥遥无期
遥遥无期 2021-01-12 04:42

I\'m trying to write a decorator that provides method overloading functionality to python, similar to the one mentioned in PEP 3124.

The decorator I wrote works grea

3条回答
  •  广开言路
    2021-01-12 05:18

    For reference, here is the working implementation, thanks to the detailed explanation by glglgl:

    argtype_tuple = lambda args: tuple(type(a) for a in args)
    
    class Overload(object):    
        def __init__(self, func):
            self.default = func
            self.map = {}
    
        def __call__(self, *args, **kwargs):
            key_tuple = argtype_tuple(args)
            c_inst = kwargs.pop("c_inst", None)
            if c_inst:
                args = (c_inst,) + args
            try:
                return self.map[key_tuple](*args, **kwargs)
            except KeyError:
                return self.default(*args, **kwargs)
    
        def __get__(self, obj, cls):
            if obj:
                return lambda *args, **kwargs: self(c_inst=obj, *args, **kwargs)
            else:
                return self
    
        def overload(self, *types):
            def wrapper(f):
                for type_seq in types:
                    if type(type_seq) == tuple:
                        type_seq = tuple(type_seq)
                    else:
                        type_seq = (type_seq,)
                    self.map[type_seq] = f
                return self
            return wrapper
    
    #Some tests/usage examples
    class A(object):
        @Overload
        def print_first(self, x):
            return x[0]
    
        @print_first.overload(str)
        def p_first(self, x):
            return x.split()[0]
    
        def __repr__(self):
            return "class A Instance"
    
    a = A()
    assert a.print_first([1,2,3]) == 1
    assert a.print_first("one two three") == "one"
    
    @Overload
    def flatten(seq):
        return [seq]
    
    @flatten.overload(list, tuple)
    def flat(seq):
        return sum((flatten(item) for item in seq), [])
    
    assert flatten([1,2,[3,4]]) == [1,2,3,4]
    assert flat([1,2,[3,4]]) == [1,2,3,4]
    

提交回复
热议问题