LeetCode 1803. Count Pairs With XOR in a Range

Trie

Given a (0-indexed) integer array nums and two integers low and high, return the number of nice pairs.

A nice pair is a pair (i, j) where 0 <= i < j < nums.length and low <= (nums[i] XOR nums[j]) <= high.

Example 1:

Input: nums = [1,4,2,7], low = 2, high = 6
Output: 6
Explanation: All nice pairs (i, j) are as follows:
    - (0, 1): nums[0] XOR nums[1] = 5 
    - (0, 2): nums[0] XOR nums[2] = 3
    - (0, 3): nums[0] XOR nums[3] = 6
    - (1, 2): nums[1] XOR nums[2] = 6
    - (1, 3): nums[1] XOR nums[3] = 3
    - (2, 3): nums[2] XOR nums[3] = 5

Example 2:

Input: nums = [9,8,4,2,1], low = 5, high = 14
Output: 8
Explanation: All nice pairs (i, j) are as follows:
​​​​​    - (0, 2): nums[0] XOR nums[2] = 13
    - (0, 3): nums[0] XOR nums[3] = 11
    - (0, 4): nums[0] XOR nums[4] = 8
    - (1, 2): nums[1] XOR nums[2] = 12
    - (1, 3): nums[1] XOR nums[3] = 10
    - (1, 4): nums[1] XOR nums[4] = 9
    - (2, 3): nums[2] XOR nums[3] = 6
    - (2, 4): nums[2] XOR nums[4] = 5

Constraints:

  • 1 <= nums.length <= 2 * 104

  • 1 <= nums[i] <= 2 * 104

  • 1 <= low <= high <= 2 * 104

Solution

English Version in Youtube

中文版解答Youtube Link

中文版解答Bilibili Link

class Solution {
    
    const int HEIGHT = 14;
    
    class TreeNode {
    public:
        TreeNode* next[2];
        int cnt;
        TreeNode () {
            next[0] = nullptr;
            next[1] = nullptr;
            cnt = 0;
        };
    };
    
    void insert(TreeNode* root, int num) {
        TreeNode* cur = root;
        for (int j = HEIGHT; j >= 0; j--) {
            int index = ((num >> j) & 1);
            if (cur->next[index] == nullptr)
                cur->next[index] = new TreeNode();
            cur = cur->next[index];
            cur->cnt++;
        }
    }
    
    int GetCount(TreeNode* root, int num, int limit) {
        TreeNode* cur = root;
        int cnt = 0;
        for (int j = HEIGHT; j >= 0; j--) {
            int bit_num = ((num >> j) & 1);
            int bit_limit = ((limit >> j) & 1);
            
            if (bit_limit == 1) {
                // For bit_num branch, its values are all < limit
                if (cur->next[bit_num] != nullptr) {
                    cnt += cur->next[bit_num]->cnt;
                }
                // we try to pick the side so that it XOR bit_num is 1.
                cur = cur->next[1 - bit_num];
            } else {
                // we pick 'bit_num' child so that it does not exceed limit.
                cur = cur->next[bit_num];
            }
            
            if (cur == nullptr) break;
        }
        return cnt;
    }
    
public:
    int countPairs(vector<int>& nums, int low, int high) {
        TreeNode* root = new TreeNode();
        
        int ans = 0;
        for (int num : nums) {
            ans += GetCount(root, num, high + 1) - GetCount(root, num, low);
            insert(root, num);
        }
        
        return ans;
    }
};

Last updated