# Question 79 - LC 222. Count Complete Tree Nodes

Given the `root` of a **complete** binary tree, return the number of the nodes in the tree.

According to Wikipedia, every level, except possibly the last, is completely filled in a complete binary tree, and all nodes in the last level are as far left as possible. It can have between `1` and `2^h` nodes inclusive at the last level `h` (h starts from 0).

Design an algorithm that runs in less than `O(n)` time complexity.

**Example 1:**

    Input: root = [1,2,3,4,5,6]
    Output: 6

**Example 2:**

    Input: root = []
    Output: 0

**Example 3:**

    Input: root = [1]
    Output: 1

**Constraints:**

- The number of nodes in the tree is in the range `[0, 5 * 10^4]`.
- `0 <= Node.val <= 5 * 10^4`
- The tree is guaranteed to be **complete**.


In [1]:
from __future__ import annotations
from typing import Optional
from collections import deque

In [2]:
# Definition for a binary tree node.
class TreeNode:
    def __init__(
        self,
        val: int = 0,
        left: Optional[TreeNode] = None,
        right: Optional[TreeNode] = None,
    ):
        self.val = val
        self.left = left
        self.right = right

    @staticmethod
    def from_list(values: list[Optional[int]]) -> Optional[TreeNode]:
        """
        Constructs a binary tree from a list of values using level-order traversal (BFS).

        Parameters:
        - values (list): The list of values to construct the binary tree from.

        Returns:
        TreeNode: The root of the constructed binary tree.
        """
        if not values:
            return None

        root = TreeNode(val=values[0])
        queue = deque([root])
        i = 1

        while queue and i < len(values):
            current_node = queue.popleft()

            if values[i] is not None:
                current_node.left = TreeNode(val=values[i])
                queue.append(current_node.left)
            i += 1

            if i < len(values) and values[i] is not None:
                current_node.right = TreeNode(val=values[i])
                queue.append(current_node.right)
            i += 1

        return root

    def print_level_order(self):
        """
        Prints the level order traversal of the tree (BFS).

        Returns:
        list: A list of lists, where each sublist contains the values of the tree nodes
              at that depth.
        """
        queue = deque([(self, 0)])
        result = []
        current_level = []
        level_number = 0

        while queue:
            current_node, node_level = queue.popleft()

            if node_level > level_number:
                result.append(current_level)
                current_level = []
                level_number = node_level

            if current_node:
                current_level.append(current_node.val)
                queue.append((current_node.left, node_level + 1))
                queue.append((current_node.right, node_level + 1))
            else:
                current_level.append("None")

        if current_level:
            result.append(current_level)

        # Remove trailing "None" values
        result = [level for level in result if any(elem != "None" for elem in level)]

        return result

In [3]:
class Solution:
    def countNodes(self, root: Optional[TreeNode]) -> int:
        if not root:
            return 0

        # Count the height of the tree
        height = 0
        node = root
        while node.left:
            node = node.left
            height += 1

        # function to check if the node at index exists
        def node_exists(index: int, height: int, node: Optional[TreeNode]) -> bool:
            left, right = 0, 2 ** height - 1
            for _ in range(height):
                pivot = (left + right) // 2
                if index <= pivot:
                    node = node.left
                    right = pivot
                else:
                    node = node.right
                    left = pivot + 1
            return node is not None

        # Binary search to find the last existing node
        left, right = 0, 2 ** height - 1
        while left <= right:
            pivot = (left + right) // 2
            if node_exists(pivot, height, root):
                left = pivot + 1
            else:
                right = pivot - 1

        # The total number of nodes is the sum of the nodes in all levels except the last,
        # plus the number of nodes in the last level
        # 2**height - 1 is the number of nodes in all levels except the last
        # 2^0 + 2^1 + ... + 2^(height-1) = 2^height - 1 (formula for sum of a geometric series)
        # Sn = a1 * (1 - r^n) / (1 - r), where a1 = 1, r = 2, n = height
        return (2**height - 1) + left

In [4]:
# Test

s = Solution()

root = TreeNode.from_list([1, 2, 3, 4, 5, 6])
assert s.countNodes(root) == 6

root = TreeNode.from_list([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14])
assert s.countNodes(root) == 14

root = TreeNode.from_list([1, 2, 3, 4, 5, 6, 7])
assert s.countNodes(root) == 7

root = TreeNode.from_list([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13])
assert s.countNodes(root) == 13

root = TreeNode.from_list([1])
assert s.countNodes(root) == 1

root = TreeNode.from_list([])
assert s.countNodes(root) == 0