Skip to content

Count Submatrices With All Ones⚓︎

Link

Description⚓︎

Given an m x n binary matrix mat, return the number of submatrices that have all ones.

Example 1:

Ex1

  • Input: mat = [[1,0,1],[1,1,0],[1,1,0]]
  • Output: 13
  • Explanation:
1
2
3
4
5
6
There are 6 rectangles of side 1x1.
There are 2 rectangles of side 1x2.
There are 3 rectangles of side 2x1.
There is 1 rectangle of side 2x2. 
There is 1 rectangle of side 3x1.
Total number of rectangles = 6 + 2 + 3 + 1 + 1 = 13.

Example 2:

Ex2

  • Input: mat = [[0,1,1,0],[0,1,1,1],[1,1,1,0]]
  • Output: 24
  • Explanation:
1
2
3
4
5
6
7
8
There are 8 rectangles of side 1x1.
There are 5 rectangles of side 1x2.
There are 2 rectangles of side 1x3. 
There are 4 rectangles of side 2x1.
There are 2 rectangles of side 2x2. 
There are 2 rectangles of side 3x1. 
There is 1 rectangle of side 3x2. 
Total number of rectangles = 8 + 5 + 2 + 4 + 2 + 2 + 1 = 24.

Constraints:

  • 1 <= m, n <= 150
  • mat[i][j] is either 0 or 1.

Solution⚓︎

Monotonic Stack⚓︎

class Solution {
private:
    int countFromBottom(vector<int>& height) {
        int m = height.size();
        stack<int> stk;
        int res = 0;

        for (int i = 0; i < m; i++) {
            while (!stk.empty() && height[stk.top()] >= height[i]) {
                int current = stk.top();
                stk.pop();

                int left = !stk.empty() ? stk.top() : -1;
                int len = i - left - 1;
                int bottom = max(left == -1 ? 0 : height[left], height[i]);

                res += (height[current] - bottom) * len * (len + 1) / 2;
            }
            stk.push(i);
        }

        while (!stk.empty()) {
            int current = stk.top();
            stk.pop();

            int left = !stk.empty() ? stk.top() : -1;
            int len = m - left - 1;
            int bottom = left == -1 ? 0 : height[left];

            res += (height[current] - bottom) * len * (len + 1) / 2;
        }

        return res;
    }

public:
    int numSubmat(vector<vector<int>>& mat) {
        int n = mat.size(), m = mat[0].size();
        int res = 0;

        vector<int> height(m, 0);
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < m; j++) {
                height[j] = mat[i][j] == 0 ? 0 : height[j] + 1;
            }
            res += countFromBottom(height);
        }

        return res;
    }
};