back to home

problem link

counting on tree

i should probably make a cf blog or something but the people on there are more unpredictable than my history teacher

so i came upon this problem and found that the editorial was absolutely awful

so we have a tree with nodes and values, normal crap
and they want us to find how many subtrees have a sum of \(t\) (\((t \le 100)\))

this is actually a pretty standard tree dp problem, we just keep an array `sum_ways` and define it like this: ``` sum_ways[n][s] = number of subtrees rooted at node n that have a sum s ``` seems pretty simple to understand (the answer we output is the sum of `sum_ways[n][t]` across all `n`)

to handle the base case for a leaf (and every node), we increment `sum_ways[node][node_value]` by 1

now to actually merge a child's array into a parent's array- the actual dp part of tree dp
let's say our parent and child array were something like this (who cares if these are actually possible):

``` parent = [0, 0, 4, 1, 5, 0, 3] (let's just say t = 6) child = [0, 1, 3, 3, 0, 1, 1] ```

how do we merge these?
let's just start with the target value because why not
there's 7 ways for the parent & child to come together and make a new subtree that sums to 6:

and for each of these `p_contrib` and `c_contrib` values, we just increment `parent[t]` by `p_contrib * c_contrib`
and then we win

"but wait, kevin!" i can hear you typing in the comment sections already
"this dp relation only accounts for the current child *adding on* to what the parent already has with the others!"
"it doesn't account for just the parent node and the child's subtrees alone!"
but note that our base case of incrementing `sum_ways[node][node_value]` by 1 accounted for that! checkmate, liberal!

but enough of me, you can just take the code and go now ig

```cpp #include <iostream> #include <cassert> #include <vector> #include <algorithm> using std::cout; using std::endl; using std::vector; using std::pair; constexpr int MOD = 1e9 + 7; // some of the values here have different names than in the editorial // i hope that doesn't ruin things for you class Tree { private: static const int ROOT = 0; const vector<vector<int>>& neighbors; const vector<int>& node_vals; vector<vector<long long>> sum_ways; int target; void process_sums(int at, int prev) { int val = node_vals[at]; // just a shorthand if (val <= target) { sum_ways[at][val]++; } for (int n : neighbors[at]) { if (n == prev) { continue; } process_sums(n, at); if (val > target) { continue; } for (int t = target; t >= 0; t--) { // this is so the update can happen simultaneously int new_val = sum_ways[at][t]; for (int a = 0; a <= t; a++) { int b = t - a; new_val = ( new_val + sum_ways[at][a] * sum_ways[n][b] ) % MOD; } sum_ways[at][t] = new_val; } } } public: Tree(const vector<vector<int>>& neighbors, const vector<int>& node_vals, int target) : neighbors(neighbors), node_vals(node_vals), target(target), sum_ways(neighbors.size(), vector<long long>(target + 1)) { assert(node_vals.size() == neighbors.size()); process_sums(ROOT, ROOT); } long long sum_num(int n) { return sum_ways[n][target]; } }; int main() { int test_num; std::cin >> test_num; for (int t = 0; t < test_num; t++) { int node_num; int target; std::cin >> node_num >> target; vector<int> node_vals(node_num); for (int& v : node_vals) { std::cin >> v; } vector<vector<int>> neighbors(node_num); for (int e = 0; e < node_num - 1; e++) { int a; int b; std::cin >> a >> b; neighbors[--a].push_back(--b); neighbors[b].push_back(a); } Tree tree(neighbors, node_vals, target); long long total_sums = 0; for (int n = 0; n < node_num; n++) { total_sums = (total_sums + tree.sum_num(n)) % MOD; } cout << total_sums << endl; } } ```