ustfork: Fix possible race conditions
authorOlivier Dion <odion@efficios.com>
Wed, 9 Aug 2023 21:35:40 +0000 (17:35 -0400)
committerMathieu Desnoyers <mathieu.desnoyers@efficios.com>
Sat, 12 Aug 2023 17:14:58 +0000 (13:14 -0400)
Assuming that `dlsym(RTLD_NEXT, "symbol")' is invariant for "symbol",
then we could think that memory operations on the `plibc_func' pointers can
be safely done without atomics.

However, consider what would happen if a load to a`plibc_func' pointer
is torn apart by the compiler. Then a thread could see:

  1) NULL

  2) The stored value as returned by a dlsym() call

  3) A mix of 1) and 2)

The same goes for other optimizations that a compiler is authorized to
do (e.g. store tearing, load fusing).

One could question whether such race condition is even possible for the
clone(2) wrapper. Indeed, a thread must be cloned to get into
existence. Therefore, the main thread would always store the value of
`plibc_func' at least once before creating the first sibling thread,
preventing any possible race condition for this wrapper. However, this
assume that the main thread will not call the clone system call directly
before calling the libc wrapper! Thus, to be on the safe side, we do the
same for the clone wrapper.

Fix the race conditions by using the uatomic_read/uatomic_set functions,
on access to `plibc_func' pointers.

Change-Id: Ic4be25983b8836d2b333f367af9c18d2f6b75879
Signed-off-by: Olivier Dion <odion@efficios.com>
Signed-off-by: Mathieu Desnoyers <mathieu.desnoyers@efficios.com>
src/lib/lttng-ust-fork/ustfork.c

index 321ffc30c0fd993d7be43a62fa9063fbe3a2e5ca..9508cc7520eaeb23145814c9976cd24b6501a2a3 100644 (file)
 
 #include <lttng/ust-fork.h>
 
+#include <urcu/uatomic.h>
+
 pid_t fork(void)
 {
        static pid_t (*plibc_func)(void) = NULL;
+       pid_t (*func)(void);
        sigset_t sigset;
        pid_t retval;
        int saved_errno;
 
-       if (plibc_func == NULL) {
-               plibc_func = dlsym(RTLD_NEXT, "fork");
-               if (plibc_func == NULL) {
+       func = uatomic_read(plibc_func);
+       if (func == NULL) {
+               func = dlsym(RTLD_NEXT, "fork");
+               if (func == NULL) {
                        fprintf(stderr, "libustfork: unable to find \"fork\" symbol\n");
                        errno = ENOSYS;
                        return -1;
                }
+               uatomic_set(&plibc_func, func);
        }
 
        lttng_ust_before_fork(&sigset);
        /* Do the real fork */
-       retval = plibc_func();
+       retval = func();
        saved_errno = errno;
        if (retval == 0) {
                /* child */
@@ -51,22 +56,25 @@ pid_t fork(void)
 int daemon(int nochdir, int noclose)
 {
        static int (*plibc_func)(int nochdir, int noclose) = NULL;
+       int (*func)(int nochdir, int noclose);
        sigset_t sigset;
        int retval;
        int saved_errno;
 
-       if (plibc_func == NULL) {
-               plibc_func = dlsym(RTLD_NEXT, "daemon");
-               if (plibc_func == NULL) {
+       func = uatomic_read(plibc_func);
+       if (func == NULL) {
+               func = dlsym(RTLD_NEXT, "daemon");
+               if (func == NULL) {
                        fprintf(stderr, "libustfork: unable to find \"daemon\" symbol\n");
                        errno = ENOSYS;
                        return -1;
                }
+               uatomic_set(&plibc_func, func);
        }
 
        lttng_ust_before_fork(&sigset);
        /* Do the real daemon call */
-       retval = plibc_func(nochdir, noclose);
+       retval = func(nochdir, noclose);
        saved_errno = errno;
        if (retval == 0) {
                /* child, parent called _exit() directly */
@@ -82,20 +90,23 @@ int daemon(int nochdir, int noclose)
 int setuid(uid_t uid)
 {
        static int (*plibc_func)(uid_t uid) = NULL;
+       int (*func)(uid_t uid);
        int retval;
        int saved_errno;
 
-       if (plibc_func == NULL) {
-               plibc_func = dlsym(RTLD_NEXT, "setuid");
-               if (plibc_func == NULL) {
+       func = uatomic_read(plibc_func);
+       if (func == NULL) {
+               func = dlsym(RTLD_NEXT, "setuid");
+               if (func == NULL) {
                        fprintf(stderr, "libustfork: unable to find \"setuid\" symbol\n");
                        errno = ENOSYS;
                        return -1;
                }
+               uatomic_set(&plibc_func, func);
        }
 
        /* Do the real setuid */
-       retval = plibc_func(uid);
+       retval = func(uid);
        saved_errno = errno;
 
        lttng_ust_after_setuid();
@@ -107,20 +118,23 @@ int setuid(uid_t uid)
 int setgid(gid_t gid)
 {
        static int (*plibc_func)(gid_t gid) = NULL;
+       int (*func)(gid_t gid);
        int retval;
        int saved_errno;
 
-       if (plibc_func == NULL) {
-               plibc_func = dlsym(RTLD_NEXT, "setgid");
-               if (plibc_func == NULL) {
+       func = uatomic_read(plibc_func);
+       if (func == NULL) {
+               func = dlsym(RTLD_NEXT, "setgid");
+               if (func == NULL) {
                        fprintf(stderr, "libustfork: unable to find \"setgid\" symbol\n");
                        errno = ENOSYS;
                        return -1;
                }
+               uatomic_set(&plibc_func, func);
        }
 
        /* Do the real setgid */
-       retval = plibc_func(gid);
+       retval = func(gid);
        saved_errno = errno;
 
        lttng_ust_after_setgid();
@@ -132,20 +146,23 @@ int setgid(gid_t gid)
 int seteuid(uid_t euid)
 {
        static int (*plibc_func)(uid_t euid) = NULL;
+       int (*func)(uid_t euid);
        int retval;
        int saved_errno;
 
-       if (plibc_func == NULL) {
-               plibc_func = dlsym(RTLD_NEXT, "seteuid");
-               if (plibc_func == NULL) {
+       func = uatomic_read(plibc_func);
+       if (func == NULL) {
+               func = dlsym(RTLD_NEXT, "seteuid");
+               if (func == NULL) {
                        fprintf(stderr, "libustfork: unable to find \"seteuid\" symbol\n");
                        errno = ENOSYS;
                        return -1;
                }
+               uatomic_set(&plibc_func, func);
        }
 
        /* Do the real seteuid */
-       retval = plibc_func(euid);
+       retval = func(euid);
        saved_errno = errno;
 
        lttng_ust_after_seteuid();
@@ -157,20 +174,23 @@ int seteuid(uid_t euid)
 int setegid(gid_t egid)
 {
        static int (*plibc_func)(gid_t egid) = NULL;
+       int (*func)(gid_t egid);
        int retval;
        int saved_errno;
 
-       if (plibc_func == NULL) {
-               plibc_func = dlsym(RTLD_NEXT, "setegid");
-               if (plibc_func == NULL) {
+       func = uatomic_read(plibc_func);
+       if (func == NULL) {
+               func = dlsym(RTLD_NEXT, "setegid");
+               if (func == NULL) {
                        fprintf(stderr, "libustfork: unable to find \"setegid\" symbol\n");
                        errno = ENOSYS;
                        return -1;
                }
+               uatomic_set(&plibc_func, func);
        }
 
        /* Do the real setegid */
-       retval = plibc_func(egid);
+       retval = func(egid);
        saved_errno = errno;
 
        lttng_ust_after_setegid();
@@ -182,20 +202,23 @@ int setegid(gid_t egid)
 int setreuid(uid_t ruid, uid_t euid)
 {
        static int (*plibc_func)(uid_t ruid, uid_t euid) = NULL;
+       int (*func)(uid_t ruid, uid_t euid);
        int retval;
        int saved_errno;
 
-       if (plibc_func == NULL) {
-               plibc_func = dlsym(RTLD_NEXT, "setreuid");
-               if (plibc_func == NULL) {
+       func = uatomic_read(plibc_func);
+       if (func == NULL) {
+               func = dlsym(RTLD_NEXT, "setreuid");
+               if (func == NULL) {
                        fprintf(stderr, "libustfork: unable to find \"setreuid\" symbol\n");
                        errno = ENOSYS;
                        return -1;
                }
+               uatomic_set(&plibc_func, func);
        }
 
        /* Do the real setreuid */
-       retval = plibc_func(ruid, euid);
+       retval = func(ruid, euid);
        saved_errno = errno;
 
        lttng_ust_after_setreuid();
@@ -207,20 +230,23 @@ int setreuid(uid_t ruid, uid_t euid)
 int setregid(gid_t rgid, gid_t egid)
 {
        static int (*plibc_func)(gid_t rgid, gid_t egid) = NULL;
+       int (*func)(gid_t rgid, gid_t egid);
        int retval;
        int saved_errno;
 
-       if (plibc_func == NULL) {
-               plibc_func = dlsym(RTLD_NEXT, "setregid");
-               if (plibc_func == NULL) {
+       func = uatomic_read(plibc_func);
+       if (func == NULL) {
+               func = dlsym(RTLD_NEXT, "setregid");
+               if (func == NULL) {
                        fprintf(stderr, "libustfork: unable to find \"setregid\" symbol\n");
                        errno = ENOSYS;
                        return -1;
                }
+               uatomic_set(&plibc_func, func);
        }
 
        /* Do the real setregid */
-       retval = plibc_func(rgid, egid);
+       retval = func(rgid, egid);
        saved_errno = errno;
 
        lttng_ust_after_setregid();
@@ -253,6 +279,9 @@ int clone(int (*fn)(void *), void *child_stack, int flags, void *arg, ...)
        static int (*plibc_func)(int (*fn)(void *), void *child_stack,
                        int flags, void *arg, pid_t *ptid,
                        struct user_desc *tls, pid_t *ctid) = NULL;
+       int (*func)(int (*fn)(void *), void *child_stack,
+                       int flags, void *arg, pid_t *ptid,
+                       struct user_desc *tls, pid_t *ctid);
        /* var args */
        pid_t *ptid;
        struct user_desc *tls;
@@ -268,13 +297,15 @@ int clone(int (*fn)(void *), void *child_stack, int flags, void *arg, ...)
        ctid = va_arg(ap, pid_t *);
        va_end(ap);
 
-       if (plibc_func == NULL) {
-               plibc_func = dlsym(RTLD_NEXT, "clone");
-               if (plibc_func == NULL) {
+       func = uatomic_read(plibc_func);
+       if (func == NULL) {
+               func = dlsym(RTLD_NEXT, "clone");
+               if (func == NULL) {
                        fprintf(stderr, "libustfork: unable to find \"clone\" symbol.\n");
                        errno = ENOSYS;
                        return -1;
                }
+               uatomic_set(&plibc_func, func);
        }
 
        if (flags & CLONE_VM) {
@@ -282,16 +313,16 @@ int clone(int (*fn)(void *), void *child_stack, int flags, void *arg, ...)
                 * Creating a thread, no need to intervene, just pass on
                 * the arguments.
                 */
-               retval = plibc_func(fn, child_stack, flags, arg, ptid,
-                               tls, ctid);
+               retval = func(fn, child_stack, flags, arg, ptid,
+                       tls, ctid);
                saved_errno = errno;
        } else {
                /* Creating a real process, we need to intervene. */
                struct ustfork_clone_info info = { .fn = fn, .arg = arg };
 
                lttng_ust_before_fork(&info.sigset);
-               retval = plibc_func(clone_fn, child_stack, flags, &info,
-                               ptid, tls, ctid);
+               retval = func(clone_fn, child_stack, flags, &info,
+                       ptid, tls, ctid);
                saved_errno = errno;
                /* The child doesn't get here. */
                lttng_ust_after_fork_parent(&info.sigset);
@@ -303,20 +334,23 @@ int clone(int (*fn)(void *), void *child_stack, int flags, void *arg, ...)
 int setns(int fd, int nstype)
 {
        static int (*plibc_func)(int fd, int nstype) = NULL;
+       int (*func)(int fd, int nstype);
        int retval;
        int saved_errno;
 
-       if (plibc_func == NULL) {
-               plibc_func = dlsym(RTLD_NEXT, "setns");
-               if (plibc_func == NULL) {
+       func = uatomic_read(plibc_func);
+       if (func == NULL) {
+               func = dlsym(RTLD_NEXT, "setns");
+               if (func == NULL) {
                        fprintf(stderr, "libustfork: unable to find \"setns\" symbol\n");
                        errno = ENOSYS;
                        return -1;
                }
+               uatomic_set(&plibc_func, func);
        }
 
        /* Do the real setns */
-       retval = plibc_func(fd, nstype);
+       retval = func(fd, nstype);
        saved_errno = errno;
 
        lttng_ust_after_setns();
@@ -328,20 +362,23 @@ int setns(int fd, int nstype)
 int unshare(int flags)
 {
        static int (*plibc_func)(int flags) = NULL;
+       int (*func)(int flags);
        int retval;
        int saved_errno;
 
-       if (plibc_func == NULL) {
-               plibc_func = dlsym(RTLD_NEXT, "unshare");
-               if (plibc_func == NULL) {
+       func = uatomic_read(plibc_func);
+       if (func == NULL) {
+               func = dlsym(RTLD_NEXT, "unshare");
+               if (func == NULL) {
                        fprintf(stderr, "libustfork: unable to find \"unshare\" symbol\n");
                        errno = ENOSYS;
                        return -1;
                }
+               uatomic_set(&plibc_func, func);
        }
 
        /* Do the real setns */
-       retval = plibc_func(flags);
+       retval = func(flags);
        saved_errno = errno;
 
        lttng_ust_after_unshare();
@@ -353,20 +390,23 @@ int unshare(int flags)
 int setresuid(uid_t ruid, uid_t euid, uid_t suid)
 {
        static int (*plibc_func)(uid_t ruid, uid_t euid, uid_t suid) = NULL;
+       int (*func)(uid_t ruid, uid_t euid, uid_t suid);
        int retval;
        int saved_errno;
 
-       if (plibc_func == NULL) {
-               plibc_func = dlsym(RTLD_NEXT, "setresuid");
-               if (plibc_func == NULL) {
+       func = uatomic_read(plibc_func);
+       if (func == NULL) {
+               func = dlsym(RTLD_NEXT, "setresuid");
+               if (func == NULL) {
                        fprintf(stderr, "libustfork: unable to find \"setresuid\" symbol\n");
                        errno = ENOSYS;
                        return -1;
                }
+               uatomic_set(&plibc_func, func);
        }
 
        /* Do the real setresuid */
-       retval = plibc_func(ruid, euid, suid);
+       retval = func(ruid, euid, suid);
        saved_errno = errno;
 
        lttng_ust_after_setresuid();
@@ -378,20 +418,23 @@ int setresuid(uid_t ruid, uid_t euid, uid_t suid)
 int setresgid(gid_t rgid, gid_t egid, gid_t sgid)
 {
        static int (*plibc_func)(gid_t rgid, gid_t egid, gid_t sgid) = NULL;
+       int (*func)(gid_t rgid, gid_t egid, gid_t sgid);
        int retval;
        int saved_errno;
 
-       if (plibc_func == NULL) {
-               plibc_func = dlsym(RTLD_NEXT, "setresgid");
-               if (plibc_func == NULL) {
+       func = uatomic_read(plibc_func);
+       if (func == NULL) {
+               func = dlsym(RTLD_NEXT, "setresgid");
+               if (func == NULL) {
                        fprintf(stderr, "libustfork: unable to find \"setresgid\" symbol\n");
                        errno = ENOSYS;
                        return -1;
                }
+               uatomic_set(&plibc_func, func);
        }
 
        /* Do the real setresgid */
-       retval = plibc_func(rgid, egid, sgid);
+       retval = func(rgid, egid, sgid);
        saved_errno = errno;
 
        lttng_ust_after_setresgid();
@@ -405,22 +448,25 @@ int setresgid(gid_t rgid, gid_t egid, gid_t sgid)
 pid_t rfork(int flags)
 {
        static pid_t (*plibc_func)(int flags) = NULL;
+       pid_t (*func)(int flags);
        sigset_t sigset;
        pid_t retval;
        int saved_errno;
 
-       if (plibc_func == NULL) {
-               plibc_func = dlsym(RTLD_NEXT, "rfork");
-               if (plibc_func == NULL) {
+       func = uatomic_read(plibc_func);
+       if (func == NULL) {
+               func = dlsym(RTLD_NEXT, "rfork");
+               if (func == NULL) {
                        fprintf(stderr, "libustfork: unable to find \"rfork\" symbol\n");
                        errno = ENOSYS;
                        return -1;
                }
+               uatomic_set(&plibc_func, func);
        }
 
        lttng_ust_before_fork(&sigset);
        /* Do the real rfork */
-       retval = plibc_func(flags);
+       retval = func(flags);
        saved_errno = errno;
        if (retval == 0) {
                /* child */
This page took 0.03338 seconds and 4 git commands to generate.