class Solution {
private:
vector<int> p;
vector<int> maxCnt;
void build(int n) {
p.resize(n);
maxCnt.resize(n);
for (int i = 0; i < n; i++) {
p[i] = i;
maxCnt[i] = 1;
}
}
int find(int i) {
if (i != p[i]) p[i] = find(p[i]);
return p[i];
}
int unionSets(int x, int y, vector<int>& vals) {
int fx = find(x), fy = find(y);
int path = 0;
if (vals[fx] > vals[fy]) {
p[fy] = fx;
} else if (vals[fx] < vals[fy]) {
p[fx] = fy;
} else {
path = maxCnt[fx] * maxCnt[fy];
p[fy] = fx;
maxCnt[fx] += maxCnt[fy];
}
return path;
}
public:
int numberOfGoodPaths(vector<int>& vals, vector<vector<int>>& edges) {
int n = vals.size();
build(n);
int res = n;
sort(edges.begin(), edges.end(), [&](auto& lhs, auto& rhs) {
return max(vals[lhs[0]], vals[lhs[1]]) < max(vals[rhs[0]], vals[rhs[1]]);
});
for (auto edge : edges) {
res += unionSets(edge[0], edge[1], vals);
}
return res;
}
};