Tests: environment: use a context manager to restore original signal handler
[lttng-tools.git] / tests / utils / lttngtest / environment.py
index 282e95643c33a6ece9cd2f815bc1d28f5f6f69f0..2710a7efb957dd46994f20454054bbc1acc3ddfd 100644 (file)
@@ -6,7 +6,7 @@
 #
 
 from types import FrameType
-from typing import Callable, Iterator, Optional, Tuple, List
+from typing import Callable, Iterator, Optional, Tuple, List, Generator
 import sys
 import pathlib
 import signal
@@ -63,6 +63,19 @@ class _SignalWaitQueue:
     def wait_for_signal(self):
         self._queue.get(block=True)
 
+    @contextlib.contextmanager
+    def intercept_signal(self, signal_number):
+        # type: (int) -> Generator[None, None, None]
+        original_handler = signal.getsignal(signal_number)
+        signal.signal(signal_number, self.signal)
+        try:
+            yield
+        except:
+            # Restore the original signal handler and forward the exception.
+            raise
+        finally:
+            signal.signal(signal_number, original_handler)
+
 
 class WaitTraceTestApplication:
     """
@@ -189,11 +202,13 @@ class WaitTraceTestApplication:
 
 class TraceTestApplication:
     """
-    Create an application to trace.
+    Create an application that emits events as soon as it is launched. In most
+    scenarios, it is preferable to use a WaitTraceTestApplication.
     """
 
-    def __init__(self, binary_path: pathlib.Path, environment: "Environment"):
-        self._environment: Environment = environment
+    def __init__(self, binary_path, environment):
+        # type: (pathlib.Path, Environment)
+        self._environment = environment  # type: Environment
         self._has_returned = False
 
         test_app_env = os.environ.copy()
@@ -204,11 +219,12 @@ class TraceTestApplication:
 
         test_app_args = [str(binary_path)]
 
-        self._process: subprocess.Popen = subprocess.Popen(
+        self._process = subprocess.Popen(
             test_app_args, env=test_app_env
-        )
+        )  # type: subprocess.Popen
 
-    def wait_for_exit(self) -> None:
+    def wait_for_exit(self):
+        # type: () -> None
         if self._process.wait() != 0:
             raise RuntimeError(
                 "Test application has exit with return code `{return_code}`".format(
@@ -355,35 +371,33 @@ class _Environment(logger._Logger):
         sessiond_env["LTTNG_HOME"] = str(self._lttng_home.path)
 
         wait_queue = _SignalWaitQueue()
-        signal.signal(signal.SIGUSR1, wait_queue.signal)
-
-        self._log(
-            "Launching session daemon with LTTNG_HOME=`{home_dir}`".format(
-                home_dir=str(self._lttng_home.path)
+        with wait_queue.intercept_signal(signal.SIGUSR1):
+            self._log(
+                "Launching session daemon with LTTNG_HOME=`{home_dir}`".format(
+                    home_dir=str(self._lttng_home.path)
+                )
+            )
+            process = subprocess.Popen(
+                [
+                    str(sessiond_path),
+                    consumerd_path_option_name,
+                    str(consumerd_path),
+                    "--sig-parent",
+                ],
+                stdout=subprocess.PIPE,
+                stderr=subprocess.STDOUT,
+                env=sessiond_env,
             )
-        )
-        process = subprocess.Popen(
-            [
-                str(sessiond_path),
-                consumerd_path_option_name,
-                str(consumerd_path),
-                "--sig-parent",
-            ],
-            stdout=subprocess.PIPE,
-            stderr=subprocess.STDOUT,
-            env=sessiond_env,
-        )
 
-        if self._logging_function:
-            self._sessiond_output_consumer = ProcessOutputConsumer(
-                process, "lttng-sessiond", self._logging_function
-            )  # type: Optional[ProcessOutputConsumer]
-            self._sessiond_output_consumer.daemon = True
-            self._sessiond_output_consumer.start()
+            if self._logging_function:
+                self._sessiond_output_consumer = ProcessOutputConsumer(
+                    process, "lttng-sessiond", self._logging_function
+                )  # type: Optional[ProcessOutputConsumer]
+                self._sessiond_output_consumer.daemon = True
+                self._sessiond_output_consumer.start()
 
-        # Wait for SIGUSR1, indicating the sessiond is ready to proceed
-        wait_queue.wait_for_signal()
-        signal.signal(signal.SIGUSR1, wait_queue.signal)
+            # Wait for SIGUSR1, indicating the sessiond is ready to proceed
+            wait_queue.wait_for_signal()
 
         return process
 
@@ -412,9 +426,8 @@ class _Environment(logger._Logger):
             self,
         )
 
-    def launch_trace_test_constructor_application(
-        self
-    ) -> TraceTestApplication:
+    def launch_trace_test_constructor_application(self):
+        # type () -> TraceTestApplication
         """
         Launch an application that will trace from within constructors.
         """
This page took 0.024494 seconds and 4 git commands to generate.