Meet in the Middle
1. Overview
Core Concept
Meet in the Middle (MITM) is a technique that reduces time complexity by splitting the problem into two halves, solving each half independently, and combining results.
When to Use
- Brute force is O(2^n) or O(n!) but n is moderate (30-40)
- Problem can be split into independent halves
- Results from halves can be combined efficiently
Complexity Comparison
| Problem Size | Brute Force | Meet in Middle |
|---|
| n = 40 | 2^40 ≈ 10^12 | 2 × 2^20 ≈ 2 × 10^6 |
| n = 30 | 2^30 ≈ 10^9 | 2 × 2^15 ≈ 65,000 |
2. Classic Subset Sum
Basic Template
from typing import List
from bisect import bisect_left, bisect_right
def subset_sum_mitm(nums: List[int], target: int) -> bool:
"""
Check if any subset sums to target.
Time: O(2^(n/2) × log(2^(n/2))) = O(2^(n/2) × n)
Space: O(2^(n/2))
"""
n = len(nums)
mid = n // 2
# Generate all subset sums for first half
def get_all_sums(arr: List[int]) -> List[int]:
sums = [0]
for num in arr:
sums += [s + num for s in sums]
return sums
left_sums = get_all_sums(nums[:mid])
right_sums = get_all_sums(nums[mid:])
# Sort right half for binary search
right_sums.sort()
# For each left sum, binary search for complement in right
for s in left_sums:
complement = target - s
idx = bisect_left(right_sums, complement)
if idx < len(right_sums) and right_sums[idx] == complement:
return True
return False
def count_subset_sums_mitm(nums: List[int], target: int) -> int:
"""
Count subsets that sum to target.
Time: O(2^(n/2) × log(2^(n/2)))
"""
n = len(nums)
mid = n // 2
def get_all_sums(arr: List[int]) -> List[int]:
sums = [0]
for num in arr:
sums += [s + num for s in sums]
return sums
left_sums = get_all_sums(nums[:mid])
right_sums = get_all_sums(nums[mid:])
right_sums.sort()
count = 0
for s in left_sums:
complement = target - s
# Count occurrences of complement
left_idx = bisect_left(right_sums, complement)
right_idx = bisect_right(right_sums, complement)
count += right_idx - left_idx
return count
# Test
nums = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
print(subset_sum_mitm(nums, 15)) # True
print(count_subset_sums_mitm(nums, 15)) # Count of subsets summing to 15
3. Closest Subset Sum
def closest_sum_mitm(nums: List[int], target: int) -> int:
"""
Find subset sum closest to target.
Time: O(2^(n/2) × log(2^(n/2)))
"""
n = len(nums)
mid = n // 2
def get_all_sums(arr: List[int]) -> List[int]:
sums = [0]
for num in arr:
sums += [s + num for s in sums]
return sums
left_sums = get_all_sums(nums[:mid])
right_sums = sorted(get_all_sums(nums[mid:]))
closest = float('inf')
best_sum = 0
for s in left_sums:
complement = target - s
# Binary search for closest value
idx = bisect_left(right_sums, complement)
# Check idx and idx-1
for i in [idx - 1, idx]:
if 0 <= i < len(right_sums):
total = s + right_sums[i]
if abs(total - target) < closest:
closest = abs(total - target)
best_sum = total
return best_sum
def partition_min_difference(nums: List[int]) -> int:
"""
LeetCode 2035 - Partition Array Into Two Arrays to Minimize Sum Difference
Split array into two equal parts minimizing absolute difference of sums.
Time: O(2^(n/2) × log(2^(n/2)))
"""
n = len(nums)
half = n // 2
total = sum(nums)
target = total / 2
# For each half, generate sums for each possible count of elements
def get_sums_by_count(arr: List[int]):
m = len(arr)
# sums[k] = list of sums using exactly k elements
sums = [[] for _ in range(m + 1)]
for mask in range(1 << m):
s = 0
cnt = 0
for i in range(m):
if mask & (1 << i):
s += arr[i]
cnt += 1
sums[cnt].append(s)
return sums
left_sums = get_sums_by_count(nums[:half])
right_sums = get_sums_by_count(nums[half:])
# Sort right sums for binary search
for i in range(len(right_sums)):
right_sums[i].sort()
min_diff = float('inf')
# For each count k in left half, pair with half-k in right half
for k in range(half + 1):
right_k = half - k
if right_k < 0 or right_k >= len(right_sums):
continue
for ls in left_sums[k]:
# Binary search in right_sums[right_k] for value closest to target - ls
rs_list = right_sums[right_k]
if not rs_list:
continue
complement = target - ls
idx = bisect_left(rs_list, complement)
for i in [idx - 1, idx]:
if 0 <= i < len(rs_list):
part1_sum = ls + rs_list[i]
part2_sum = total - part1_sum
min_diff = min(min_diff, abs(part1_sum - part2_sum))
return int(min_diff)
# Test
print(closest_sum_mitm([1, 2, 3, 4, 5], 7)) # 7
4. 4-Sum Problem
def four_sum_mitm(nums: List[int], target: int) -> List[List[int]]:
"""
LeetCode 18 - 4Sum
Find all unique quadruplets that sum to target.
Meet in middle: split into pairs.
Time: O(n² log n)
"""
from collections import defaultdict
n = len(nums)
if n < 4:
return []
nums.sort()
# Store all pair sums with their indices
pair_sums = defaultdict(list)
for i in range(n):
for j in range(i + 1, n):
pair_sums[nums[i] + nums[j]].append((i, j))
result = set()
# For each pair, find complement pair
for i in range(n):
for j in range(i + 1, n):
complement = target - nums[i] - nums[j]
if complement in pair_sums:
for k, l in pair_sums[complement]:
# Ensure no overlapping indices
if k > j:
quad = (nums[i], nums[j], nums[k], nums[l])
result.add(quad)
return [list(q) for q in result]
def four_sum_count(nums1, nums2, nums3, nums4) -> int:
"""
LeetCode 454 - 4Sum II
Count tuples (i, j, k, l) where nums1[i] + nums2[j] + nums3[k] + nums4[l] = 0.
Time: O(n²)
"""
from collections import Counter
# Sum pairs from first two arrays
sum12 = Counter()
for a in nums1:
for b in nums2:
sum12[a + b] += 1
# For each pair from last two arrays, count complement
count = 0
for c in nums3:
for d in nums4:
count += sum12[-(c + d)]
return count
5. Equation with Multiple Variables
def solve_equation_mitm(a: List[int], b: List[int], target: int) -> bool:
"""
Check if a[i1] + a[i2] + b[j1] + b[j2] = target for some indices.
Time: O(n² log n)
"""
# Generate all pair sums from a
sums_a = set()
for i in range(len(a)):
for j in range(len(a)): # Can use same element twice
sums_a.add(a[i] + a[j])
# Check if any pair sum from b complements
for i in range(len(b)):
for j in range(len(b)):
complement = target - b[i] - b[j]
if complement in sums_a:
return True
return False
def count_arithmetic_tuples(arr: List[int]) -> int:
"""
Count tuples (i, j, k) where arr[j] - arr[i] = arr[k] - arr[j].
(i < j < k)
Equivalent to: 2 × arr[j] = arr[i] + arr[k]
Time: O(n²) with hash map
"""
n = len(arr)
count = 0
for j in range(1, n - 1):
target = 2 * arr[j]
# Count pairs (i, k) where i < j < k and arr[i] + arr[k] = target
left_vals = {}
for i in range(j):
left_vals[arr[i]] = left_vals.get(arr[i], 0) + 1
for k in range(j + 1, n):
complement = target - arr[k]
count += left_vals.get(complement, 0)
return count
6. Bidirectional BFS / Search
def bidirectional_bfs(start: str, end: str, word_list: List[str]) -> int:
"""
LeetCode 127 - Word Ladder (optimized)
Meet in middle with BFS from both ends.
Time: O(b^(d/2)) instead of O(b^d) where b is branching factor
"""
if end not in word_list:
return 0
word_set = set(word_list)
# BFS from both ends
front = {start}
back = {end}
visited = {start, end}
steps = 1
while front and back:
# Always expand smaller set
if len(front) > len(back):
front, back = back, front
next_front = set()
for word in front:
for i in range(len(word)):
for c in 'abcdefghijklmnopqrstuvwxyz':
new_word = word[:i] + c + word[i+1:]
if new_word in back:
return steps + 1
if new_word in word_set and new_word not in visited:
visited.add(new_word)
next_front.add(new_word)
front = next_front
steps += 1
return 0
def minimum_genetic_mutation(start: str, end: str, bank: List[str]) -> int:
"""
LeetCode 433 - Minimum Genetic Mutation
Similar to word ladder with genes.
"""
if end not in bank:
return -1
bank_set = set(bank)
genes = "ACGT"
front = {start}
back = {end}
visited = {start, end}
steps = 0
while front and back:
if len(front) > len(back):
front, back = back, front
next_front = set()
steps += 1
for gene in front:
for i in range(8):
for g in genes:
if g == gene[i]:
continue
new_gene = gene[:i] + g + gene[i+1:]
if new_gene in back:
return steps
if new_gene in bank_set and new_gene not in visited:
visited.add(new_gene)
next_front.add(new_gene)
front = next_front
return -1
7. NP-Hard Problem Optimization
def traveling_salesman_mitm(dist: List[List[int]]) -> int:
"""
TSP with Meet in Middle for moderate n (up to ~24).
Split cities into two halves, enumerate paths in each half,
combine optimally.
Time: O(n² × 2^(n/2)) instead of O(n × 2^n)
"""
n = len(dist)
if n <= 10:
# Standard bitmask DP for small n
INF = float('inf')
dp = [[INF] * n for _ in range(1 << n)]
dp[1][0] = 0 # Start at city 0
for mask in range(1 << n):
for last in range(n):
if not (mask & (1 << last)) or dp[mask][last] == INF:
continue
for next_city in range(n):
if mask & (1 << next_city):
continue
new_mask = mask | (1 << next_city)
dp[new_mask][next_city] = min(
dp[new_mask][next_city],
dp[mask][last] + dist[last][next_city]
)
full_mask = (1 << n) - 1
return min(dp[full_mask][i] + dist[i][0] for i in range(n))
# Meet in middle approach for larger n
mid = n // 2
INF = float('inf')
# First half: compute best cost to reach each (ending city, visited mask) from city 0
first_half = {} # (end_city, visited_in_first_half) -> min_cost
for mask in range(1 << mid):
# DFS/DP for first half
pass # Implementation similar to above but only for first half cities
# Second half: compute best cost for each (starting city, visited mask) to return to 0
second_half = {}
# Combine: for each split point, combine first and second half optimally
# This is a simplified description - full implementation is complex
return -1 # Placeholder
def max_weight_k_items_mitm(items: List[int], k: int) -> int:
"""
Select exactly k items with maximum sum.
For large n with small k, MITM can help.
"""
n = len(items)
mid = n // 2
# Generate sums for all subsets of first half with count
from collections import defaultdict
left_sums = defaultdict(list) # count -> list of sums
for mask in range(1 << mid):
s = 0
cnt = bin(mask).count('1')
for i in range(mid):
if mask & (1 << i):
s += items[i]
left_sums[cnt].append(s)
# Similarly for right half
right_sums = defaultdict(list)
for mask in range(1 << (n - mid)):
s = 0
cnt = bin(mask).count('1')
for i in range(n - mid):
if mask & (1 << i):
s += items[mid + i]
right_sums[cnt].append(s)
# For each left count, find best right count to make k total
max_sum = 0
for left_cnt in range(min(k, mid) + 1):
right_cnt = k - left_cnt
if right_cnt < 0 or right_cnt > n - mid:
continue
if left_sums[left_cnt] and right_sums[right_cnt]:
max_left = max(left_sums[left_cnt])
max_right = max(right_sums[right_cnt])
max_sum = max(max_sum, max_left + max_right)
return max_sum
8. Practice Problems
LeetCode Problems
Competitive Programming
| Problem | Source | Notes |
|---|
| Balanced Cow Subsets | USACO | Subset diff |
| 4 Values | SPOJ | 4 variable equation |
| Double Knapsack | Various | Weight constraints |
9. Key Patterns Summary
Meet in the Middle Decision Tree:
1. Is brute force O(2^n) or O(n!)?
→ If n ≤ 20: standard approach
→ If 20 < n ≤ 40: consider MITM
2. Can problem be split into halves?
→ Subset problems: Yes
→ Sequence problems: Usually no
3. Can halves be combined efficiently?
→ Binary search: O(log n) combination
→ Hash map: O(1) lookup
→ Two pointers: O(n) combination
MITM Template:
1. Split input into two halves
2. Enumerate all possibilities for each half
3. Sort or hash one half
4. For each element in other half, find complement
Common Pitfalls:
- Forgetting to handle edge cases (empty half)
- Integer overflow with large sums
- Duplicate counting
- Incorrect index handling when combining
10. References
- Competitive Programming 3 - Meet in the Middle
- USACO Guide: Meet in the Middle
- Codeforces Blog: MITM Technique
- TopCoder Tutorial: Meet in the Middle