(Unit) Test python signal handler

后端 未结 2 930
天涯浪人
天涯浪人 2021-01-21 13:51

I have a simple Python service, where there is a loop that performs some action infinitely. On various signals, sys.exit(0) is called, which causes SystemExit

相关标签:
2条回答
  • 2021-01-21 14:16

    Let's refactor that to make it easier to test:

    def loop():
        try:
            print("Some action here")
        except:
            # clean up and re-raise
            print("Some cleanup")
            raise
    
    def main():
    
        def signal_handler(signalnum, _):
            # How to get this to block to run in a test?
            sys.exit(0)
    
        signal.signal(signal.SIGINT, signal_handler)
        signal.signal(signal.SIGTERM, signal_handler)
    
        while True:
            try:
                loop_body()
                time.sleep(10)
            except SystemExit:
                break
    
    if __name__ == '__main__':
        main()
    

    This doesn't allow easy testing of the signal handling code though. However, that amount is so small, rarely changed and strongly depends on the environment, that it is possible and perhaps even better to test manually.

    For clarity, it could be useful to use a context handler, which is usually a good idea when you have setup/shutdown code. You don't mention the setup code, but my Crystall Ball (tm) tells me it exists. It could then be called like this:

    try:
        with my_service() as service:
            while True:
                service.run()
                sleep(10)
    except SystemExit:
        # perform graceful shutdown on signal
        pass
    

    I'll leave the implementation of that context manager to you, but check out contextlib, which makes it easy and fun.

    0 讨论(0)
  • 2021-01-21 14:28

    You can trigger a SIGINT (or any signal) from another thread after some delay, which is received in the main thread. You can then assert on its effects just as in any other test, as below.

    import os
    import signal
    import time
    import threading
    import unittest
    from unittest.mock import (
        Mock,
        patch,
    )
    
    import service
    
    class TestService(unittest.TestCase):
    
        @patch('service.print')
        def test_signal_handling(self, mock_print):
    
            pid = os.getpid()
    
            def trigger_signal():
                while len(mock_print.mock_calls) < 1:
                    time.sleep(0.2)
                os.kill(pid, signal.SIGINT)
    
            thread = threading.Thread(target=trigger_signal)
            thread.daemon = True
            thread.start()
    
            service.main()
    
            self.assertEqual(mock_print.mock_calls[1][1][0], 'Some cleanup')
    
    
    if __name__ == '__main__':
        unittest.main()
    
    0 讨论(0)
提交回复
热议问题