class Solution {
public:
vector<int> graph[30005];
void traverse(vector<int> &coins,
vector<int> &cost,
vector<vector<int>> &four_distances,
int vertex,
int parentVertex)
{
int zeroDistance = coins[vertex]==1?1:0;
int oneDistance = 0,
twoDistance = 0,
greaterThanTwoDistance = 0,
totalCost = 0;
for(int i: graph[vertex]){
if(i!=parentVertex){
traverse(coins,
cost,
four_distances,
i,
vertex);
oneDistance+=four_distances[i][0];
twoDistance+=four_distances[i][1];
greaterThanTwoDistance+=four_distances[i][2] + four_distances[i][3];
if(four_distances[i][2] + four_distances[i][3])
totalCost+=cost[i]+2;
}
}
four_distances[vertex][0] = zeroDistance;
four_distances[vertex][1] = oneDistance;
four_distances[vertex][2] = twoDistance;
four_distances[vertex][3] = greaterThanTwoDistance;
cost[vertex] = totalCost;
}
void dfs(vector<int> &coins,
vector<int> &cost,
vector<vector<int>> &four_distances,
vector<bool> &vis,
int &maxCost,
int vertex)
{
if(vis[vertex])
return;
vis[vertex] = true;
maxCost = min(maxCost, cost[vertex]);
for(int i: graph[vertex]){
//Re calculate for this child;
int ci = cost[i],
cv = cost[vertex];
int pfdu0 = four_distances[vertex][0],
pfdu1 = four_distances[vertex][1],
pfdu2 = four_distances[vertex][2],
pfdu3 = four_distances[vertex][3];
int cfd0 = four_distances[i][0],
cfd1 = four_distances[i][1],
cfd2 = four_distances[i][2],
cfd3 = four_distances[i][3];
pfdu1-=cfd0;
pfdu2-=cfd1;
pfdu3-=(cfd2 + cfd3);
if(cfd2 + cfd3)
cost[vertex]-=(cost[i]+2);
cfd1+=pfdu0;
cfd2+=pfdu1;
cfd3+=(pfdu2 + pfdu3);
if(pfdu2 + pfdu3)
cost[i]+=(cost[vertex]+2);
swap(four_distances[vertex][0], pfdu0);
swap(four_distances[vertex][1], pfdu1);
swap(four_distances[vertex][2], pfdu2);
swap(four_distances[vertex][3], pfdu3);
swap(four_distances[i][0], cfd0);
swap(four_distances[i][1], cfd1);
swap(four_distances[i][2], cfd2);
swap(four_distances[i][3], cfd3);
if(!vis[i])
dfs(coins, cost, four_distances, vis, maxCost, i);
swap(four_distances[vertex][0], pfdu0);
swap(four_distances[vertex][1], pfdu1);
swap(four_distances[vertex][2], pfdu2);
swap(four_distances[vertex][3], pfdu3);
swap(four_distances[i][0], cfd0);
swap(four_distances[i][1], cfd1);
swap(four_distances[i][2], cfd2);
swap(four_distances[i][3], cfd3);
swap(cost[i], ci);
swap(cost[vertex], cv);
}
}
int collectTheCoins(vector<int>& coins, vector<vector<int>>& edges) {
int N = coins.size();
if(!N)
return 0;
for(int i = 0; i<30005; i++)
graph[i].clear();
vector<int> degree(N, 0);
for(auto i: edges){
graph[i[0]].push_back(i[1]);
graph[i[1]].push_back(i[0]);
degree[i[0]]++;
degree[i[1]]++;
}
vector<vector<int>> four_distances(N, vector<int>(4, 0));
vector<int> cost(N, INT_MAX);
traverse(coins, cost, four_distances, 0, -1);
int maxCost = cost[0];
vector<bool> vis(N, false);
dfs(coins, cost, four_distances, vis, maxCost, 0);
return maxCost;
}
};