Add userspace namespace contexts
[lttng-ust.git] / liblttng-ust-fork / ustfork.c
index 5e6acba16e87fb9893ee7c9fe3b756afae12dd38..25f9d4cc04c31d2589e36d148684bf61d846c5aa 100644 (file)
@@ -1,6 +1,6 @@
 /*
  * Copyright (C) 2009  Pierre-Marc Fournier
- * Copyright (C) 2011  Mathieu Desnoyers <mathieu.desnoyers@efficios.com>
+ * Copyright (C) 2011-2012  Mathieu Desnoyers <mathieu.desnoyers@efficios.com>
  *
  * This library is free software; you can redistribute it and/or
  * modify it under the terms of the GNU Lesser General Public
  */
 
 #define _GNU_SOURCE
-#include <dlfcn.h>
+#include <lttng/ust-dlfcn.h>
 #include <unistd.h>
 #include <stdio.h>
 #include <signal.h>
 #include <sched.h>
 #include <stdarg.h>
-#include "usterr.h"
+#include <errno.h>
 
 #include <lttng/ust.h>
 
-struct user_desc;
-
 pid_t fork(void)
 {
        static pid_t (*plibc_func)(void) = NULL;
-       ust_fork_info_t fork_info;
+       sigset_t sigset;
        pid_t retval;
+       int saved_errno;
 
        if (plibc_func == NULL) {
                plibc_func = dlsym(RTLD_NEXT, "fork");
                if (plibc_func == NULL) {
                        fprintf(stderr, "libustfork: unable to find \"fork\" symbol\n");
+                       errno = ENOSYS;
                        return -1;
                }
        }
 
-       ust_before_fork(&fork_info);
+       ust_before_fork(&sigset);
        /* Do the real fork */
        retval = plibc_func();
+       saved_errno = errno;
        if (retval == 0) {
                /* child */
-               ust_after_fork_child(&fork_info);
+               ust_after_fork_child(&sigset);
+       } else {
+               ust_after_fork_parent(&sigset);
+       }
+       errno = saved_errno;
+       return retval;
+}
+
+int daemon(int nochdir, int noclose)
+{
+       static int (*plibc_func)(int nochdir, int noclose) = NULL;
+       sigset_t sigset;
+       int retval;
+       int saved_errno;
+
+       if (plibc_func == NULL) {
+               plibc_func = dlsym(RTLD_NEXT, "daemon");
+               if (plibc_func == NULL) {
+                       fprintf(stderr, "libustfork: unable to find \"daemon\" symbol\n");
+                       errno = ENOSYS;
+                       return -1;
+               }
+       }
+
+       ust_before_fork(&sigset);
+       /* Do the real daemon call */
+       retval = plibc_func(nochdir, noclose);
+       saved_errno = errno;
+       if (retval == 0) {
+               /* child, parent called _exit() directly */
+               ust_after_fork_child(&sigset);
        } else {
-               ust_after_fork_parent(&fork_info);
+               /* on error in the parent */
+               ust_after_fork_parent(&sigset);
        }
+       errno = saved_errno;
        return retval;
 }
 
+#ifdef __linux__
+
+struct user_desc;
+
 struct ustfork_clone_info {
        int (*fn)(void *);
        void *arg;
-       ust_fork_info_t fork_info;
+       sigset_t sigset;
 };
 
 static int clone_fn(void *arg)
@@ -67,7 +104,7 @@ static int clone_fn(void *arg)
        struct ustfork_clone_info *info = (struct ustfork_clone_info *) arg;
 
        /* clone is now done and we are in child */
-       ust_after_fork_child(&info->fork_info);
+       ust_after_fork_child(&info->sigset);
        return info->fn(info->arg);
 }
 
@@ -83,6 +120,7 @@ int clone(int (*fn)(void *), void *child_stack, int flags, void *arg, ...)
        /* end of var args */
        va_list ap;
        int retval;
+       int saved_errno;
 
        va_start(ap, arg);
        ptid = va_arg(ap, pid_t *);
@@ -94,6 +132,7 @@ int clone(int (*fn)(void *), void *child_stack, int flags, void *arg, ...)
                plibc_func = dlsym(RTLD_NEXT, "clone");
                if (plibc_func == NULL) {
                        fprintf(stderr, "libustfork: unable to find \"clone\" symbol.\n");
+                       errno = ENOSYS;
                        return -1;
                }
        }
@@ -105,15 +144,110 @@ int clone(int (*fn)(void *), void *child_stack, int flags, void *arg, ...)
                 */
                retval = plibc_func(fn, child_stack, flags, arg, ptid,
                                tls, ctid);
+               saved_errno = errno;
        } else {
                /* Creating a real process, we need to intervene. */
-               struct ustfork_clone_info info = { fn: fn, arg: arg };
+               struct ustfork_clone_info info = { .fn = fn, .arg = arg };
 
-               ust_before_fork(&info.fork_info);
+               ust_before_fork(&info.sigset);
                retval = plibc_func(clone_fn, child_stack, flags, &info,
                                ptid, tls, ctid);
+               saved_errno = errno;
                /* The child doesn't get here. */
-               ust_after_fork_parent(&info.fork_info);
+               ust_after_fork_parent(&info.sigset);
+       }
+       errno = saved_errno;
+       return retval;
+}
+
+int setns(int fd, int nstype)
+{
+       static int (*plibc_func)(int fd, int nstype) = NULL;
+       int retval;
+       int saved_errno;
+
+       if (plibc_func == NULL) {
+               plibc_func = dlsym(RTLD_NEXT, "setns");
+               if (plibc_func == NULL) {
+                       fprintf(stderr, "libustfork: unable to find \"setns\" symbol\n");
+                       errno = ENOSYS;
+                       return -1;
+               }
        }
+
+       /* Do the real setns */
+       retval = plibc_func(fd, nstype);
+       saved_errno = errno;
+
+       ust_after_setns();
+
+       errno = saved_errno;
        return retval;
 }
+
+int unshare(int flags)
+{
+       static int (*plibc_func)(int flags) = NULL;
+       int retval;
+       int saved_errno;
+
+       if (plibc_func == NULL) {
+               plibc_func = dlsym(RTLD_NEXT, "unshare");
+               if (plibc_func == NULL) {
+                       fprintf(stderr, "libustfork: unable to find \"unshare\" symbol\n");
+                       errno = ENOSYS;
+                       return -1;
+               }
+       }
+
+       /* Do the real setns */
+       retval = plibc_func(flags);
+       saved_errno = errno;
+
+       ust_after_unshare();
+
+       errno = saved_errno;
+       return retval;
+}
+
+#elif defined (__FreeBSD__)
+
+pid_t rfork(int flags)
+{
+       static pid_t (*plibc_func)(void) = NULL;
+       sigset_t sigset;
+       pid_t retval;
+       int saved_errno;
+
+       if (plibc_func == NULL) {
+               plibc_func = dlsym(RTLD_NEXT, "rfork");
+               if (plibc_func == NULL) {
+                       fprintf(stderr, "libustfork: unable to find \"rfork\" symbol\n");
+                       errno = ENOSYS;
+                       return -1;
+               }
+       }
+
+       ust_before_fork(&sigset);
+       /* Do the real rfork */
+       retval = plibc_func();
+       saved_errno = errno;
+       if (retval == 0) {
+               /* child */
+               ust_after_fork_child(&sigset);
+       } else {
+               ust_after_fork_parent(&sigset);
+       }
+       errno = saved_errno;
+       return retval;
+}
+
+/*
+ * On BSD, no need to override vfork, because it runs in the context of
+ * the parent, with parent waiting until execve or exit is executed in
+ * the child.
+ */
+
+#else
+#warning "Unknown OS. You might want to ensure that fork/clone/vfork/fork handling is complete."
+#endif
This page took 0.026486 seconds and 4 git commands to generate.