POJ - 2114 Boatherds(点分治)

思路

点分治裸题,但统计长度为k的路径时需要注意要使用O(n)的算法(我用了O(n^2)的居然没发现,TLE。。)具体思路就是先对depth排序,然后当depth[l]+depth[r]==k是,因为可能有与depth[l],depth[r]深度相同的,所以统计相同的个数直接numl*numr计算。若depth[l]+depth[r]<k那么l++,因为这时对当前的l能满足条件的r都已经计算过了,所以直接l++。若depth[l]+depth[r]>k,思路一样,r—。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#include <cstdio>
#include <cstring>
#include <algorithm>
#define ll long long
#define INF 0x7fffffff
#define ms(a,val) memset((a),(val),(sizeof(a)))

using namespace std;

const int MAXN=10005,MAXM=10005*2;
int head[MAXN],nxt[MAXM],to[MAXM],val[MAXM],edgecnt,n,sum;
void addedge(int u,int v,int w){
to[edgecnt]=v;
val[edgecnt]=w;
nxt[edgecnt]=head[u];
head[u]=edgecnt++;
}
int son[MAXN],maxson[MAXN],vis[MAXN],root,depth[MAXN],d[MAXN],ans,k;
void getroot(int u,int fa){
son[u]=1,maxson[u]=0;
for(int i=head[u];~i;i=nxt[i]){
int v=to[i];
if(v!=fa&&!vis[v]){
getroot(v,u);
son[u]+=son[v];
maxson[u]=max(maxson[u],son[v]);
}
}
maxson[u]=max(maxson[u],sum-son[u]);
if(maxson[u]<maxson[root])root=u;
}
void getdeep(int u,int fa){
depth[++depth[0]]=d[u];
for(int i=head[u];~i;i=nxt[i]){
int v=to[i];
if(v!=fa&&!vis[v]){
d[v]=d[u]+val[i];
getdeep(v,u);
}
}
}
int cal(int u,int cost){
d[u]=cost,depth[0]=0;
getdeep(u,-1);
sort(depth+1,depth+depth[0]+1);
int r=depth[0],res=0;
for(int l=1;l<r;){
if(depth[l]+depth[r]==k){
if(depth[l]==depth[r]){
res+=(r-l)*(r-l+1)/2;
break;
}
int tl=l,tr=r;
while(depth[tl]==depth[l])tl++;
while(depth[tr]==depth[r])tr--;
res+=(tl-l)*(r-tr);
l=tl,r=tr;
}
else if(depth[l]+depth[r]<k)l++;
else r--;
}
return res;
}
void solve(int u){
ans+=cal(u,0);
vis[u]=1;
for(int i=head[u];~i;i=nxt[i]){
int v=to[i];
if(!vis[v]){
ans-=cal(v,val[i]);
sum=son[v];
root=0;
getroot(v,-1);
solve(root);
}
}
}

int main() {
int q;
while(scanf("%d",&n)&&n){
ms(head,-1);
edgecnt=0;
for(int i=1;i<=n;i++){
int v,w;
while(scanf("%d",&v)&&v){
scanf("%d",&w);
addedge(i,v,w),addedge(v,i,w);
}
}
while(scanf("%d",&q)&&q){
ms(depth,0),ms(son,0),ms(maxson,0),ms(vis,0);
ans=0,k=q,root=0,sum=n,maxson[0]=INF;
getroot(1,-1);
solve(root);
if(ans)printf("AYE\n");
else printf("NAY\n");
}
printf(".\n");
}
return 0;
}