Olox Olox

Theme

Documentation
Back to Home

Meet in the Middle (Alternative to DP)

10 min read

Meet in the Middle

1. Overview

Core Concept

Meet in the Middle (MITM) is a technique that reduces time complexity by splitting the problem into two halves, solving each half independently, and combining results.

When to Use

  • Brute force is O(2^n) or O(n!) but n is moderate (30-40)
  • Problem can be split into independent halves
  • Results from halves can be combined efficiently

Complexity Comparison

Problem SizeBrute ForceMeet in Middle
n = 402^40 ≈ 10^122 × 2^20 ≈ 2 × 10^6
n = 302^30 ≈ 10^92 × 2^15 ≈ 65,000

2. Classic Subset Sum

Basic Template

from typing import List
from bisect import bisect_left, bisect_right

def subset_sum_mitm(nums: List[int], target: int) -> bool:
    """
    Check if any subset sums to target.
    
    Time: O(2^(n/2) × log(2^(n/2))) = O(2^(n/2) × n)
    Space: O(2^(n/2))
    """
    n = len(nums)
    mid = n // 2
    
    # Generate all subset sums for first half
    def get_all_sums(arr: List[int]) -> List[int]:
        sums = [0]
        for num in arr:
            sums += [s + num for s in sums]
        return sums
    
    left_sums = get_all_sums(nums[:mid])
    right_sums = get_all_sums(nums[mid:])
    
    # Sort right half for binary search
    right_sums.sort()
    
    # For each left sum, binary search for complement in right
    for s in left_sums:
        complement = target - s
        idx = bisect_left(right_sums, complement)
        if idx < len(right_sums) and right_sums[idx] == complement:
            return True
    
    return False


def count_subset_sums_mitm(nums: List[int], target: int) -> int:
    """
    Count subsets that sum to target.
    
    Time: O(2^(n/2) × log(2^(n/2)))
    """
    n = len(nums)
    mid = n // 2
    
    def get_all_sums(arr: List[int]) -> List[int]:
        sums = [0]
        for num in arr:
            sums += [s + num for s in sums]
        return sums
    
    left_sums = get_all_sums(nums[:mid])
    right_sums = get_all_sums(nums[mid:])
    right_sums.sort()
    
    count = 0
    for s in left_sums:
        complement = target - s
        # Count occurrences of complement
        left_idx = bisect_left(right_sums, complement)
        right_idx = bisect_right(right_sums, complement)
        count += right_idx - left_idx
    
    return count


# Test
nums = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
print(subset_sum_mitm(nums, 15))  # True
print(count_subset_sums_mitm(nums, 15))  # Count of subsets summing to 15

3. Closest Subset Sum

def closest_sum_mitm(nums: List[int], target: int) -> int:
    """
    Find subset sum closest to target.
    
    Time: O(2^(n/2) × log(2^(n/2)))
    """
    n = len(nums)
    mid = n // 2
    
    def get_all_sums(arr: List[int]) -> List[int]:
        sums = [0]
        for num in arr:
            sums += [s + num for s in sums]
        return sums
    
    left_sums = get_all_sums(nums[:mid])
    right_sums = sorted(get_all_sums(nums[mid:]))
    
    closest = float('inf')
    best_sum = 0
    
    for s in left_sums:
        complement = target - s
        
        # Binary search for closest value
        idx = bisect_left(right_sums, complement)
        
        # Check idx and idx-1
        for i in [idx - 1, idx]:
            if 0 <= i < len(right_sums):
                total = s + right_sums[i]
                if abs(total - target) < closest:
                    closest = abs(total - target)
                    best_sum = total
    
    return best_sum


def partition_min_difference(nums: List[int]) -> int:
    """
    LeetCode 2035 - Partition Array Into Two Arrays to Minimize Sum Difference
    
    Split array into two equal parts minimizing absolute difference of sums.
    
    Time: O(2^(n/2) × log(2^(n/2)))
    """
    n = len(nums)
    half = n // 2
    total = sum(nums)
    target = total / 2
    
    # For each half, generate sums for each possible count of elements
    def get_sums_by_count(arr: List[int]):
        m = len(arr)
        # sums[k] = list of sums using exactly k elements
        sums = [[] for _ in range(m + 1)]
        
        for mask in range(1 << m):
            s = 0
            cnt = 0
            for i in range(m):
                if mask & (1 << i):
                    s += arr[i]
                    cnt += 1
            sums[cnt].append(s)
        
        return sums
    
    left_sums = get_sums_by_count(nums[:half])
    right_sums = get_sums_by_count(nums[half:])
    
    # Sort right sums for binary search
    for i in range(len(right_sums)):
        right_sums[i].sort()
    
    min_diff = float('inf')
    
    # For each count k in left half, pair with half-k in right half
    for k in range(half + 1):
        right_k = half - k
        if right_k < 0 or right_k >= len(right_sums):
            continue
        
        for ls in left_sums[k]:
            # Binary search in right_sums[right_k] for value closest to target - ls
            rs_list = right_sums[right_k]
            if not rs_list:
                continue
            
            complement = target - ls
            idx = bisect_left(rs_list, complement)
            
            for i in [idx - 1, idx]:
                if 0 <= i < len(rs_list):
                    part1_sum = ls + rs_list[i]
                    part2_sum = total - part1_sum
                    min_diff = min(min_diff, abs(part1_sum - part2_sum))
    
    return int(min_diff)


# Test
print(closest_sum_mitm([1, 2, 3, 4, 5], 7))  # 7

4. 4-Sum Problem

def four_sum_mitm(nums: List[int], target: int) -> List[List[int]]:
    """
    LeetCode 18 - 4Sum
    
    Find all unique quadruplets that sum to target.
    
    Meet in middle: split into pairs.
    
    Time: O(n² log n)
    """
    from collections import defaultdict
    
    n = len(nums)
    if n < 4:
        return []
    
    nums.sort()
    
    # Store all pair sums with their indices
    pair_sums = defaultdict(list)
    for i in range(n):
        for j in range(i + 1, n):
            pair_sums[nums[i] + nums[j]].append((i, j))
    
    result = set()
    
    # For each pair, find complement pair
    for i in range(n):
        for j in range(i + 1, n):
            complement = target - nums[i] - nums[j]
            
            if complement in pair_sums:
                for k, l in pair_sums[complement]:
                    # Ensure no overlapping indices
                    if k > j:
                        quad = (nums[i], nums[j], nums[k], nums[l])
                        result.add(quad)
    
    return [list(q) for q in result]


def four_sum_count(nums1, nums2, nums3, nums4) -> int:
    """
    LeetCode 454 - 4Sum II
    
    Count tuples (i, j, k, l) where nums1[i] + nums2[j] + nums3[k] + nums4[l] = 0.
    
    Time: O(n²)
    """
    from collections import Counter
    
    # Sum pairs from first two arrays
    sum12 = Counter()
    for a in nums1:
        for b in nums2:
            sum12[a + b] += 1
    
    # For each pair from last two arrays, count complement
    count = 0
    for c in nums3:
        for d in nums4:
            count += sum12[-(c + d)]
    
    return count

5. Equation with Multiple Variables

def solve_equation_mitm(a: List[int], b: List[int], target: int) -> bool:
    """
    Check if a[i1] + a[i2] + b[j1] + b[j2] = target for some indices.
    
    Time: O(n² log n)
    """
    # Generate all pair sums from a
    sums_a = set()
    for i in range(len(a)):
        for j in range(len(a)):  # Can use same element twice
            sums_a.add(a[i] + a[j])
    
    # Check if any pair sum from b complements
    for i in range(len(b)):
        for j in range(len(b)):
            complement = target - b[i] - b[j]
            if complement in sums_a:
                return True
    
    return False


def count_arithmetic_tuples(arr: List[int]) -> int:
    """
    Count tuples (i, j, k) where arr[j] - arr[i] = arr[k] - arr[j].
    (i < j < k)
    
    Equivalent to: 2 × arr[j] = arr[i] + arr[k]
    
    Time: O(n²) with hash map
    """
    n = len(arr)
    count = 0
    
    for j in range(1, n - 1):
        target = 2 * arr[j]
        
        # Count pairs (i, k) where i < j < k and arr[i] + arr[k] = target
        left_vals = {}
        for i in range(j):
            left_vals[arr[i]] = left_vals.get(arr[i], 0) + 1
        
        for k in range(j + 1, n):
            complement = target - arr[k]
            count += left_vals.get(complement, 0)
    
    return count

def bidirectional_bfs(start: str, end: str, word_list: List[str]) -> int:
    """
    LeetCode 127 - Word Ladder (optimized)
    
    Meet in middle with BFS from both ends.
    
    Time: O(b^(d/2)) instead of O(b^d) where b is branching factor
    """
    if end not in word_list:
        return 0
    
    word_set = set(word_list)
    
    # BFS from both ends
    front = {start}
    back = {end}
    visited = {start, end}
    steps = 1
    
    while front and back:
        # Always expand smaller set
        if len(front) > len(back):
            front, back = back, front
        
        next_front = set()
        
        for word in front:
            for i in range(len(word)):
                for c in 'abcdefghijklmnopqrstuvwxyz':
                    new_word = word[:i] + c + word[i+1:]
                    
                    if new_word in back:
                        return steps + 1
                    
                    if new_word in word_set and new_word not in visited:
                        visited.add(new_word)
                        next_front.add(new_word)
        
        front = next_front
        steps += 1
    
    return 0


def minimum_genetic_mutation(start: str, end: str, bank: List[str]) -> int:
    """
    LeetCode 433 - Minimum Genetic Mutation
    
    Similar to word ladder with genes.
    """
    if end not in bank:
        return -1
    
    bank_set = set(bank)
    genes = "ACGT"
    
    front = {start}
    back = {end}
    visited = {start, end}
    steps = 0
    
    while front and back:
        if len(front) > len(back):
            front, back = back, front
        
        next_front = set()
        steps += 1
        
        for gene in front:
            for i in range(8):
                for g in genes:
                    if g == gene[i]:
                        continue
                    new_gene = gene[:i] + g + gene[i+1:]
                    
                    if new_gene in back:
                        return steps
                    
                    if new_gene in bank_set and new_gene not in visited:
                        visited.add(new_gene)
                        next_front.add(new_gene)
        
        front = next_front
    
    return -1

7. NP-Hard Problem Optimization

def traveling_salesman_mitm(dist: List[List[int]]) -> int:
    """
    TSP with Meet in Middle for moderate n (up to ~24).
    
    Split cities into two halves, enumerate paths in each half,
    combine optimally.
    
    Time: O(n² × 2^(n/2)) instead of O(n × 2^n)
    """
    n = len(dist)
    
    if n <= 10:
        # Standard bitmask DP for small n
        INF = float('inf')
        dp = [[INF] * n for _ in range(1 << n)]
        dp[1][0] = 0  # Start at city 0
        
        for mask in range(1 << n):
            for last in range(n):
                if not (mask & (1 << last)) or dp[mask][last] == INF:
                    continue
                for next_city in range(n):
                    if mask & (1 << next_city):
                        continue
                    new_mask = mask | (1 << next_city)
                    dp[new_mask][next_city] = min(
                        dp[new_mask][next_city],
                        dp[mask][last] + dist[last][next_city]
                    )
        
        full_mask = (1 << n) - 1
        return min(dp[full_mask][i] + dist[i][0] for i in range(n))
    
    # Meet in middle approach for larger n
    mid = n // 2
    INF = float('inf')
    
    # First half: compute best cost to reach each (ending city, visited mask) from city 0
    first_half = {}  # (end_city, visited_in_first_half) -> min_cost
    
    for mask in range(1 << mid):
        # DFS/DP for first half
        pass  # Implementation similar to above but only for first half cities
    
    # Second half: compute best cost for each (starting city, visited mask) to return to 0
    second_half = {}
    
    # Combine: for each split point, combine first and second half optimally
    # This is a simplified description - full implementation is complex
    
    return -1  # Placeholder


def max_weight_k_items_mitm(items: List[int], k: int) -> int:
    """
    Select exactly k items with maximum sum.
    
    For large n with small k, MITM can help.
    """
    n = len(items)
    mid = n // 2
    
    # Generate sums for all subsets of first half with count
    from collections import defaultdict
    left_sums = defaultdict(list)  # count -> list of sums
    
    for mask in range(1 << mid):
        s = 0
        cnt = bin(mask).count('1')
        for i in range(mid):
            if mask & (1 << i):
                s += items[i]
        left_sums[cnt].append(s)
    
    # Similarly for right half
    right_sums = defaultdict(list)
    for mask in range(1 << (n - mid)):
        s = 0
        cnt = bin(mask).count('1')
        for i in range(n - mid):
            if mask & (1 << i):
                s += items[mid + i]
        right_sums[cnt].append(s)
    
    # For each left count, find best right count to make k total
    max_sum = 0
    for left_cnt in range(min(k, mid) + 1):
        right_cnt = k - left_cnt
        if right_cnt < 0 or right_cnt > n - mid:
            continue
        
        if left_sums[left_cnt] and right_sums[right_cnt]:
            max_left = max(left_sums[left_cnt])
            max_right = max(right_sums[right_cnt])
            max_sum = max(max_sum, max_left + max_right)
    
    return max_sum

8. Practice Problems

LeetCode Problems

#ProblemMITM ApplicationDifficulty
184SumPair sumsMedium
127Word LadderBidirectional BFSHard
416Partition Equal SubsetSubset sumMedium
4544Sum IIPair sum countsMedium
805Split Array Same AverageMITM subsetHard
956Tallest BillboardDifference sumHard
1049Last Stone Weight IIMin partition diffMedium
1755Closest Subsequence SumClassic MITMHard
2035Partition Min DifferenceMITMHard

Competitive Programming

ProblemSourceNotes
Balanced Cow SubsetsUSACOSubset diff
4 ValuesSPOJ4 variable equation
Double KnapsackVariousWeight constraints

9. Key Patterns Summary

Meet in the Middle Decision Tree:

1. Is brute force O(2^n) or O(n!)?
   → If n ≤ 20: standard approach
   → If 20 < n ≤ 40: consider MITM
   
2. Can problem be split into halves?
   → Subset problems: Yes
   → Sequence problems: Usually no
   
3. Can halves be combined efficiently?
   → Binary search: O(log n) combination
   → Hash map: O(1) lookup
   → Two pointers: O(n) combination

MITM Template:
1. Split input into two halves
2. Enumerate all possibilities for each half
3. Sort or hash one half
4. For each element in other half, find complement

Common Pitfalls:
- Forgetting to handle edge cases (empty half)
- Integer overflow with large sums
- Duplicate counting
- Incorrect index handling when combining

10. References

  1. Competitive Programming 3 - Meet in the Middle
  2. USACO Guide: Meet in the Middle
  3. Codeforces Blog: MITM Technique
  4. TopCoder Tutorial: Meet in the Middle