How to convert recursive tree search function to nonrecursively?

66 views Asked by At

I'm trying to explore the binary tree. However, I have to implement recursive functions nonrecursively.

I've searched several ways to convert recursion to nonrecursive. But it doesn't seem to apply to my code.

I wonder if I can convert my code to non-recursive and how I can convert it.

This is my code(recursive function)

const NODE* getNN(const float pos[DIM], const NODE* cur, int depth) {
    if (!cur) return nullptr;

    bool flag = pos[depth % DIM] < cur->pos[depth % DIM];
    NODE* next{ flag ? cur->left : cur->right };
    NODE* other{ flag ? cur->right : cur->left };

    const NODE* temp = getNN(pos, next, depth + 1);
    const NODE* best{ closest(pos, temp, cur) };

    float r = this->distance(pos, best->pos);
    float dr = pos[depth % DIM] - cur->pos[depth % DIM];

    if (r >= dr * dr) {
        temp = getNN(pos, other, depth + 1);
        best = closest(pos, temp, best);
    }
    return best;
}

Here is what I expected

const NODE* getNN_NonRecur(const float pos[DIM])

1

There are 1 answers

1
sapi On

It's been resolved it. Thank you for advice.

const NODE* getNN_NR(const float pos[DIM])
{
    std::stack<std::pair<std::pair<NODE*, NODE*>, unsigned int>> st;
    composeStack(st, pos, this->root, 0);

    const NODE* best{ st.top().first.first };
    while (!st.empty())
    {
        auto e = st.top(); st.pop();
        if (!e.first.first) continue;
        best = closest(pos, best, e.first.first);

        float r = distance(pos, best->pos);
        float dr = pos[e.second % DIM] - e.first.first->pos[e.second % DIM];

        if (r >= dr * dr) {
            composeStack(st, pos, e.first.second, e.second);
            best = closest(pos, st.top().first.first, best);
        }
    }
    return best;
}
void composeStack(std::stack<std::pair<std::pair<NODE*, NODE*>, unsigned int>>& st,
    const float pos[DIM],
    NODE* node, unsigned int depth)
{
    NODE* cur = node;
    st.push({ {node, nullptr}, depth });
    while (cur) {
        auto e = st.top();
        cur = e.first.first; depth = e.second;

        bool flag = pos[depth % DIM] < cur->pos[depth % DIM];
        NODE* next{ flag ? cur->left : cur->right };
        NODE* other{ flag ? cur->right : cur->left };

        st.push(std::pair<std::pair
            <NODE*, NODE*>, unsigned int>({ next, other }, depth + 1));
        cur = next;
    }
}