prio heap default alloc min value fix
[lttng-modules.git] / lib / prio_heap / prio_heap.c
index 8945c2a019804f7d0afe9e126b8c4a1b4dcce79d..58d5d6ae9df05a112026196b252a204cbd432407 100644 (file)
 /*
- * LICENSING: this file is copied from the Linux kernel. We should therefore
- * assume a GPLv2 license for the code that comes from the Linux mainline.
- */
-
-/*
- * Static-sized priority heap containing pointers. Based on CLR, chapter 7.
+ * prio_heap.c
+ *
+ * Priority heap containing pointers. Based on CLRS, chapter 6.
+ *
+ * Copyright 2011 - Mathieu Desnoyers <mathieu.desnoyers@efficios.com>
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
  */
 
 #include <linux/slab.h>
-#include <linux/prio_heap.h>
+#include "prio_heap.h"
+
+static
+size_t parent(size_t i)
+{
+       return i >> 1;
+}
+
+static
+size_t left(size_t i)
+{
+       return i << 1;
+}
+
+static
+size_t right(size_t i)
+{
+       return (i << 1) + 1;
+}
 
-int heap_init(struct ptr_heap *heap, size_t size, gfp_t gfp_mask,
-             int (*gt)(void *, void *))
+static
+int heap_grow(struct ptr_heap *heap, size_t new_len)
 {
-       heap->ptrs = kmalloc(size, gfp_mask);
-       if (!heap->ptrs)
+       void **new_ptrs;
+
+       if (heap->alloc_len >= new_len)
+               return 0;
+
+       heap->alloc_len = max_t(size_t, new_len, heap->alloc_len << 1);
+       new_ptrs = kmalloc(heap->alloc_len * sizeof(void *), heap->gfpmask);
+       if (!new_ptrs)
                return -ENOMEM;
-       heap->size = 0;
-       heap->max = size / sizeof(void *);
-       heap->gt = gt;
+       if (heap->ptrs)
+               memcpy(new_ptrs, heap->ptrs, heap->len * sizeof(void *));
+       kfree(heap->ptrs);
+       heap->ptrs = new_ptrs;
+       return 0;
+}
+
+static
+int heap_set_len(struct ptr_heap *heap, size_t new_len)
+{
+       int ret;
+
+       ret = heap_grow(heap, new_len);
+       if (ret)
+               return ret;
+       heap->len = new_len;
        return 0;
 }
 
+int heap_init(struct ptr_heap *heap, size_t alloc_len,
+             gfp_t gfpmask, int gt(void *a, void *b))
+{
+       heap->ptrs = NULL;
+       heap->len = 0;
+       heap->alloc_len = 0;
+       heap->gt = gt;
+       /*
+        * Minimum size allocated is 1 entry to ensure memory allocation
+        * never fails within heap_replace_max.
+        */
+       return heap_grow(heap, max_t(size_t, 1, alloc_len));
+}
+
 void heap_free(struct ptr_heap *heap)
 {
        kfree(heap->ptrs);
 }
 
-static void heapify(struct ptr_heap *heap, int pos)
+static void heapify(struct ptr_heap *heap, size_t i)
 {
        void **ptrs = heap->ptrs;
-       void *p = ptrs[pos];
-
-       while (1) {
-               int left = 2 * pos + 1;
-               int right = 2 * pos + 2;
-               int largest = pos;
-               if (left < heap->size && heap->gt(ptrs[left], p))
-                       largest = left;
-               if (right < heap->size && heap->gt(ptrs[right], ptrs[largest]))
-                       largest = right;
-               if (largest == pos)
+       size_t l, r, largest;
+
+       for (;;) {
+               l = left(i);
+               r = right(i);
+               if (l <= heap->len && ptrs[l] > ptrs[i])
+                       largest = l;
+               else
+                       largest = i;
+               if (r <= heap->len && ptrs[r] > ptrs[largest])
+                       largest = r;
+               if (largest != i) {
+                       void *tmp;
+
+                       tmp = ptrs[i];
+                       ptrs[i] = ptrs[largest];
+                       ptrs[largest] = tmp;
+                       i = largest;
+                       continue;
+               } else {
                        break;
-               /* Push p down the heap one level and bump one up */
-               ptrs[pos] = ptrs[largest];
-               ptrs[largest] = p;
-               pos = largest;
+               }
        }
 }
 
@@ -54,9 +120,9 @@ void *heap_replace_max(struct ptr_heap *heap, void *p)
        void *res;
        void **ptrs = heap->ptrs;
 
-       if (!heap->size) {
+       if (!heap->len) {
+               (void) heap_set_len(heap, 1);
                ptrs[0] = p;
-               heap->size = 1;
                return NULL;
        }
 
@@ -67,66 +133,54 @@ void *heap_replace_max(struct ptr_heap *heap, void *p)
        return res;
 }
 
-void *heap_insert(struct ptr_heap *heap, void *p)
+int heap_insert(struct ptr_heap *heap, void *p)
 {
        void **ptrs = heap->ptrs;
-       int pos;
-
-       if (heap->size < heap->max) {
-               /* Heap insertion */
-               pos = heap->size++;
-               while (pos > 0 && heap->gt(p, ptrs[(pos-1)/2])) {
-                       ptrs[pos] = ptrs[(pos-1)/2];
-                       pos = (pos-1)/2;
-               }
-               ptrs[pos] = p;
-               return NULL;
-       }
-
-       /* The heap is full, so something will have to be dropped */
-
-       /* If the new pointer is greater than the current max, drop it */
-       if (heap->gt(p, ptrs[0]))
-               return p;
-
-       /* Replace the current max and heapify */
-       return heap_replace_max(heap, p);
+       int ret;
+
+       ret = heap_set_len(heap, heap->len + 1);
+       if (ret)
+               return ret;
+       /* Add the element to the end */
+       ptrs[heap->len - 1] = p;
+       /* rebalance */
+       heapify(heap, 0);
+       return 0;
 }
 
 void *heap_remove(struct ptr_heap *heap)
 {
        void **ptrs = heap->ptrs;
 
-       switch (heap->size) {
+       switch (heap->len) {
        case 0:
                return NULL;
        case 1:
-               heap->size = 0;
+               (void) heap_set_len(heap, 0);
                return ptrs[0];
        }
-
        /* Shrink, replace the current max by previous last entry and heapify */
-       return heap_replace_max(heap, ptrs[--heap->size]);
+       heap_set_len(heap, heap->len - 1);
+       return heap_replace_max(heap, ptrs[heap->len - 1]);
 }
 
 void *heap_cherrypick(struct ptr_heap *heap, void *p)
 {
        void **ptrs = heap->ptrs;
-       size_t pos, size = heap->size;
+       size_t pos, len = heap->len;
 
-       for (pos = 0; pos < size; pos++)
+       for (pos = 0; pos < len; pos++)
                if (ptrs[pos] == p)
                        goto found;
        return NULL;
 found:
-       if (heap->size == 1) {
-               heap->size = 0;
+       if (heap->len == 1) {
+               (void) heap_set_len(heap, 0);
                return ptrs[0];
        }
-       /*
-        * Replace p with previous last entry and heapify.
-        */
-       ptrs[pos] = ptrs[--heap->size];
+       /* Replace p with previous last entry and heapify. */
+       heap_set_len(heap, heap->len - 1);
+       ptrs[pos] = ptrs[heap->len - 1];
        heapify(heap, pos);
        return p;
 }
This page took 0.028819 seconds and 4 git commands to generate.