Segment Tree

Theorem

Some diagrams and contents are based on Segment Tree – Algorithms for Competitive Programming.

A segment tree should be a binary tree like the one below, additionally, we could proof that Segment Tree should NOT have any node with degree 11.

Based on the property of binary tree, we can calculate the index of leave nodes and non-leave nodes. Consider n0,n1,n2n_0, n_1, n_2 refers to nodes with different degrees, we will have:

{n0=n2+1n1=0\begin{cases} n_0 = n_2 + 1 \\ n_{1} = 0 \end{cases}

Based on the formula above, we could know that the total node count should be:

n=2n01=2ElemCount1n = 2 n_0 – 1 = 2 \text{ElemCount} – 1

Initialize

The procedure of initialization is a post-order traversal of binary tree.

def initialize_node(node, range):
    if node.is_leaf():
        node.range = range
        node.value = arr[range.left]
    else:
        initialize_node(node.left_child, range.left_subrange())
        initialize_node(node.right_child, range.right_subrange())
        node.update_based_on_children()

Range-Limited Traversal

Range-Limited Traversal here refers to the iteration of the Segment Tree with limited range.

More concretely, if we limited a range [l,r][l ,r], then this operation should promise that only nodes which’s range have intersection with this target range will be iterated.

Here, we also defined two types of nodes:

  • Range-Intersection Node
  • Range-Contained Node

The Range-Intersection Node means the range of such node is intersected with the target range. (Denoted by “I” in the image above)

The Range-Contained Node means such node’s range completely falls into the target range. (Denoted by “C” in the image above)

It could be easily found that we have relationship:

CIC \subseteq I

Now let’s create a general function implements this traversal that allowing caller to specify different function as pre-order / post-order traversal visitor, and a Range-Contained Node visitor:

The basic implementation logic looks like below:

def tree_traversal(node):
    # limited range
    if(node not in I or node not in C):
        break

    # pre visit
    pre_visitor(node)

    # contained visit
    if node in C:
        contained_visitor(node)
    
    # sub tree
    tree_traversal(node.left)
    tree_traversal(node.right)
    
    # post visit
    post_visitor(node)

Get Range Value

After implement the range-limited traversal above, we could achieve this feature by simply specifying the preVisitor() and containedVisitor() for the traversal.

  • preVisitor() should be responsible to clean any dirty flag at this node (which is mDelta field in the code implementation below)
  • containedVisitor() should be responsible to updating the final answer value

As shown above, the usage of preVisitor() is to ensure all the dirty flag has been cleared on the path of range-limited traversal.

Update Range Value

Similar to get range value, we still have to use preVisitor() to push down those dirty flag, and this time, the containedVisitor() should be used to apply the update value to those Range-Contained Node.

Besides, a postVisitor() is needed to rebuild (update) all nodes that have been visited to ensure they reflect the value changes.

Code Implementation

The code below is the AC code of the Segment Tree template question on Luogu.

Note that the code below is using std::stack to achieve post-order traversal instead of using recursive function. Also, it make use of C++ std::functional to define the callback function.

#include <iostream>
#include <stack>
#include <functional>
#include <tuple>

const bool debug = false;
using std::cin, std::cout;
const char endq = '\n';

using LL = unsigned long long;

/**
 * Class used to record a range
 */
class Range
{
public:
    int l;
    int r;
    int size;
    int mid;
    Range(int l, int r)
        : l(l), r(r), size(r - l + 1), mid((l + r) / 2)
    {
        if (l > r)
        {
            cout << "Illegal range" << endq;
            exit(-1);
        }
    }

    Range leftSubrange()
    {
        return Range(l, mid);
    }

    Range rightSubrange()
    {
        return Range(mid + 1, r);
    }

    bool contains(const Range &another) const
    {
        return (l <= another.l && r >= another.r);
    }

    bool hasIntersection(const Range &another) const
    {
        return (r >= another.l && another.r >= l);
    }

    void display()
    {
        cout << "[" << l << ", " << r << "]";
    }
};

class SegmentTree
{
public:
    class SegmentTreeNode
    {
    public:
        using Visitor = std::function<void(SegmentTreeNode *)>;

        LL mValue = 0;
        LL mDelta = 0;
        SegmentTreeNode *mLeft = nullptr;
        SegmentTreeNode *mRight = nullptr;
        Range range = Range(0, 0);

        SegmentTreeNode(const Range &range) : range(range) {}

        /**
         * Clear delta value in this node and pass down delta to children (if exists)
         *
         * Also update value of this node after clear.
         */
        void clearDelta()
        {
            if (isLeaf())
            {
                mValue += mDelta;
                mDelta = 0;
            }
            else
            {
                mLeft->mDelta += mDelta;
                mRight->mDelta += mDelta;
                mDelta = 0;
                rebuild();
            }
        }

        // recalculate value of this node based on children
        void rebuild()
        {
            if (!isLeaf())
            {
                mValue = mLeft->getValue() + mRight->getValue();
            }
        }

        LL getValue()
        {
            return (mValue + mDelta * range.size);
        }

        bool isLeaf()
        {
            return (range.size == 1);
        }

        void display()
        {
            range.display();
            cout << ": TreeNode: value: " << mValue << ", delta: " << mDelta;
        };

        void preIteration(Visitor func)
        {
            func(this);
            if (mLeft != nullptr)
                mLeft->preIteration(func);
            if (mRight != nullptr)
                mRight->preIteration(func);
        }

        void postIteration(Visitor func)
        {
            if (mLeft != nullptr)
                mLeft->postIteration(func);
            if (mRight != nullptr)
                mRight->postIteration(func);
            func(this);
        }
    };

    /**
     * Root of this segment tree
     */
    SegmentTreeNode *root = nullptr;

    /**
     * Size of the array corresponding to this segment tree
     */
    int mSize;

private:
    class ConstructInfo
    {
    public:
        SegmentTreeNode *node;
        bool isChildrenConstructed = false;

        ConstructInfo(SegmentTreeNode *node) : node(node), isChildrenConstructed(0) {}

        /**
         * Automatically create two children node of received parent node,
         * create the link in parent node, then return children
         */
        std::tuple<ConstructInfo, ConstructInfo>
        createChildrenConstruct()
        {
            if (node->isLeaf())
            {
                cout << "no child for leaf node";
                exit(-1);
            }

            SegmentTreeNode *l = new SegmentTreeNode(node->range.leftSubrange());
            SegmentTreeNode *r = new SegmentTreeNode(node->range.rightSubrange());
            node->mLeft = l;
            node->mRight = r;
            return {
                ConstructInfo(l),
                ConstructInfo(r),
            };
        }
    };

public:
    SegmentTree(LL arr[], unsigned count) : mSize(count)
    {
        // exploit stack to implement binary tree post-iteration
        // to initialize the segment tree

        std::stack<ConstructInfo> workStack;
        root = new SegmentTreeNode(Range(0, mSize - 1));
        workStack.push(ConstructInfo(root));

        while (!workStack.empty())
        {
            // get stack top elements
            ConstructInfo &curTop = workStack.top();

            // arrived leaf node, directly initialize node value
            if (curTop.node->isLeaf())
            {
                curTop.node->mValue = arr[curTop.node->range.l];
                workStack.pop();
            }

            // non-leaf node that not finished children init
            else if (!curTop.isChildrenConstructed)
            {
                curTop.isChildrenConstructed = 1;
                auto [l, r] = curTop.createChildrenConstruct();
                workStack.push(l);
                workStack.push(r);
            }

            // should be non-leaf node that finished children init
            else
            {
                if (debug)
                {
                    // not leave
                    if (curTop.node->mLeft == nullptr || curTop.node->mRight == nullptr)
                    {
                        cout << "non-leaf node should have children";
                        exit(-1);
                    }
                    if (curTop.node->range.size <= 1)
                    {
                        cout << "non-leaf node should have range > 1";
                        exit(-1);
                    }
                }
                curTop.node->rebuild();
                workStack.pop();
            }
        }

        // finish init
        if (debug)
        {
            root->preIteration(
                [](SegmentTreeNode *node)
                { 
                    node->display();
                    cout << endq; });
        }
    }

    void _iteration(
        const Range &targetRange,
        const SegmentTreeNode::Visitor preVisitor,
        const SegmentTreeNode::Visitor postVisitor,
        const SegmentTreeNode::Visitor containedVisitor)
    {
        std::stack<SegmentTreeNode *> nodeStack;
        std::stack<bool> visited;
        nodeStack.push(root);
        visited.push(0);

        while (!nodeStack.empty())
        {
            SegmentTreeNode *curNode = nodeStack.top();

            // non-visited node
            if (!visited.top())
            {
                if (debug)
                {
                    cout << "visiting: ";
                    curNode->display();
                    cout << endq;
                }
                // pre-iteration
                if (preVisitor)
                    preVisitor(curNode);
                visited.top() = true;

                // contained node
                if (targetRange.contains(curNode->range))
                {
                    if (containedVisitor)
                        containedVisitor(curNode);
                }
                // intersection node
                // we could sure this node is NOT a leaf node, since leaf node must:
                // 1. completely contained in the target range
                // 2. completely outside of the target range,
                //    thus should not appear in the working stack
                else
                {
                    if (curNode->mLeft->range.hasIntersection(targetRange))
                    {
                        nodeStack.push(curNode->mLeft);
                        visited.push(false);
                    }
                    if (curNode->mRight->range.hasIntersection(targetRange))
                    {
                        nodeStack.push(curNode->mRight);
                        visited.push(false);
                    }
                }
            }
            // visited node
            else
            {
                if (postVisitor)
                    postVisitor(curNode);
                nodeStack.pop();
                visited.pop();
            }
        }
    }

    LL getRangeValue(const Range &targetRange)
    {
        LL value = 0;
        std::function preVisitor = [](SegmentTreeNode *node)
        {
            node->clearDelta();
        };
        std::function containedVisitor = [&value](SegmentTreeNode *node)
        {
            value += node->getValue();
        };
        _iteration(targetRange, preVisitor, nullptr, containedVisitor);
        return value;
    }

    void mutateRangeValue(const Range &range, LL delta)
    {
        std::function preVisitor = [](SegmentTreeNode *node)
        {
            node->clearDelta();
        };

        std::function postVisitor = [](SegmentTreeNode *node)
        {
            node->rebuild();
        };
        std::function containedVisitor = [delta](SegmentTreeNode *node)
        {
            node->mDelta += delta;
        };
        _iteration(range, preVisitor, postVisitor, containedVisitor);
    }
};

LL arr[100010];

int main()
{
    LL n, m;
    cin >> n >> m;

    for (int i = 0; i < n; ++i)
    {
        cin >> arr[i];
    }

    SegmentTree segTree = SegmentTree(arr, n);

    LL mode, l, r, tmp;
    while (m--)
    {
        cin >> mode >> l >> r;
        auto range = Range(l - 1, r - 1);
        if (mode == 1)
        {
            cin >> tmp;
            segTree.mutateRangeValue(range, tmp);
        }
        else
        {
            cout << segTree.getRangeValue(range) << endq;
        }
    }
    return 0;
}

Refs

Published by Oyasumi

Just a normal person in the world

Leave a Reply

Your email address will not be published. Required fields are marked *