问题
I need a unittest to load a previously saved class in a pickle. However, when I load the pickle in the unittest (out of unittest works), it raises the error:
AttributeError: Can't get attribute 'Foo' on <module 'unittest.main' from '...\unittest\main.py'>
Code example to save the class (I save this code in run_and_save_class.py
):
from pickle import dump
from pickle import load
from pickle import HIGHEST_PROTOCOL
class Foo(object):
def __init__(self):
self.bar = None
self.file_out = "./out.pkl"
def save_class(self):
with open(self.file_out, "wb") as file_out:
dump(self, file_out, protocol=HIGHEST_PROTOCOL)
def load_class(self):
with open(self.file_out, "rb") as file_out:
cls = load(file_out)
return cls
if __name__ == "__main__":
cls = Foo()
cls.bar = "saving a bar"
cls.save_class()
Code to test the class (I save this code in unittest_class.py
):
import unittest
from run_and_save_class import Foo
class ClassValidation(unittest.TestCase):
def __init__(self, *args, **kwargs):
print("init")
self.cls = Foo
self.instance = Foo().load_class()
print("class loaded")
unittest.TestCase.__init__(self, *args, **kwargs)
def test_anything(self):
pass
I run in Anaconda Prompt:
python run_and_save_class.py
python -m unittest -v unittest_class.py
The latter is the one that raises the error.
However, this works in a notebook.
from run_and_save_class import Foo
cls = Foo().load_class()
I don't understand why it doesn't in a unittest.
回答1:
The problem is that pickle saves the object relative to __main__
, where dump
was called (via save_class
). To load the same object, you have to provide the same environment - a workaround is to add the class to __main__
in your test, so that pickle can find it:
import __main__
class ClassValidation(unittest.TestCase):
def __init__(self, *args, **kwargs):
__main__.Foo = Foo
self.cls = Foo
self.instance = Foo().load_class()
unittest.TestCase.__init__(self, *args, **kwargs)
def test_anything(self):
self.assertEqual("saving a bar", self.instance.bar)
回答2:
Try to instanciate the Foo object in the ClassValidation in this way : self.cls = Foo()
来源:https://stackoverflow.com/questions/63827918/unittest-unable-to-import-class-from-pickle-attributeerror-cant-get-attribute