Handling the classmethod pickling issue with copy_reg

前端 未结 3 1576
花落未央
花落未央 2021-02-04 14:17

I met a pickling error when dealing with multiprocessing:

 from multiprocessing import Pool

 def test_func(x):
     return x**2

 class Test:
     @classmethod
         


        
相关标签:
3条回答
  • 2021-02-04 14:44

    I modified the recipe to make it work with classmethod. Here's the code.

    import copy_reg
    import types
    
    def _pickle_method(method):
        func_name = method.im_func.__name__
        obj = method.im_self
        cls = method.im_class
        if func_name.startswith('__') and not func_name.endswith('__'):
            #deal with mangled names
            cls_name = cls.__name__.lstrip('_')
            func_name = '_%s%s' % (cls_name, func_name)
        return _unpickle_method, (func_name, obj, cls)
    
    def _unpickle_method(func_name, obj, cls):
        if obj and func_name in obj.__dict__:
            cls, obj = obj, None # if func_name is classmethod
        for cls in cls.__mro__:
            try:
                func = cls.__dict__[func_name]
            except KeyError:
                pass
            else:
                break
        return func.__get__(obj, cls)
    
    copy_reg.pickle(types.MethodType, _pickle_method, _unpickle_method)
    
    0 讨论(0)
  • 2021-02-04 14:49

    Instead of returning the actual class object from _pickle_method, return a string that can be used to import it when unpickling and then do that in _unpickle_method

    0 讨论(0)
  • 2021-02-04 15:04

    The following solution now also handles class methods correctly. Please let me know if there is still something missing.

    def _pickle_method(method):
        """
        Pickle methods properly, including class methods.
        """
        func_name = method.im_func.__name__
        obj = method.im_self
        cls = method.im_class
        if isinstance(cls, type):
            # handle classmethods differently
            cls = obj
            obj = None
        if func_name.startswith('__') and not func_name.endswith('__'):
            #deal with mangled names
            cls_name = cls.__name__.lstrip('_')
            func_name = '_%s%s' % (cls_name, func_name)
    
        return _unpickle_method, (func_name, obj, cls)
    
    def _unpickle_method(func_name, obj, cls):
        """
        Unpickle methods properly, including class methods.
        """
        if obj is None:
            return cls.__dict__[func_name].__get__(obj, cls)
        for cls in cls.__mro__:
            try:
                func = cls.__dict__[func_name]
            except KeyError:
                pass
            else:
                break
        return func.__get__(obj, cls)
    
    0 讨论(0)
提交回复
热议问题