http://www.lydsy.com/JudgeOnline/problem.php?id=5024
首先吐槽一下题面是有错误的
那一个"或" 应该改成","
这道题目条件是非常绕的
先看一个简化的问题
我们对于\(u\), \(v\)
如果\(u-v\)是良的 那么染成白色
否则染成黑色
问题转化成一个经典问题
给定一个完全图 求同色三角形的个数
做法就是考虑非同色三角形的个数
找角的个数\(jiao\) 答案就是\(C_n ^ 3 - jiao / 2\)
考虑如何在\(O(distance(u, v))\) 判断是否\(u - v\)是否是良的
观察计算\(P\)的方式
由于是末尾是\(011\)
可以考虑在模\(8\)的意义下来计算
于是$P = 1 \times s_1 + 5 \times s_2 + 1 \times s_3 ... $
此时考虑\(1 = 001_2\), \(5 = 101_2\)
末尾的\(01\)是相同的
于是我们只需要考虑倒数第三位是否是\(1\)或者是\(0\)
倒数第三位的贡献是由谁提供的呢?
分成两部分
- 所有数的和 的倒数第三位
- 偶数位的数 的倒数第一位
由于第一部分贡献已知
只需要看第二部分是否满足
问题变成
考虑一个\(01\)串 可以从头尾取
是否能满足偶数位为\(0\)为\(1\)
如果全\(0\)全\(1\)那么答案固定
否则\(01\)都是满足的
证明的话考虑两头类似于括号序列的消去
到现在为止
我们可以在\(O(distance(u, v))\) 判断是否\(u - v\)是否是良的
于是就可以树形dp了啊
记录状态\(f[sum][dep][k]\) 表示子树的答案
\(sum\)表示路径的和 模 \(8\)
\(dep\)表示最低位的和 模 \(4\)
\(k\) 表示 \(all_0 / all_1 / others\)
然后我们就求出了一个点为根的答案
然后换根 dp 一下就好了
复杂度\(O(n)\)
以下代码bzoj不能过 爆栈了.. xjoi可以过..
不会写手工栈.jpg
#pragma GCC optimize(2)#pragma comment(linker, "/STACK:2048000000,2048000000")#include<bits/stdc++.h>#define int long long#define fo(i, n) for(int i = 1; i <= (n); i ++)#define out(x) cerr << #x << " = " << x << "\n"#define type(x) __typeof((x).begin())#define foreach(it, x) for(type(x) it = (x).begin(); it != (x).end(); ++ it)using namespace std;// by pianotemplate<typename tp> inline void read(tp &x) { ?x = 0; char c = getchar(); bool f = 0; ?for(; c < '0' || c > '9'; f |= (c == '-'), c = getchar()); ?for(; c >= '0' && c <= '9'; x = (x << 3) + (x << 1) + c - '0', c = getchar()); ?if(f) x = -x;}template<typename tp> inline void arr(tp *a, int n) { ?for(int i = 1; i <= n; i ++) ???cout << a[i] << " "; ?puts("");}const int N = 3e5 + 233;struct E { ?int nxt, to;}e[N << 1];int head[N], e_cnt = 0; ?inline void add(int u, int v) { ?e[++ e_cnt] = (E) {head[u], v}; head[u] = e_cnt;} ?struct Node { ?// sum % 8, dep % 4, all_0 / all_1 / others ?int f[8][4][3]; ??inline void clear(void) { ???memset(f, 0, sizeof f); ?} ????inline void init(int val, int fff) { ???int sum = val % 8, dep = val & 1; ???int k = dep; ???f[sum][dep][k] += fff; ?} ?inline void ovo(void) { ???for(int sum = 0; sum < 8; sum ++) ?????for(int dep = 0; dep < 4; dep ++) ???????for(int k = 0; k < 3; k ++) ?????????if(f[sum][dep][k]) ???????????printf("f[%lld][%lld][%lld] = %lld\n", sum, dep, k, f[sum][dep][k]); ????????}}p[N], tmp;int n, fat, val[N], ans[N]; inline void U(Node &a, Node b, int fff) { ?for(int sum = 0; sum < 8; sum ++) ???for(int dep = 0; dep < 4; dep ++) ?????for(int k = 0; k < 3; k ++) ???????a.f[sum][dep][k] += b.f[sum][dep][k] * fff;} inline void Get(Node u, int val) { ?tmp.clear(); ?for(int sum = 0; sum < 8; sum ++) { ???for(int dep = 0; dep < 4; dep ++) { ?????int ns = (val % 8 + sum) % 8; ?????int nd = ((val & 1) + dep) % 4; ?????int ok = val & 1; ?????if(ns >= 8 || nd >= 4) while(1); ?????for(int k = 0; k < 3; k ++) { ???????if(ok == k) ?????????tmp.f[ns][nd][k] += u.f[sum][dep][k]; ???????else ?????????tmp.f[ns][nd][2] += u.f[sum][dep][k]; ?????} ???} ?}} inline int Getans(const Node &u) { ?int ans = 0; ?for(int dep = 0; dep < 4; dep ++) ???ans += u.f[3][dep][0] + u.f[3][dep][2]; ?for(int dep = 0; dep < 4; dep ++) ???if((dep / 2) % 2 == 0) ?????ans += u.f[3][dep][1]; ?for(int dep = 0; dep < 4; dep ++) { ???if((dep / 2) % 2 == 1) ?????ans += u.f[7][dep][1]; ???ans += u.f[7][dep][2]; ?} ?return ans * (n - 1 - ans);} inline void dfs(int u, int fat) { ?for(int i = head[u]; i; i = e[i].nxt) { ???int v = e[i].to; ???if(v != fat) dfs(v, u), Get(p[v], val[u]), U(p[u], tmp, 1); ?} ?p[u].init(val[u], 1);} inline void frt(int u, int fat) { ?p[u].init(val[u], -1); ?ans[u] = Getans(p[u]); ?p[u].init(val[u], 1); ?for(int i = head[u]; i; i = e[i].nxt) { ???int v = e[i].to; ???if(v != fat) { ?????Get(p[v], val[u]); U(p[u], tmp, -1); ?????Get(p[u], val[v]); U(p[v], tmp, 1); ?????frt(v, u); ?????Get(p[u], val[v]); U(p[v], tmp, -1); ?????Get(p[v], val[u]); U(p[u], tmp, 1); ???} ?}} ?main(void) { ?read(n); ?for(int i = 1; i <= n; i ++) { ???read(fat); read(val[i]); val[i] &= 7; ???add(i, fat); add(fat, i); ?} ?dfs(1, 0); frt(1, 0); ?int res = 0; ?fo(i, n) res += ans[i]; ?res /= 2; ?cout << n * (n - 1) * (n - 2) / 6 - res << "\n"; ????}