Count Complete Tree Nodes
Problem
Given the root
of a complete binary tree, return the number of the nodes in the tree.
According to Wikipedia, every level, except possibly the last, is completely filled in a complete binary tree, and all nodes in the last level are as far left as possible. It can have between 1
and 2h
nodes inclusive at the last level h
.
Design an algorithm that runs in less than O(n)
time complexity.
Constraints
- The number of nodes in the tree is in the range
[0, 5 * 104]
. 0 <= Node.val <= 5 * 104
- The tree is guaranteed to be complete.
Solution
The problem Count Complete Tree Nodes
can be solved using a binary search. As the given binary tree is guaranteed to be complete, we only need to check the last level of the tree to find the number of nodes in the tree. In order to achieve this without using a depth-first search which requires O(n)
time complexity, we can perform a binary search to check for the last existing node in the last level of the tree. Complexity-wise, this requires a binary search for n / 2
number of nodes which results in O(log n)
number of searches. For each search, we need to traverse through the tree to reach the last level which requires O(log n)
operations per each search. Therefore, the time complexity of this algorithm becomes O(log2 n)
which is less than O(n)
as per the problem requirement.
Implementation
/**
* Definition for a binary tree node.
* struct TreeNode {
* int val;
* TreeNode *left;
* TreeNode *right;
* TreeNode() : val(0), left(nullptr), right(nullptr) {}
* TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
* TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
* };
*/
class Solution
{
public:
int countNodes(TreeNode *root)
{
int depth = 0;
if (root == NULL)
return depth;
TreeNode *node = root;
while ((node = node->left) != NULL)
depth += 1;
int left = 1 << depth;
int right = (left << 1) - 1;
while (left <= right)
{
node = root;
int mid = (left + right) / 2;
bitset<16> movement(mid);
for (int i = 0; i < depth; i++)
node = movement[depth - 1 - i] ? node->right : node->left;
if (node == NULL)
right = mid - 1;
else
left = mid + 1;
}
return left - 1;
}
};