Tests: environment: use a context manager to restore original signal handler
[lttng-tools.git] / tests / utils / lttngtest / environment.py
index d9777810bdeb90fa7c4b7563598c43289304a8f6..2710a7efb957dd46994f20454054bbc1acc3ddfd 100644 (file)
@@ -6,7 +6,7 @@
 #
 
 from types import FrameType
-from typing import Callable, Optional, Tuple, List
+from typing import Callable, Iterator, Optional, Tuple, List, Generator
 import sys
 import pathlib
 import signal
@@ -23,14 +23,16 @@ import contextlib
 
 
 class TemporaryDirectory:
-    def __init__(self, prefix: str):
+    def __init__(self, prefix):
+        # type: (str) -> None
         self._directory_path = tempfile.mkdtemp(prefix=prefix)
 
     def __del__(self):
         shutil.rmtree(self._directory_path, ignore_errors=True)
 
     @property
-    def path(self) -> pathlib.Path:
+    def path(self):
+        # type: () -> pathlib.Path
         return pathlib.Path(self._directory_path)
 
 
@@ -49,14 +51,31 @@ class _SignalWaitQueue:
     """
 
     def __init__(self):
-        self._queue: queue.Queue = queue.Queue()
+        self._queue = queue.Queue()  # type: queue.Queue
 
-    def signal(self, signal_number, frame: Optional[FrameType]):
+    def signal(
+        self,
+        signal_number,
+        frame,  # type: Optional[FrameType]
+    ):
         self._queue.put_nowait(signal_number)
 
     def wait_for_signal(self):
         self._queue.get(block=True)
 
+    @contextlib.contextmanager
+    def intercept_signal(self, signal_number):
+        # type: (int) -> Generator[None, None, None]
+        original_handler = signal.getsignal(signal_number)
+        signal.signal(signal_number, self.signal)
+        try:
+            yield
+        except:
+            # Restore the original signal handler and forward the exception.
+            raise
+        finally:
+            signal.signal(signal_number, original_handler)
+
 
 class WaitTraceTestApplication:
     """
@@ -67,18 +86,18 @@ class WaitTraceTestApplication:
 
     def __init__(
         self,
-        binary_path: pathlib.Path,
-        event_count: int,
-        environment: "Environment",
-        wait_time_between_events_us: int = 0,
+        binary_path,  # type: pathlib.Path
+        event_count,  # type: int
+        environment,  # type: Environment
+        wait_time_between_events_us=0,  # type: int
     ):
-        self._environment: Environment = environment
+        self._environment = environment  # type: Environment
         if event_count % 5:
             # The test application currently produces 5 different events per iteration.
             raise ValueError("event count must be a multiple of 5")
-        self._iteration_count: int = int(event_count / 5)
+        self._iteration_count = int(event_count / 5)  # type: int
         # File that the application will wait to see before tracing its events.
-        self._app_start_tracing_file_path: pathlib.Path = pathlib.Path(
+        self._app_start_tracing_file_path = pathlib.Path(
             tempfile.mktemp(
                 prefix="app_",
                 suffix="_start_tracing",
@@ -94,11 +113,11 @@ class WaitTraceTestApplication:
         test_app_env["LTTNG_UST_REGISTER_TIMEOUT"] = "-1"
 
         # File that the application will create to indicate it has completed its initialization.
-        app_ready_file_path: str = tempfile.mktemp(
+        app_ready_file_path = tempfile.mktemp(
             prefix="app_",
             suffix="_ready",
             dir=self._compat_open_path(environment.lttng_home_location),
-        )
+        )  # type: str
 
         test_app_args = [str(binary_path)]
         test_app_args.extend(
@@ -112,10 +131,10 @@ class WaitTraceTestApplication:
             )
         )
 
-        self._process: subprocess.Popen = subprocess.Popen(
+        self._process = subprocess.Popen(
             test_app_args,
             env=test_app_env,
-        )
+        )  # type: subprocess.Popen
 
         # Wait for the application to create the file indicating it has fully
         # initialized. Make sure the app hasn't crashed in order to not wait
@@ -134,7 +153,8 @@ class WaitTraceTestApplication:
 
             time.sleep(0.1)
 
-    def trace(self) -> None:
+    def trace(self):
+        # type: () -> None
         if self._process.poll() is not None:
             # Application has unexepectedly returned.
             raise RuntimeError(
@@ -144,7 +164,8 @@ class WaitTraceTestApplication:
             )
         open(self._compat_open_path(self._app_start_tracing_file_path), mode="x")
 
-    def wait_for_exit(self) -> None:
+    def wait_for_exit(self):
+        # type: () -> None
         if self._process.wait() != 0:
             raise RuntimeError(
                 "Test application has exit with return code `{return_code}`".format(
@@ -154,12 +175,13 @@ class WaitTraceTestApplication:
         self._has_returned = True
 
     @property
-    def vpid(self) -> int:
+    def vpid(self):
+        # type: () -> int
         return self._process.pid
 
     @staticmethod
     def _compat_open_path(path):
-        # type: (pathlib.Path)
+        # type: (pathlib.Path) -> pathlib.Path | str
         """
         The builtin open() in python >= 3.6 expects a path-like object while
         prior versions expect a string or bytes object. Return the correct type
@@ -178,16 +200,61 @@ class WaitTraceTestApplication:
             self._process.wait()
 
 
+class TraceTestApplication:
+    """
+    Create an application that emits events as soon as it is launched. In most
+    scenarios, it is preferable to use a WaitTraceTestApplication.
+    """
+
+    def __init__(self, binary_path, environment):
+        # type: (pathlib.Path, Environment)
+        self._environment = environment  # type: Environment
+        self._has_returned = False
+
+        test_app_env = os.environ.copy()
+        test_app_env["LTTNG_HOME"] = str(environment.lttng_home_location)
+        # Make sure the app is blocked until it is properly registered to
+        # the session daemon.
+        test_app_env["LTTNG_UST_REGISTER_TIMEOUT"] = "-1"
+
+        test_app_args = [str(binary_path)]
+
+        self._process = subprocess.Popen(
+            test_app_args, env=test_app_env
+        )  # type: subprocess.Popen
+
+    def wait_for_exit(self):
+        # type: () -> None
+        if self._process.wait() != 0:
+            raise RuntimeError(
+                "Test application has exit with return code `{return_code}`".format(
+                    return_code=self._process.returncode
+                )
+            )
+        self._has_returned = True
+
+    def __del__(self):
+        if not self._has_returned:
+            # This is potentially racy if the pid has been recycled. However,
+            # we can't use pidfd_open since it is only available in python >= 3.9.
+            self._process.kill()
+            self._process.wait()
+
+
 class ProcessOutputConsumer(threading.Thread, logger._Logger):
     def __init__(
-        self, process: subprocess.Popen, name: str, log: Callable[[str], None]
+        self,
+        process,  # type: subprocess.Popen
+        name,  # type: str
+        log,  # type: Callable[[str], None]
     ):
         threading.Thread.__init__(self)
         self._prefix = name
         logger._Logger.__init__(self, log)
         self._process = process
 
-    def run(self) -> None:
+    def run(self):
+        # type: () -> None
         while self._process.poll() is None:
             assert self._process.stdout
             line = self._process.stdout.readline().decode("utf-8").replace("\n", "")
@@ -198,7 +265,9 @@ class ProcessOutputConsumer(threading.Thread, logger._Logger):
 # Generate a temporary environment in which to execute a test.
 class _Environment(logger._Logger):
     def __init__(
-        self, with_sessiond: bool, log: Optional[Callable[[str], None]] = None
+        self,
+        with_sessiond,  # type: bool
+        log=None,  # type: Optional[Callable[[str], None]]
     ):
         super().__init__(log)
         signal.signal(signal.SIGTERM, self._handle_termination_signal)
@@ -206,26 +275,31 @@ class _Environment(logger._Logger):
 
         # Assumes the project's hierarchy to this file is:
         # tests/utils/python/this_file
-        self._project_root: pathlib.Path = pathlib.Path(__file__).absolute().parents[3]
-        self._lttng_home: Optional[TemporaryDirectory] = TemporaryDirectory(
+        self._project_root = (
+            pathlib.Path(__file__).absolute().parents[3]
+        )  # type: pathlib.Path
+        self._lttng_home = TemporaryDirectory(
             "lttng_test_env_home"
-        )
+        )  # type: Optional[TemporaryDirectory]
 
-        self._sessiond: Optional[subprocess.Popen[bytes]] = (
+        self._sessiond = (
             self._launch_lttng_sessiond() if with_sessiond else None
-        )
+        )  # type: Optional[subprocess.Popen[bytes]]
 
     @property
-    def lttng_home_location(self) -> pathlib.Path:
+    def lttng_home_location(self):
+        # type: () -> pathlib.Path
         if self._lttng_home is None:
             raise RuntimeError("Attempt to access LTTng home after clean-up")
         return self._lttng_home.path
 
     @property
-    def lttng_client_path(self) -> pathlib.Path:
+    def lttng_client_path(self):
+        # type: () -> pathlib.Path
         return self._project_root / "src" / "bin" / "lttng" / "lttng"
 
-    def create_temporary_directory(self, prefix: Optional[str] = None) -> pathlib.Path:
+    def create_temporary_directory(self, prefix=None):
+        # type: (Optional[str]) -> pathlib.Path
         # Simply return a path that is contained within LTTNG_HOME; it will
         # be destroyed when the temporary home goes out of scope.
         assert self._lttng_home
@@ -239,7 +313,8 @@ class _Environment(logger._Logger):
     # Unpack a list of environment variables from a string
     # such as "HELLO=is_it ME='/you/are/looking/for'"
     @staticmethod
-    def _unpack_env_vars(env_vars_string: str) -> List[Tuple[str, str]]:
+    def _unpack_env_vars(env_vars_string):
+        # type: (str) -> List[Tuple[str, str]]
         unpacked_vars = []
         for var in shlex.split(env_vars_string):
             equal_position = var.find("=")
@@ -258,7 +333,8 @@ class _Environment(logger._Logger):
 
         return unpacked_vars
 
-    def _launch_lttng_sessiond(self) -> Optional[subprocess.Popen]:
+    def _launch_lttng_sessiond(self):
+        # type: () -> Optional[subprocess.Popen]
         is_64bits_host = sys.maxsize > 2**32
 
         sessiond_path = (
@@ -295,41 +371,38 @@ class _Environment(logger._Logger):
         sessiond_env["LTTNG_HOME"] = str(self._lttng_home.path)
 
         wait_queue = _SignalWaitQueue()
-        signal.signal(signal.SIGUSR1, wait_queue.signal)
-
-        self._log(
-            "Launching session daemon with LTTNG_HOME=`{home_dir}`".format(
-                home_dir=str(self._lttng_home.path)
+        with wait_queue.intercept_signal(signal.SIGUSR1):
+            self._log(
+                "Launching session daemon with LTTNG_HOME=`{home_dir}`".format(
+                    home_dir=str(self._lttng_home.path)
+                )
+            )
+            process = subprocess.Popen(
+                [
+                    str(sessiond_path),
+                    consumerd_path_option_name,
+                    str(consumerd_path),
+                    "--sig-parent",
+                ],
+                stdout=subprocess.PIPE,
+                stderr=subprocess.STDOUT,
+                env=sessiond_env,
             )
-        )
-        process = subprocess.Popen(
-            [
-                str(sessiond_path),
-                consumerd_path_option_name,
-                str(consumerd_path),
-                "--sig-parent",
-            ],
-            stdout=subprocess.PIPE,
-            stderr=subprocess.STDOUT,
-            env=sessiond_env,
-        )
 
-        if self._logging_function:
-            self._sessiond_output_consumer: Optional[
-                ProcessOutputConsumer
-            ] = ProcessOutputConsumer(process, "lttng-sessiond", self._logging_function)
-            self._sessiond_output_consumer.daemon = True
-            self._sessiond_output_consumer.start()
+            if self._logging_function:
+                self._sessiond_output_consumer = ProcessOutputConsumer(
+                    process, "lttng-sessiond", self._logging_function
+                )  # type: Optional[ProcessOutputConsumer]
+                self._sessiond_output_consumer.daemon = True
+                self._sessiond_output_consumer.start()
 
-        # Wait for SIGUSR1, indicating the sessiond is ready to proceed
-        wait_queue.wait_for_signal()
-        signal.signal(signal.SIGUSR1, wait_queue.signal)
+            # Wait for SIGUSR1, indicating the sessiond is ready to proceed
+            wait_queue.wait_for_signal()
 
         return process
 
-    def _handle_termination_signal(
-        self, signal_number: int, frame: Optional[FrameType]
-    ) -> None:
+    def _handle_termination_signal(self, signal_number, frame):
+        # type: (int, Optional[FrameType]) -> None
         self._log(
             "Killed by {signal_name} signal, cleaning-up".format(
                 signal_name=signal.strsignal(signal_number)
@@ -337,9 +410,8 @@ class _Environment(logger._Logger):
         )
         self._cleanup()
 
-    def launch_wait_trace_test_application(
-        self, event_count: int
-    ) -> WaitTraceTestApplication:
+    def launch_wait_trace_test_application(self, event_count):
+        # type: (int) -> WaitTraceTestApplication
         """
         Launch an application that will wait before tracing `event_count` events.
         """
@@ -354,8 +426,24 @@ class _Environment(logger._Logger):
             self,
         )
 
+    def launch_trace_test_constructor_application(self):
+        # type () -> TraceTestApplication
+        """
+        Launch an application that will trace from within constructors.
+        """
+        return TraceTestApplication(
+            self._project_root
+            / "tests"
+            / "utils"
+            / "testapp"
+            / "gen-ust-events-constructor"
+            / "gen-ust-events-constructor",
+            self,
+        )
+
     # Clean-up managed processes
-    def _cleanup(self) -> None:
+    def _cleanup(self):
+        # type: () -> None
         if self._sessiond and self._sessiond.poll() is None:
             # The session daemon is alive; kill it.
             self._log(
@@ -380,7 +468,8 @@ class _Environment(logger._Logger):
 
 
 @contextlib.contextmanager
-def test_environment(with_sessiond: bool, log: Optional[Callable[[str], None]] = None):
+def test_environment(with_sessiond, log=None):
+    # type: (bool, Optional[Callable[[str], None]]) -> Iterator[_Environment]
     env = _Environment(with_sessiond, log)
     try:
         yield env
This page took 0.027576 seconds and 4 git commands to generate.