Post 37

终于调出了树链第k大。。。

Count on a tree

题目描述

给定一棵N个节点的树,每个点有一个权值,对于M个询问(u,v,k),你需要回答u xor lastans和v这两个节点间第K小的点权。其中lastans是上一个询问的答案,初始为0,即第一个询问的u是明文。

输入输出格式

输入格式:

第一行两个整数N,M。

第二行有N个整数,其中第i个整数表示点i的权值。

后面N-1行每行两个整数(x,y),表示点x到点y有一条边。

最后M行每行两个整数(u,v,k),表示一组询问。

输出格式:

M行,表示每个询问的答案。

输入输出样例

输入样例#1:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
8 5
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5 1
0 5 2
10 5 3
11 5 4
110 8 2

输出样例#1:

1
2
3
4
5
2
8
9
105
7

题解

这题思路学过可持久化不难想,只不过我写出了好几个sb的错误

一开始在学校写,我建可持久化线段树居然是枚举每个点插入。。。。可以想到这是多么的荒谬,因为枚举到的点的父亲必须在它之前插入才能维护到跟的前缀权值线段树啊。。这还是我今天又看才突然反应过来的。。

其次一个更sb的错误是我居然在查询中把$tot \leq k$当成是进左子树的条件。。。

这种把左右写反的错误已经不是一次两次了,根源是写的太快不过脑子。

一定要避免这种情况,否则调试代价不小(主要是这次调试也挺失败,输出关键信息锁定错误的能力还需要提高)。

现在来说说思路。

主要思路是利用可持久化权值线段树维护树上前缀路径在每个值域的个数,然后利用个数的可减性快速计算出当前值域下链上的个数,假如小于等于$k$就继续递归左子树查找,假如大于$k$就在右子树查找$k-cur$大的值。

主要利用了值域二分和差分统计的思想,不是很难想, 其实也不难写,结果还是挂了好几次。。

话说我也不知道为什么win10下Vim的tab缩进在Vim下显示的好好的,一复制出来就没了。。凑合着看吧。

Code:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <iostream>
#include <vector>
#define maxn 200005
std::vector<int> g[maxn];
int n , rev[maxn] , lsa , range , m , tot , val[maxn] , sgt[maxn<<5] , lc[maxn<<5] , rc[maxn<<5] , sz[maxn] , hs[maxn] , f[maxn] , top[maxn] , dep[maxn] , rt[maxn];
struct Node{
int v , id;
bool operator<(const Node& p)const{
return v < p.v;
}
}p[maxn];
void dfs1(int x , int fx)
{
sz[x] = 1;
dep[x] = dep[fx] + 1;
f[x] = fx;
for(int i = 0 ; i < (int)g[x].size() ; ++i)
{
int v = g[x][i];
if(v == fx) continue;
dfs1(v , x) , sz[x] += sz[v];
if(sz[hs[x]] < sz[v]) hs[x] = v;
}
}
void dfs2(int x , int tp)
{
top[x] = tp;
if(!hs[x]) return ;
dfs2(hs[x] , tp);
for(int i = 0 ; i < (int)g[x].size() ; ++i)
{
int v = g[x][i];
if(v == hs[x] || v == f[x]) continue;
dfs2(v , v);
}
}

inline int LCA(int x , int y)
{
while(top[x] != top[y])
{
if(dep[top[x]] < dep[top[y]]) std::swap(x,y);
x = f[top[x]];
}
if(dep[x] > dep[y]) std::swap(x,y);
return x;
}
inline void pushup(int x){
sgt[x] = sgt[lc[x]] + sgt[rc[x]];
}

inline void cpy(int x , int y){
lc[x] = lc[y];
rc[x] = rc[y];
sgt[x] = sgt[y];
}

void build(int& p , int l, int r)
{
p = ++tot;
if(l == r) return;
int mid = l + r >> 1;
build(lc[p] , l , mid);
build(rc[p] , mid + 1, r);
}

void insert(int& p , int fp , int l , int r , int v)
{
p = ++tot;
cpy(p,fp);
if(l == r){
sgt[p] ++ ;
return ;
}
int mid = l + r >> 1;
if(v <= mid) insert(lc[p] , lc[fp] , l , mid , v);
else insert(rc[p] , rc[fp] , mid + 1 ,r , v);
pushup(p);
}

void buildSgt(int x , int fx)
{
insert(rt[x] , rt[fx] , 1 , range , val[x]);
for(int i = 0 ; i < (int)g[x].size() ; ++i)
{
int v = g[x][i];
if(v == fx) continue;
buildSgt(v , x);
}
}

int query(int x , int y , int lca, int flca , int l , int r , int k)
{
if(l == r) return l;
int tot = sgt[lc[x]] + sgt[lc[y]] - sgt[lc[lca]] - sgt[lc[flca]] , mid = l + r >> 1;
// printf("QUERY INFO LINK INFO : %d %d %d %d\n",sgt[x],sgt[y],sgt[lca],sgt[flca]);
if(k <= tot)
return query(lc[x] , lc[y] , lc[lca] , lc[flca] , l, mid , k);
else
return query(rc[x] , rc[y] , rc[lca] , rc[flca] , mid + 1 , r , k-tot);
}

void print(int p , int l , int r)
{
if(l == r){
// printf("LEAF VAL : %d number:%d\n",l,sgt[p]);
return;
}
int mid = l + r >> 1;
// printf("THE NUMBER OF %d to %d:%d\n",l,r,sgt[p]);
print(lc[p],l,mid);
print(rc[p],mid+1,r);
}

inline void pre()
{
for(int i = 1 ; i <= n ; ++i)
p[i].v = val[i] , p[i].id = i;
std::sort(p+1,p+n+1);
for(int i = 1 ; i <= n ; ++i)
{
if(p[i].v != p[i-1].v) ++range;
rev[range] = p[i].v; val[p[i].id] = range;
}
dfs1(1,0);
dfs2(1,0);
build(rt[0],1,range);
buildSgt(1,0);
}

int main()
{
scanf("%d%d",&n,&m);
for(int i = 1 ; i <= n ; ++i)
scanf("%d",&val[i]);
for(int i = 1 ; i <= n - 1; ++i)
{
int x , y;
scanf("%d%d",&x,&y);
g[x].push_back(y);
g[y].push_back(x);
}
pre();
for(int i = 1 ; i <= m ; ++i)
{
int x , y, k , lca;
scanf("%d%d%d",&x,&y,&k);
// x ^= lsa;
lca = LCA(x,y);
printf("%d\n",lsa = rev[query(rt[x],rt[y],rt[lca],rt[f[lca]],1,range,k)]);
}
}