Fix: lttng-ust-malloc ip context
[lttng-ust.git] / liblttng-ust-libc-wrapper / lttng-ust-malloc.c
index 54522b4258b69d4d12fbed4ea38b882773d55726..27624a3d821465bb0c1f9ceacfb09489d683f632 100644 (file)
 #include <urcu/system.h>
 #include <urcu/uatomic.h>
 #include <urcu/compiler.h>
+#include <urcu/tls-compat.h>
+#include <urcu/arch.h>
 #include <lttng/align.h>
 
 #define TRACEPOINT_DEFINE
 #define TRACEPOINT_CREATE_PROBES
+#define TP_IP_PARAM ip
 #include "ust_libc.h"
 
 #define STATIC_CALLOC_LEN 4096
@@ -47,6 +50,50 @@ struct alloc_functions {
 static
 struct alloc_functions cur_alloc;
 
+/*
+ * Make sure our own use of the LTS compat layer will not cause infinite
+ * recursion by calling calloc.
+ */
+
+static
+void *static_calloc(size_t nmemb, size_t size);
+
+/*
+ * pthread mutex replacement for URCU tls compat layer.
+ */
+static int ust_malloc_lock;
+
+static __attribute__((unused))
+void ust_malloc_spin_lock(pthread_mutex_t *lock)
+{
+       /*
+        * The memory barrier within cmpxchg takes care of ordering
+        * memory accesses with respect to the start of the critical
+        * section.
+        */
+       while (uatomic_cmpxchg(&ust_malloc_lock, 0, 1) != 0)
+               caa_cpu_relax();
+}
+
+static __attribute__((unused))
+void ust_malloc_spin_unlock(pthread_mutex_t *lock)
+{
+       /*
+        * Ensure memory accesses within the critical section do not
+        * leak outside.
+        */
+       cmm_smp_mb();
+       uatomic_set(&ust_malloc_lock, 0);
+}
+
+#define calloc static_calloc
+#define pthread_mutex_lock ust_malloc_spin_lock
+#define pthread_mutex_unlock ust_malloc_spin_unlock
+static DEFINE_URCU_TLS(int, malloc_nesting);
+#undef ust_malloc_spin_unlock
+#undef ust_malloc_spin_lock
+#undef calloc
+
 /*
  * Static allocator to use when initially executing dlsym(). It keeps a
  * size_t value of each object size prior to the object.
@@ -86,7 +133,6 @@ void *static_calloc(size_t nmemb, size_t size)
        void *retval;
 
        retval = static_calloc_aligned(nmemb, size, 1);
-       tracepoint(ust_libc, calloc, nmemb, size, retval);
        return retval;
 }
 
@@ -96,7 +142,6 @@ void *static_malloc(size_t size)
        void *retval;
 
        retval = static_calloc_aligned(1, size, 1);
-       tracepoint(ust_libc, malloc, size, retval);
        return retval;
 }
 
@@ -104,7 +149,6 @@ static
 void static_free(void *ptr)
 {
        /* no-op. */
-       tracepoint(ust_libc, free, ptr);
 }
 
 static
@@ -133,7 +177,6 @@ void *static_realloc(void *ptr, size_t size)
        if (ptr)
                memcpy(retval, ptr, *old_size);
 end:
-       tracepoint(ust_libc, realloc, ptr, size, retval);
        return retval;
 }
 
@@ -143,29 +186,23 @@ void *static_memalign(size_t alignment, size_t size)
        void *retval;
 
        retval = static_calloc_aligned(1, size, alignment);
-       tracepoint(ust_libc, memalign, alignment, size, retval);
        return retval;
 }
 
 static
 int static_posix_memalign(void **memptr, size_t alignment, size_t size)
 {
-       int retval = 0;
        void *ptr;
 
        /* Check for power of 2, larger than void *. */
        if (alignment & (alignment - 1)
                        || alignment < sizeof(void *)
                        || alignment == 0) {
-               retval = EINVAL;
                goto end;
        }
        ptr = static_calloc_aligned(1, size, alignment);
        *memptr = ptr;
-       if (size && !ptr)
-               retval = ENOMEM;
 end:
-       tracepoint(ust_libc, posix_memalign, *memptr, alignment, size, retval);
        return 0;
 }
 
@@ -214,6 +251,7 @@ void *malloc(size_t size)
 {
        void *retval;
 
+       URCU_TLS(malloc_nesting)++;
        if (cur_alloc.malloc == NULL) {
                lookup_all_symbols();
                if (cur_alloc.malloc == NULL) {
@@ -222,21 +260,29 @@ void *malloc(size_t size)
                }
        }
        retval = cur_alloc.malloc(size);
-       tracepoint(ust_libc, malloc, size, retval);
+       if (URCU_TLS(malloc_nesting) == 1) {
+               tracepoint(ust_libc, malloc, size, retval,
+                       __builtin_return_address(0));
+       }
+       URCU_TLS(malloc_nesting)--;
        return retval;
 }
 
 void free(void *ptr)
 {
-       tracepoint(ust_libc, free, ptr);
-
+       URCU_TLS(malloc_nesting)++;
        /*
         * Check whether the memory was allocated with
         * static_calloc_align, in which case there is nothing to free.
         */
        if (caa_unlikely((char *)ptr >= static_calloc_buf &&
                        (char *)ptr < static_calloc_buf + STATIC_CALLOC_LEN)) {
-               return;
+               goto end;
+       }
+
+       if (URCU_TLS(malloc_nesting) == 1) {
+               tracepoint(ust_libc, free, ptr,
+                       __builtin_return_address(0));
        }
 
        if (cur_alloc.free == NULL) {
@@ -247,12 +293,15 @@ void free(void *ptr)
                }
        }
        cur_alloc.free(ptr);
+end:
+       URCU_TLS(malloc_nesting)--;
 }
 
 void *calloc(size_t nmemb, size_t size)
 {
        void *retval;
 
+       URCU_TLS(malloc_nesting)++;
        if (cur_alloc.calloc == NULL) {
                lookup_all_symbols();
                if (cur_alloc.calloc == NULL) {
@@ -261,7 +310,11 @@ void *calloc(size_t nmemb, size_t size)
                }
        }
        retval = cur_alloc.calloc(nmemb, size);
-       tracepoint(ust_libc, calloc, nmemb, size, retval);
+       if (URCU_TLS(malloc_nesting) == 1) {
+               tracepoint(ust_libc, calloc, nmemb, size, retval,
+                       __builtin_return_address(0));
+       }
+       URCU_TLS(malloc_nesting)--;
        return retval;
 }
 
@@ -269,7 +322,9 @@ void *realloc(void *ptr, size_t size)
 {
        void *retval;
 
-       /* Check whether the memory was allocated with
+       URCU_TLS(malloc_nesting)++;
+       /*
+        * Check whether the memory was allocated with
         * static_calloc_align, in which case there is nothing
         * to free, and we need to copy the old data.
         */
@@ -289,6 +344,13 @@ void *realloc(void *ptr, size_t size)
                if (retval) {
                        memcpy(retval, ptr, *old_size);
                }
+               /*
+                * Mimick that a NULL pointer has been received, so
+                * memory allocation analysis based on the trace don't
+                * get confused by the address from the static
+                * allocator.
+                */
+               ptr = NULL;
                goto end;
        }
 
@@ -301,7 +363,11 @@ void *realloc(void *ptr, size_t size)
        }
        retval = cur_alloc.realloc(ptr, size);
 end:
-       tracepoint(ust_libc, realloc, ptr, size, retval);
+       if (URCU_TLS(malloc_nesting) == 1) {
+               tracepoint(ust_libc, realloc, ptr, size, retval,
+                       __builtin_return_address(0));
+       }
+       URCU_TLS(malloc_nesting)--;
        return retval;
 }
 
@@ -309,6 +375,7 @@ void *memalign(size_t alignment, size_t size)
 {
        void *retval;
 
+       URCU_TLS(malloc_nesting)++;
        if (cur_alloc.memalign == NULL) {
                lookup_all_symbols();
                if (cur_alloc.memalign == NULL) {
@@ -317,7 +384,11 @@ void *memalign(size_t alignment, size_t size)
                }
        }
        retval = cur_alloc.memalign(alignment, size);
-       tracepoint(ust_libc, memalign, alignment, size, retval);
+       if (URCU_TLS(malloc_nesting) == 1) {
+               tracepoint(ust_libc, memalign, alignment, size, retval,
+                       __builtin_return_address(0));
+       }
+       URCU_TLS(malloc_nesting)--;
        return retval;
 }
 
@@ -325,6 +396,7 @@ int posix_memalign(void **memptr, size_t alignment, size_t size)
 {
        int retval;
 
+       URCU_TLS(malloc_nesting)++;
        if (cur_alloc.posix_memalign == NULL) {
                lookup_all_symbols();
                if (cur_alloc.posix_memalign == NULL) {
@@ -333,7 +405,11 @@ int posix_memalign(void **memptr, size_t alignment, size_t size)
                }
        }
        retval = cur_alloc.posix_memalign(memptr, alignment, size);
-       tracepoint(ust_libc, posix_memalign, *memptr, alignment, size, retval);
+       if (URCU_TLS(malloc_nesting) == 1) {
+               tracepoint(ust_libc, posix_memalign, *memptr, alignment, size,
+                       retval, __builtin_return_address(0));
+       }
+       URCU_TLS(malloc_nesting)--;
        return retval;
 }
 
This page took 0.025872 seconds and 4 git commands to generate.