class Solution {
private:
int countFromBottom(vector<int>& height) {
int m = height.size();
stack<int> stk;
int res = 0;
for (int i = 0; i < m; i++) {
while (!stk.empty() && height[stk.top()] >= height[i]) {
int current = stk.top();
stk.pop();
int left = !stk.empty() ? stk.top() : -1;
int len = i - left - 1;
int bottom = max(left == -1 ? 0 : height[left], height[i]);
res += (height[current] - bottom) * len * (len + 1) / 2;
}
stk.push(i);
}
while (!stk.empty()) {
int current = stk.top();
stk.pop();
int left = !stk.empty() ? stk.top() : -1;
int len = m - left - 1;
int bottom = left == -1 ? 0 : height[left];
res += (height[current] - bottom) * len * (len + 1) / 2;
}
return res;
}
public:
int numSubmat(vector<vector<int>>& mat) {
int n = mat.size(), m = mat[0].size();
int res = 0;
vector<int> height(m, 0);
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
height[j] = mat[i][j] == 0 ? 0 : height[j] + 1;
}
res += countFromBottom(height);
}
return res;
}
};