/*
- * LICENSING: this file is copied from the Linux kernel. We should therefore
- * assume a GPLv2 license for the code that comes from the Linux mainline.
+ * prio_heap.c
+ *
+ * Priority heap containing pointers. Based on CLRS, chapter 6.
+ *
+ * Copyright 2011 - Mathieu Desnoyers <mathieu.desnoyers@efficios.com>
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
*/
+#include <linux/slab.h>
+#include "prio_heap.h"
+
+static
+size_t parent(size_t i)
+{
+ return i >> 1;
+}
+
+static
+size_t left(size_t i)
+{
+ return i << 1;
+}
+
+static
+size_t right(size_t i)
+{
+ return (i << 1) + 1;
+}
+
/*
- * Static-sized priority heap containing pointers. Based on CLR, chapter 7.
+ * Copy of heap->ptrs pointer is invalid after heap_grow.
*/
+static
+int heap_grow(struct ptr_heap *heap, size_t new_len)
+{
+ void **new_ptrs;
-#include <linux/slab.h>
-#include <linux/prio_heap.h>
+ if (heap->alloc_len >= new_len)
+ return 0;
-int heap_init(struct ptr_heap *heap, size_t size, gfp_t gfp_mask,
- int (*gt)(void *, void *))
-{
- heap->ptrs = kmalloc(size, gfp_mask);
- if (!heap->ptrs)
+ heap->alloc_len = max_t(size_t, new_len, heap->alloc_len << 1);
+ new_ptrs = kmalloc(heap->alloc_len * sizeof(void *), heap->gfpmask);
+ if (!new_ptrs)
return -ENOMEM;
- heap->size = 0;
- heap->max = size / sizeof(void *);
- heap->gt = gt;
+ if (heap->ptrs)
+ memcpy(new_ptrs, heap->ptrs, heap->len * sizeof(void *));
+ kfree(heap->ptrs);
+ heap->ptrs = new_ptrs;
return 0;
}
+static
+int heap_set_len(struct ptr_heap *heap, size_t new_len)
+{
+ int ret;
+
+ ret = heap_grow(heap, new_len);
+ if (ret)
+ return ret;
+ heap->len = new_len;
+ return 0;
+}
+
+int heap_init(struct ptr_heap *heap, size_t alloc_len,
+ gfp_t gfpmask, int gt(void *a, void *b))
+{
+ heap->ptrs = NULL;
+ heap->len = 0;
+ heap->alloc_len = 0;
+ heap->gt = gt;
+ /*
+ * Minimum size allocated is 1 entry to ensure memory allocation
+ * never fails within heap_replace_max.
+ */
+ return heap_grow(heap, max_t(size_t, 1, alloc_len));
+}
+
void heap_free(struct ptr_heap *heap)
{
kfree(heap->ptrs);
}
-static void heapify(struct ptr_heap *heap, int pos)
+static void heapify(struct ptr_heap *heap, size_t i)
{
void **ptrs = heap->ptrs;
- void *p = ptrs[pos];
-
- while (1) {
- int left = 2 * pos + 1;
- int right = 2 * pos + 2;
- int largest = pos;
- if (left < heap->size && heap->gt(ptrs[left], p))
- largest = left;
- if (right < heap->size && heap->gt(ptrs[right], ptrs[largest]))
- largest = right;
- if (largest == pos)
+ size_t l, r, largest;
+
+ for (;;) {
+ l = left(i);
+ r = right(i);
+ if (l <= heap->len && ptrs[l] > ptrs[i])
+ largest = l;
+ else
+ largest = i;
+ if (r <= heap->len && ptrs[r] > ptrs[largest])
+ largest = r;
+ if (largest != i) {
+ void *tmp;
+
+ tmp = ptrs[i];
+ ptrs[i] = ptrs[largest];
+ ptrs[largest] = tmp;
+ i = largest;
+ continue;
+ } else {
break;
- /* Push p down the heap one level and bump one up */
- ptrs[pos] = ptrs[largest];
- ptrs[largest] = p;
- pos = largest;
+ }
}
}
void *heap_replace_max(struct ptr_heap *heap, void *p)
{
void *res;
- void **ptrs = heap->ptrs;
- if (!heap->size) {
- ptrs[0] = p;
- heap->size = 1;
+ if (!heap->len) {
+ (void) heap_set_len(heap, 1);
+ heap->ptrs[0] = p;
return NULL;
}
/* Replace the current max and heapify */
- res = ptrs[0];
- ptrs[0] = p;
+ res = heap->ptrs[0];
+ heap->ptrs[0] = p;
heapify(heap, 0);
return res;
}
-void *heap_insert(struct ptr_heap *heap, void *p)
+int heap_insert(struct ptr_heap *heap, void *p)
{
- void **ptrs = heap->ptrs;
- int pos;
-
- if (heap->size < heap->max) {
- /* Heap insertion */
- pos = heap->size++;
- while (pos > 0 && heap->gt(p, ptrs[(pos-1)/2])) {
- ptrs[pos] = ptrs[(pos-1)/2];
- pos = (pos-1)/2;
+ void **ptrs;
+ size_t pos;
+ int ret;
+
+ ret = heap_set_len(heap, heap->len + 1);
+ if (ret)
+ return ret;
+ ptrs = heap->ptrs;
+ /* Add the element to the end */
+ ptrs[heap->len - 1] = p;
+ pos = heap->len - 1;
+ /* Bubble it up to the appropriate position. */
+ for (;;) {
+ if (pos > 0 && heap->gt(ptrs[pos], ptrs[parent(pos)])) {
+ void *tmp;
+
+ /* Need to exchange */
+ tmp = ptrs[pos];
+ ptrs[pos] = ptrs[parent(pos)];
+ ptrs[parent(pos)] = tmp;
+ pos = parent(pos);
+ /* rebalance */
+ heapify(heap, pos);
+ } else {
+ break;
}
- ptrs[pos] = p;
- return NULL;
}
-
- /* The heap is full, so something will have to be dropped */
-
- /* If the new pointer is greater than the current max, drop it */
- if (heap->gt(p, ptrs[0]))
- return p;
-
- /* Replace the current max and heapify */
- return heap_replace_max(heap, p);
+ return 0;
}
void *heap_remove(struct ptr_heap *heap)
{
- void **ptrs = heap->ptrs;
-
- switch (heap->size) {
+ switch (heap->len) {
case 0:
return NULL;
case 1:
- heap->size = 0;
- return ptrs[0];
+ (void) heap_set_len(heap, 0);
+ return heap->ptrs[0];
}
-
/* Shrink, replace the current max by previous last entry and heapify */
- return heap_replace_max(heap, ptrs[--heap->size]);
+ heap_set_len(heap, heap->len - 1);
+ return heap_replace_max(heap, heap->ptrs[heap->len - 1]);
}
void *heap_cherrypick(struct ptr_heap *heap, void *p)
{
- void **ptrs = heap->ptrs;
- size_t pos, size = heap->size;
+ size_t pos, len = heap->len;
- for (pos = 0; pos < size; pos++)
- if (ptrs[pos] == p)
+ for (pos = 0; pos < len; pos++)
+ if (heap->ptrs[pos] == p)
goto found;
return NULL;
found:
- if (heap->size == 1) {
- heap->size = 0;
- return ptrs[0];
+ if (heap->len == 1) {
+ (void) heap_set_len(heap, 0);
+ return heap->ptrs[0];
}
- /*
- * Replace p with previous last entry and heapify.
- */
- ptrs[pos] = ptrs[--heap->size];
+ /* Replace p with previous last entry and heapify. */
+ heap_set_len(heap, heap->len - 1);
+ heap->ptrs[pos] = heap->ptrs[heap->len - 1];
heapify(heap, pos);
return p;
}