Query on a tree again!
解法
頂点0を根として考える。
ある頂点が黒に代わると、その頂点を根とする部分木内の頂点のみ影響を受ける。よって、ある部分木にいい感じの操作ができると嬉しい。ここで、オイラーツアーテクニックを考えると、ある頂点がdfs順序に初めて現れるindexから最後に現れるindexの中に部分木内の頂点のみがすべて現れていて、非常に嬉しい。列に対する操作はsegtreeでやれば出来そうで、値を追加、削除したり最小値を求めたりするので、ノードにsetを持たせれば行けそうだと分かる。これでO(N(logN)^2)で解ける。
普段書くsegtreeは、「一点更新、区間質問」や、「(遅延評価を用いて)区間更新、区間質問」、といった形だが、今回使ったsegtreeは「区間更新、一点質問」であり、勉強になった。(というか、segtreeのノードに列やsetを持たせる場合はそうなる??)
また、この問題の想定解(多分)は、結構頭いいと思った(ググると出てくる)。
コード
#include<cstdio> #include<iostream> #include<algorithm> #include<vector> #include<cstring> #include<set> using namespace std; //#define int long long typedef pair<int,int>pint; typedef vector<int>vint; typedef vector<pint>vpint; #define pb push_back #define mp make_pair #define fi first #define se second #define all(v) (v).begin(),(v).end() #define rep(i,n) for(int i=0;i<(n);i++) #define reps(i,f,n) for(int i=(f);i<(n);i++) #define each(it,v) for(__typeof((v).begin()) it=(v).begin();it!=(v).end();it++) template<class T,class U>void chmin(T &t,U f){if(t>f)t=f;} template<class T,class U>void chmax(T &t,U f){if(t<f)t=f;} struct BIT{ int N; vint dat; void init(int n){ N=n; dat.resize(N+1,0); } void add(int k,int x){ for(k++;k<=N;k+=k&-k)dat[k]+=x; } int sum(int k){ int ret=0; for(k++;k;k-=k&-k)ret+=dat[k]; return ret; } }; const int SIZE=100000; const int LOG=20; int N,Q; vint G[SIZE]; int tt,tin[SIZE],tout[SIZE],par[LOG][SIZE],dep[SIZE]; BIT bit; bool black[SIZE]; void dfs(int v,int p,int d){ tin[v]=tt++; par[0][v]=p; dep[v]=d; rep(i,G[v].size()){ int to=G[v][i]; if(to==p)continue; dfs(to,v,d+1); } tout[v]=tt; } void init(){ dfs(0,-1,0); bit.init(N+114); rep(i,LOG-1){ rep(j,N){ if(par[i][j]==-1)par[i+1][j]=-1; else par[i+1][j]=par[i][par[i][j]]; } } } int get(int v,int k){ rep(i,LOG)if(k>>i&1)v=par[i][v]; return v; } signed main(){ scanf("%d%d",&N,&Q); rep(i,N-1){ int a,b;scanf("%d%d",&a,&b);a--;b--; G[a].pb(b);G[b].pb(a); } init(); while(Q--){ int t,k; scanf("%d%d",&t,&k); k--; if(t==0){ if(black[k]){ bit.add(tin[k],-1); bit.add(tout[k],1); } else{ bit.add(tin[k],1); bit.add(tout[k],-1); } black[k]=!black[k]; } else{ if(bit.sum(tin[k])==0){ puts("-1"); continue; } int lb=0,ub=dep[k]+1; while(ub-lb>1){ int mid=(ub+lb)/2; int v=get(k,mid); if(bit.sum(tin[v]))lb=mid; else ub=mid; } printf("%d\n",get(k,lb)+1); } } return 0; }
想定解を実装してみたコード
#include<cstdio> #include<iostream> #include<algorithm> #include<vector> #include<cstring> #include<set> using namespace std; //#define int long long typedef pair<int,int>pint; typedef vector<int>vint; typedef vector<pint>vpint; #define pb push_back #define mp make_pair #define fi first #define se second #define all(v) (v).begin(),(v).end() #define rep(i,n) for(int i=0;i<(n);i++) #define reps(i,f,n) for(int i=(f);i<(n);i++) #define each(it,v) for(__typeof((v).begin()) it=(v).begin();it!=(v).end();it++) template<class T,class U>void chmin(T &t,U f){if(t>f)t=f;} template<class T,class U>void chmax(T &t,U f){if(t<f)t=f;} struct BIT{ int N; vint dat; void init(int n){ N=n; dat.resize(N+1,0); } void add(int k,int x){ for(k++;k<=N;k+=k&-k)dat[k]+=x; } int sum(int k){ int ret=0; for(k++;k;k-=k&-k)ret+=dat[k]; return ret; } }; const int SIZE=100000; const int LOG=20; int N,Q; vint G[SIZE]; int tt,tin[SIZE],tout[SIZE],par[LOG][SIZE],dep[SIZE]; BIT bit; bool black[SIZE]; void dfs(int v,int p,int d){ tin[v]=tt++; par[0][v]=p; dep[v]=d; rep(i,G[v].size()){ int to=G[v][i]; if(to==p)continue; dfs(to,v,d+1); } tout[v]=tt; } void init(){ dfs(0,-1,0); bit.init(N+114); rep(i,LOG-1){ rep(j,N){ if(par[i][j]==-1)par[i+1][j]=-1; else par[i+1][j]=par[i][par[i][j]]; } } } int get(int v,int k){ rep(i,LOG)if(k>>i&1)v=par[i][v]; return v; } signed main(){ scanf("%d%d",&N,&Q); rep(i,N-1){ int a,b;scanf("%d%d",&a,&b);a--;b--; G[a].pb(b);G[b].pb(a); } init(); while(Q--){ int t,k; scanf("%d%d",&t,&k); k--; if(t==0){ if(black[k]){ bit.add(tin[k],-1); bit.add(tout[k],1); } else{ bit.add(tin[k],1); bit.add(tout[k],-1); } black[k]=!black[k]; } else{ if(bit.sum(tin[k])==0){ puts("-1"); continue; } int lb=0,ub=dep[k]+1; while(ub-lb>1){ int mid=(ub+lb)/2; int v=get(k,mid); if(bit.sum(tin[v]))lb=mid; else ub=mid; } printf("%d\n",get(k,lb)+1); } } return 0; }