Olox Olox

Theme

Documentation
Back to Home

Tree Algorithms - LCA, Binary Lifting, Serialization

10 min read

Advanced Tree Algorithms

📚 Summary

Advanced tree algorithms including Lowest Common Ancestor (LCA), Binary Lifting for O(log n) ancestor queries, tree serialization/deserialization, and Morris traversal for O(1) space traversal.


1️⃣ Lowest Common Ancestor (LCA)

LCA - Basic (LC 236)

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right


def lowest_common_ancestor(root: TreeNode, p: TreeNode, q: TreeNode) -> TreeNode:
    """
    Find LCA of two nodes p and q
    Time: O(n), Space: O(h)
    """
    if not root or root == p or root == q:
        return root
    
    left = lowest_common_ancestor(root.left, p, q)
    right = lowest_common_ancestor(root.right, p, q)
    
    if left and right:
        return root
    
    return left if left else right


def lca_iterative(root: TreeNode, p: TreeNode, q: TreeNode) -> TreeNode:
    """Iterative with parent pointers"""
    parent = {root: None}
    stack = [root]
    
    # Build parent pointers until we find both nodes
    while p not in parent or q not in parent:
        node = stack.pop()
        if node.left:
            parent[node.left] = node
            stack.append(node.left)
        if node.right:
            parent[node.right] = node
            stack.append(node.right)
    
    # Find ancestors of p
    ancestors = set()
    while p:
        ancestors.add(p)
        p = parent[p]
    
    # Find first common ancestor
    while q not in ancestors:
        q = parent[q]
    
    return q

LCA - BST (LC 235)

def lca_bst(root: TreeNode, p: TreeNode, q: TreeNode) -> TreeNode:
    """
    LCA in BST - exploit BST property
    Time: O(h)
    """
    while root:
        if p.val < root.val and q.val < root.val:
            root = root.left
        elif p.val > root.val and q.val > root.val:
            root = root.right
        else:
            return root
    return None

LCA with Parent Pointers (LC 1650)

class Node:
    def __init__(self, val=0, left=None, right=None, parent=None):
        self.val = val
        self.left = left
        self.right = right
        self.parent = parent


def lca_with_parent(p: 'Node', q: 'Node') -> 'Node':
    """
    LCA when nodes have parent pointers
    Similar to finding intersection of two linked lists
    """
    a, b = p, q
    
    while a != b:
        a = a.parent if a else q
        b = b.parent if b else p
    
    return a


def lca_with_parent_depth(p: 'Node', q: 'Node') -> 'Node':
    """Alternative: equalize depths first"""
    def get_depth(node: 'Node') -> int:
        depth = 0
        while node.parent:
            depth += 1
            node = node.parent
        return depth
    
    depth_p = get_depth(p)
    depth_q = get_depth(q)
    
    # Move deeper node up
    while depth_p > depth_q:
        p = p.parent
        depth_p -= 1
    while depth_q > depth_p:
        q = q.parent
        depth_q -= 1
    
    # Move both up together
    while p != q:
        p = p.parent
        q = q.parent
    
    return p

2️⃣ Binary Lifting

"""
Binary Lifting: Preprocess to answer kth ancestor queries in O(log n)

Key idea: 
- up[node][j] = 2^j-th ancestor of node
- up[node][j] = up[up[node][j-1]][j-1]
"""

import math

class BinaryLifting:
    """
    Binary lifting for LCA and kth ancestor queries
    Preprocessing: O(n log n)
    Query: O(log n)
    """
    
    def __init__(self, n: int, parent: list[int], root: int = 0):
        """
        n: number of nodes
        parent: parent[i] = parent of node i (-1 for root)
        """
        self.n = n
        self.LOG = max(1, math.ceil(math.log2(n)))
        
        # up[node][j] = 2^j ancestor of node
        self.up = [[-1] * self.LOG for _ in range(n)]
        self.depth = [0] * n
        
        # Initialize with direct parents
        for i in range(n):
            self.up[i][0] = parent[i]
        
        # Build sparse table
        for j in range(1, self.LOG):
            for i in range(n):
                if self.up[i][j - 1] != -1:
                    self.up[i][j] = self.up[self.up[i][j - 1]][j - 1]
        
        # Compute depths using BFS
        self._compute_depths(root, parent)
    
    def _compute_depths(self, root: int, parent: list[int]):
        """Compute depth of each node"""
        from collections import deque
        
        # Build adjacency list
        children = [[] for _ in range(self.n)]
        for i in range(self.n):
            if parent[i] != -1:
                children[parent[i]].append(i)
        
        queue = deque([root])
        self.depth[root] = 0
        
        while queue:
            node = queue.popleft()
            for child in children[node]:
                self.depth[child] = self.depth[node] + 1
                queue.append(child)
    
    def kth_ancestor(self, node: int, k: int) -> int:
        """Find k-th ancestor of node"""
        for j in range(self.LOG):
            if node == -1:
                break
            if k & (1 << j):
                node = self.up[node][j]
        return node
    
    def lca(self, u: int, v: int) -> int:
        """Find LCA of nodes u and v"""
        # Make u the deeper node
        if self.depth[u] < self.depth[v]:
            u, v = v, u
        
        # Bring u to same depth as v
        diff = self.depth[u] - self.depth[v]
        u = self.kth_ancestor(u, diff)
        
        if u == v:
            return u
        
        # Binary search for LCA
        for j in range(self.LOG - 1, -1, -1):
            if self.up[u][j] != self.up[v][j]:
                u = self.up[u][j]
                v = self.up[v][j]
        
        return self.up[u][0]
    
    def distance(self, u: int, v: int) -> int:
        """Distance between two nodes"""
        ancestor = self.lca(u, v)
        return self.depth[u] + self.depth[v] - 2 * self.depth[ancestor]


# LC 1483: Kth Ancestor of a Tree Node
class TreeAncestor:
    """LeetCode version"""
    
    def __init__(self, n: int, parent: list[int]):
        self.LOG = 20  # Enough for n up to 10^6
        self.up = [[-1] * self.LOG for _ in range(n)]
        
        for i in range(n):
            self.up[i][0] = parent[i]
        
        for j in range(1, self.LOG):
            for i in range(n):
                if self.up[i][j - 1] != -1:
                    self.up[i][j] = self.up[self.up[i][j - 1]][j - 1]
    
    def getKthAncestor(self, node: int, k: int) -> int:
        for j in range(self.LOG):
            if node == -1:
                break
            if k & (1 << j):
                node = self.up[node][j]
        return node

3️⃣ Tree Serialization/Deserialization (LC 297)

BFS Serialization

from collections import deque

class Codec:
    """Serialize/Deserialize Binary Tree using BFS"""
    
    def serialize(self, root: TreeNode) -> str:
        """Encodes tree to a string"""
        if not root:
            return ""
        
        result = []
        queue = deque([root])
        
        while queue:
            node = queue.popleft()
            if node:
                result.append(str(node.val))
                queue.append(node.left)
                queue.append(node.right)
            else:
                result.append("null")
        
        # Remove trailing nulls
        while result and result[-1] == "null":
            result.pop()
        
        return ",".join(result)
    
    def deserialize(self, data: str) -> TreeNode:
        """Decodes string to tree"""
        if not data:
            return None
        
        values = data.split(",")
        root = TreeNode(int(values[0]))
        queue = deque([root])
        i = 1
        
        while queue and i < len(values):
            node = queue.popleft()
            
            # Left child
            if i < len(values) and values[i] != "null":
                node.left = TreeNode(int(values[i]))
                queue.append(node.left)
            i += 1
            
            # Right child
            if i < len(values) and values[i] != "null":
                node.right = TreeNode(int(values[i]))
                queue.append(node.right)
            i += 1
        
        return root

Preorder Serialization

class CodecPreorder:
    """Serialize using preorder traversal"""
    
    def serialize(self, root: TreeNode) -> str:
        """Preorder serialization"""
        def dfs(node):
            if not node:
                result.append("null")
                return
            result.append(str(node.val))
            dfs(node.left)
            dfs(node.right)
        
        result = []
        dfs(root)
        return ",".join(result)
    
    def deserialize(self, data: str) -> TreeNode:
        """Reconstruct from preorder"""
        values = iter(data.split(","))
        
        def dfs():
            val = next(values)
            if val == "null":
                return None
            node = TreeNode(int(val))
            node.left = dfs()
            node.right = dfs()
            return node
        
        return dfs()

BST Serialization (LC 449)

class CodecBST:
    """
    BST serialization - no need for null markers
    Using preorder, BST property determines structure
    """
    
    def serialize(self, root: TreeNode) -> str:
        """Preorder without null markers"""
        if not root:
            return ""
        
        result = []
        
        def preorder(node):
            if node:
                result.append(str(node.val))
                preorder(node.left)
                preorder(node.right)
        
        preorder(root)
        return ",".join(result)
    
    def deserialize(self, data: str) -> TreeNode:
        """Reconstruct BST from preorder"""
        if not data:
            return None
        
        values = [int(x) for x in data.split(",")]
        
        def build(min_val: float, max_val: float) -> TreeNode:
            if not values or values[0] < min_val or values[0] > max_val:
                return None
            
            val = values.pop(0)
            node = TreeNode(val)
            node.left = build(min_val, val)
            node.right = build(val, max_val)
            return node
        
        return build(float('-inf'), float('inf'))

4️⃣ Morris Traversal

"""
Morris Traversal: O(1) space tree traversal
Uses threading to return from leaves without stack

Key idea: Temporarily modify tree by creating links from 
rightmost node of left subtree back to current node
"""

def morris_inorder(root: TreeNode) -> list[int]:
    """
    Inorder traversal with O(1) space
    Time: O(n)
    """
    result = []
    current = root
    
    while current:
        if not current.left:
            # No left subtree - visit and go right
            result.append(current.val)
            current = current.right
        else:
            # Find inorder predecessor (rightmost in left subtree)
            predecessor = current.left
            while predecessor.right and predecessor.right != current:
                predecessor = predecessor.right
            
            if not predecessor.right:
                # Create thread back to current
                predecessor.right = current
                current = current.left
            else:
                # Thread exists - we've returned
                # Remove thread and visit current
                predecessor.right = None
                result.append(current.val)
                current = current.right
    
    return result


def morris_preorder(root: TreeNode) -> list[int]:
    """Preorder with O(1) space"""
    result = []
    current = root
    
    while current:
        if not current.left:
            result.append(current.val)
            current = current.right
        else:
            predecessor = current.left
            while predecessor.right and predecessor.right != current:
                predecessor = predecessor.right
            
            if not predecessor.right:
                # Visit before going left (preorder)
                result.append(current.val)
                predecessor.right = current
                current = current.left
            else:
                predecessor.right = None
                current = current.right
    
    return result


def flatten_to_linked_list_morris(root: TreeNode) -> None:
    """
    LC 114: Flatten tree to linked list in-place
    Morris-based O(1) space solution
    """
    current = root
    
    while current:
        if current.left:
            # Find rightmost node in left subtree
            rightmost = current.left
            while rightmost.right:
                rightmost = rightmost.right
            
            # Connect rightmost to current's right
            rightmost.right = current.right
            # Move left subtree to right
            current.right = current.left
            current.left = None
        
        current = current.right


def recover_bst_morris(root: TreeNode) -> None:
    """
    LC 99: Recover BST with O(1) space
    Two nodes are swapped - find and fix them
    """
    first = second = prev = None
    current = root
    
    while current:
        if not current.left:
            # Process current
            if prev and prev.val > current.val:
                if not first:
                    first = prev
                second = current
            prev = current
            current = current.right
        else:
            predecessor = current.left
            while predecessor.right and predecessor.right != current:
                predecessor = predecessor.right
            
            if not predecessor.right:
                predecessor.right = current
                current = current.left
            else:
                predecessor.right = None
                # Process current
                if prev and prev.val > current.val:
                    if not first:
                        first = prev
                    second = current
                prev = current
                current = current.right
    
    # Swap values
    if first and second:
        first.val, second.val = second.val, first.val

5️⃣ Tree Distance Problems

Diameter of Tree (LC 543)

def diameter_of_binary_tree(root: TreeNode) -> int:
    """
    Longest path between any two nodes
    """
    diameter = 0
    
    def height(node: TreeNode) -> int:
        nonlocal diameter
        if not node:
            return 0
        
        left_height = height(node.left)
        right_height = height(node.right)
        
        # Update diameter
        diameter = max(diameter, left_height + right_height)
        
        return 1 + max(left_height, right_height)
    
    height(root)
    return diameter

Distance Between Nodes

def distance_between_nodes(root: TreeNode, p: int, q: int) -> int:
    """
    Distance between two nodes
    = depth(p) + depth(q) - 2 * depth(LCA)
    """
    def find_lca_and_depths(node, target1, target2, depth):
        if not node:
            return None, -1, -1
        
        if node.val == target1:
            d1 = depth
        else:
            d1 = -1
        
        if node.val == target2:
            d2 = depth
        else:
            d2 = -1
        
        left_lca, left_d1, left_d2 = find_lca_and_depths(
            node.left, target1, target2, depth + 1
        )
        right_lca, right_d1, right_d2 = find_lca_and_depths(
            node.right, target1, target2, depth + 1
        )
        
        # Combine depths
        d1 = max(d1, left_d1, right_d1)
        d2 = max(d2, left_d2, right_d2)
        
        # Determine LCA
        if left_lca:
            lca = left_lca
        elif right_lca:
            lca = right_lca
        elif (d1 != -1 and d2 != -1 and 
              (node.val == target1 or node.val == target2 or 
               (left_d1 != -1 or left_d2 != -1) and (right_d1 != -1 or right_d2 != -1))):
            lca = node
        else:
            lca = None
        
        return lca, d1, d2
    
    lca, depth1, depth2 = find_lca_and_depths(root, p, q, 0)
    
    # Find LCA depth
    def get_depth(node, target, depth):
        if not node:
            return -1
        if node.val == target:
            return depth
        left = get_depth(node.left, target, depth + 1)
        if left != -1:
            return left
        return get_depth(node.right, target, depth + 1)
    
    lca_depth = get_depth(root, lca.val, 0)
    return depth1 + depth2 - 2 * lca_depth

6️⃣ Tree Isomorphism and Subtree

Subtree of Another Tree (LC 572)

def is_subtree(root: TreeNode, subRoot: TreeNode) -> bool:
    """Check if subRoot is subtree of root"""
    def is_same(t1: TreeNode, t2: TreeNode) -> bool:
        if not t1 and not t2:
            return True
        if not t1 or not t2:
            return False
        return (t1.val == t2.val and 
                is_same(t1.left, t2.left) and 
                is_same(t1.right, t2.right))
    
    def dfs(node: TreeNode) -> bool:
        if not node:
            return False
        if is_same(node, subRoot):
            return True
        return dfs(node.left) or dfs(node.right)
    
    return dfs(root)


def is_subtree_merkle(root: TreeNode, subRoot: TreeNode) -> bool:
    """O(n) using Merkle hashing"""
    def get_hash(node: TreeNode) -> str:
        if not node:
            return "#"
        left_hash = get_hash(node.left)
        right_hash = get_hash(node.right)
        return f"({left_hash}{node.val}{right_hash})"
    
    def find_hash(node: TreeNode, target: str) -> bool:
        if not node:
            return False
        if get_hash(node) == target:
            return True
        return find_hash(node.left, target) or find_hash(node.right, target)
    
    target_hash = get_hash(subRoot)
    return find_hash(root, target_hash)

⏱️ Complexity Summary

AlgorithmPreprocessingQuerySpace
LCA Basic-O(n)O(h)
Binary LiftingO(n log n)O(log n)O(n log n)
Morris Traversal-O(n)O(1)
SerializationO(n)-O(n)

📚 Practice Problems

#ProblemKey Concept
1LC 236: LCA Binary TreeBasic LCA
2LC 235: LCA BSTBST LCA
3LC 1650: LCA IIIParent pointers
4LC 1483: Kth AncestorBinary lifting
5LC 297: Serialize TreeSerialization
6LC 449: Serialize BSTBST serialization
7LC 114: Flatten to ListMorris variant
8LC 99: Recover BSTMorris + swapped
9LC 543: DiameterDistance
10LC 572: SubtreeIsomorphism
11LC 1123: LCA Deepest LeavesLCA variant
12LC 1644: LCA IIMay not exist
13LC 1676: LCA IVMultiple nodes
14LC 863: All Nodes Distance KDistance
15LC 979: Distribute CoinsPost-order

Last Updated: 2024