We're reasoning about an imagined merged sorted array that we never actually build.
We're doing a binary search over the smaller array over where we can cut it (+ implicit merge of second) to form a valid split.
We define the left partition size as the combined arrays' midpoint.
i is the midpoint in the first array.
j is whatever's left over, so is implicit.
If we achieved a sorted split of the merged array, the median would straddle along this property:
Every element on the left side is <= every element on the right side.
Then it's just basic binary search for whether or not we took too many elements from A into the left half. If we did, then we need to "move the bar left" in A so that fewer large elements in A end up on the left. So we shrink the right boundary to look for a smaller i.
i represents something semantically richer than the midpoint inde; it's the number of elements we take FROM the first array into the left partition.
We compare both arrays at the partition boundaries via j = left-partition_size - i. That covers every possible left-half composition, needs only O(1) lookups, and the too much/too little checks give us the monotone signal needed for binary search.
We're not "stuck" in the first half. i ranges from 0...m. When i is small, j is large; when i is large, j is small. That sweeps across the full cross-product of splits that achieve the correct left size. The correct (i, j) pair must be hit.
We want the left half of the imagined merged array to contain exactly left_partition_size = (m + n + 1) // 2 elements. Any valid split of that merged array with that left size corresponds to some pair (i, j): i = how many we take from A into the left half. j = how many we take from B into the left half.
We systematically enumerate every possible way to populate the left half from the two arrays. There is no split we fail to consider.
class Solution:
def findMedianSortedArrays(self, nums1: List[int], nums2: List[int]) -> float:
if len(nums1) > len(nums2):
nums1, nums2 = nums2, nums1
m, n = len(nums1), len(nums2)
left, right = 0, m
left_partition_size = (m + n + 1) // 2
while left <= right:
i = (left + right) // 2
j = left_partition_size - i
Aleft = float('-inf') if i == 0 else nums1[i - 1]
Aright = float('inf') if i == m else nums1[i]
Bleft = float('-inf') if j == 0 else nums2[j - 1]
Bright = float('inf') if j == n else nums2[j]
if Aleft <= Bright and Bleft <= Aright:
if (m + n) % 2 == 1:
return max(Aleft, Bleft)
else:
return (max(Aleft, Bleft) + min(Aright, Bright)) / 2.0
elif Aleft > Bright:
right = i - 1
else:
left = i + 1