rcuja: lower equal: handle concurrent removal with retry
[userspace-rcu.git] / rcuja / rcuja.c
index 41ac70c43fb457aa98ae7ba80ebfce4b515904b6..e6c6d5b51493e8b5bbc58d3044c04ae2a84bb183 100644 (file)
@@ -34,7 +34,6 @@
 #include <stdint.h>
 
 #include "rcuja-internal.h"
-#include "bitfield.h"
 
 #ifndef abs
 #define abs_int(a)     ((int) (a) > 0 ? (int) (a) : -((int) (a)))
@@ -370,6 +369,43 @@ struct cds_ja_inode_flag *ja_linear_node_get_nth(const struct cds_ja_type *type,
        return ptr;
 }
 
+static
+struct cds_ja_inode_flag *ja_linear_node_get_left(const struct cds_ja_type *type,
+               struct cds_ja_inode *node,
+               unsigned int n)
+{
+       uint8_t nr_child;
+       uint8_t *values;
+       struct cds_ja_inode_flag **pointers;
+       struct cds_ja_inode_flag *ptr;
+       unsigned int i, match_idx;
+       int match_v = -1;
+
+       assert(type->type_class == RCU_JA_LINEAR || type->type_class == RCU_JA_POOL);
+
+       nr_child = ja_linear_node_get_nr_child(type, node);
+       cmm_smp_rmb();  /* read nr_child before values and pointers */
+       assert(nr_child <= type->max_linear_child);
+       assert(type->type_class != RCU_JA_LINEAR || nr_child >= type->min_child);
+
+       values = &node->u.data[1];
+       for (i = 0; i < nr_child; i++) {
+               unsigned int v;
+
+               v = CMM_LOAD_SHARED(values[i]);
+               if (v < n && (int) v > match_v) {
+                       match_v = v;
+                       match_idx = i;
+               }
+       }
+       if (match_v < 0) {
+               return NULL;
+       }
+       pointers = (struct cds_ja_inode_flag **) align_ptr_size(&values[type->max_linear_child]);
+       ptr = rcu_dereference(pointers[match_idx]);
+       return ptr;
+}
+
 static
 void ja_linear_node_get_ith_pos(const struct cds_ja_type *type,
                struct cds_ja_inode *node,
@@ -442,6 +478,42 @@ struct cds_ja_inode *ja_pool_node_get_ith_pool(const struct cds_ja_type *type,
                &node->u.data[(unsigned int) i << type->pool_size_order];
 }
 
+static
+struct cds_ja_inode_flag *ja_pool_node_get_left(const struct cds_ja_type *type,
+               struct cds_ja_inode *node,
+               unsigned int n)
+{
+       unsigned int pool_nr;
+       int match_v = -1;
+       struct cds_ja_inode_flag *match_node_flag = NULL;
+
+       assert(type->type_class == RCU_JA_POOL);
+
+       for (pool_nr = 0; pool_nr < (1U << type->nr_pool_order); pool_nr++) {
+               struct cds_ja_inode *pool =
+                       ja_pool_node_get_ith_pool(type,
+                               node, pool_nr);
+               uint8_t nr_child =
+                       ja_linear_node_get_nr_child(type, pool);
+               unsigned int j;
+
+               for (j = 0; j < nr_child; j++) {
+                       struct cds_ja_inode_flag *iter;
+                       uint8_t v;
+
+                       ja_linear_node_get_ith_pos(type, pool,
+                                       j, &v, &iter);
+                       if (!iter)
+                               continue;
+                       if (v < n && (int) v > match_v) {
+                               match_v = v;
+                               match_node_flag = iter;
+                       }
+               }
+       }
+       return match_node_flag;
+}
+
 static
 struct cds_ja_inode_flag *ja_pigeon_node_get_nth(const struct cds_ja_type *type,
                struct cds_ja_inode *node,
@@ -461,6 +533,30 @@ struct cds_ja_inode_flag *ja_pigeon_node_get_nth(const struct cds_ja_type *type,
        return child_node_flag;
 }
 
+static
+struct cds_ja_inode_flag *ja_pigeon_node_get_left(const struct cds_ja_type *type,
+               struct cds_ja_inode *node,
+               unsigned int n)
+{
+       struct cds_ja_inode_flag **child_node_flag_ptr;
+       struct cds_ja_inode_flag *child_node_flag;
+       int i;
+
+       assert(type->type_class == RCU_JA_PIGEON);
+
+       /* n - 1 is first value left of n */
+       for (i = n - 1; i >= 0; i--) {
+               child_node_flag_ptr = &((struct cds_ja_inode_flag **) node->u.data)[i];
+               child_node_flag = rcu_dereference(*child_node_flag_ptr);
+               if (child_node_flag) {
+                       dbg_printf("ja_pigeon_node_get_left child_node_flag %p\n",
+                               child_node_flag);
+                       return child_node_flag;
+               }
+       }
+       return NULL;
+}
+
 static
 struct cds_ja_inode_flag *ja_pigeon_node_get_ith_pos(const struct cds_ja_type *type,
                struct cds_ja_inode *node,
@@ -503,6 +599,38 @@ struct cds_ja_inode_flag *ja_node_get_nth(struct cds_ja_inode_flag *node_flag,
        }
 }
 
+static
+struct cds_ja_inode_flag *ja_node_get_left(struct cds_ja_inode_flag *node_flag,
+               unsigned int n)
+{
+       unsigned int type_index;
+       struct cds_ja_inode *node;
+       const struct cds_ja_type *type;
+
+       node = ja_node_ptr(node_flag);
+       assert(node != NULL);
+       type_index = ja_node_type(node_flag);
+       type = &ja_types[type_index];
+
+       switch (type->type_class) {
+       case RCU_JA_LINEAR:
+               return ja_linear_node_get_left(type, node, n);
+       case RCU_JA_POOL:
+               return ja_pool_node_get_left(type, node, n);
+       case RCU_JA_PIGEON:
+               return ja_pigeon_node_get_left(type, node, n);
+       default:
+               assert(0);
+               return (void *) -1UL;
+       }
+}
+
+static
+struct cds_ja_inode_flag *ja_node_get_rightmost(struct cds_ja_inode_flag *node_flag)
+{
+       return ja_node_get_left(node_flag, JA_ENTRY_PER_NODE);
+}
+
 static
 int ja_linear_node_set_nth(const struct cds_ja_type *type,
                struct cds_ja_inode *node,
@@ -1610,6 +1738,105 @@ struct cds_hlist_head cds_ja_lookup(struct cds_ja *ja, uint64_t key)
        return head;
 }
 
+/*
+ * cds_ja_lookup_lower_equal() may need to retry if a concurrent removal
+ * causes failure to find the previous node.
+ */
+struct cds_hlist_head cds_ja_lookup_lower_equal(struct cds_ja *ja, uint64_t key)
+{
+       int tree_depth, level;
+       struct cds_ja_inode_flag *node_flag, *cur_node_depth[JA_MAX_DEPTH];
+       struct cds_hlist_head head = { NULL };
+
+       if (caa_unlikely(key > ja->key_max || !key))
+               return head;
+
+retry:
+       memset(cur_node_depth, 0, sizeof(cur_node_depth));
+       tree_depth = ja->tree_depth;
+       node_flag = rcu_dereference(ja->root);
+       cur_node_depth[0] = node_flag;
+
+       /* level 0: root node */
+       if (!ja_node_ptr(node_flag))
+               return head;
+
+       for (level = 1; level < tree_depth; level++) {
+               uint8_t iter_key;
+
+               iter_key = (uint8_t) (key >> (JA_BITS_PER_BYTE * (tree_depth - level - 1)));
+               node_flag = ja_node_get_nth(node_flag, NULL, iter_key);
+               if (!ja_node_ptr(node_flag))
+                       break;
+               cur_node_depth[level] = node_flag;
+               dbg_printf("cds_ja_lookup iter key lookup %u finds node_flag %p\n",
+                               (unsigned int) iter_key, node_flag);
+       }
+
+       if (level == tree_depth) {
+               /* Last level lookup succeded. We got an equal match. */
+               head.next = (struct cds_hlist_node *) node_flag;
+               return head;
+       }
+
+       /*
+        * Find highest value left of current node.
+        * Current node is cur_node_depth[level].
+        * Start at current level. If we cannot find any key left of
+        * ours, go one level up, seek highest value left of current
+        * (recursively), and when we find one, get the rightmost child
+        * of its rightmost child (recursively).
+        */
+       for (; level > 0; level--) {
+               uint8_t iter_key;
+
+               iter_key = (uint8_t) (key >> (JA_BITS_PER_BYTE * (tree_depth - level - 1)));
+               node_flag = ja_node_get_left(cur_node_depth[level - 1],
+                               iter_key);
+               /* If found left sibling, find rightmost child. */
+               if (ja_node_ptr(node_flag))
+                       break;
+       }
+
+       if (!level) {
+               /* Reached the root and could not find a left sibling. */
+               return head;
+       }
+
+       level++;
+
+       /*
+        * From this point, we should be able to find a "lower than"
+        * match. The only reason why we could fail to find such a match
+        * would be due to a concurrent removal of the branch that
+        * contains the match. If this happens, we have no choice but to
+        * retry the entire lookup. Indeed, just because we reached a
+        * dead-end due to concurrent removal of the branch does not
+        * mean that other match don't exist. However, this requires
+        * going up into the tree, hence the retry.
+        */
+
+       /* Find rightmost child of rightmost child (recursively). */
+       for (; level < tree_depth; level++) {
+               node_flag = ja_node_get_rightmost(node_flag);
+               /* If found left sibling, find rightmost child. */
+               if (!ja_node_ptr(node_flag))
+                       break;
+       }
+
+       if (level != tree_depth) {
+               /*
+                * We did not get a match. Caused by concurrent removal.
+                * We need to retry the entire lookup.
+                */
+               goto retry;
+       }
+
+       /* Last level lookup succeded. We got a "lower than" match. */
+       head.next = (struct cds_hlist_node *) node_flag;
+       return head;
+}
+
 /*
  * We reached an unpopulated node. Create it and the children we need,
  * and then attach the entire branch to the current node. This may
This page took 0.025723 seconds and 4 git commands to generate.