らての精進日記

修行をします

Query on a tree again!

解法

頂点0を根として考える。
ある頂点が黒に代わると、その頂点を根とする部分木内の頂点のみ影響を受ける。よって、ある部分木にいい感じの操作ができると嬉しい。ここで、オイラーツアーテクニックを考えると、ある頂点がdfs順序に初めて現れるindexから最後に現れるindexの中に部分木内の頂点のみがすべて現れていて、非常に嬉しい。列に対する操作はsegtreeでやれば出来そうで、値を追加、削除したり最小値を求めたりするので、ノードにsetを持たせれば行けそうだと分かる。これでO(N(logN)^2)で解ける。

普段書くsegtreeは、「一点更新、区間質問」や、「(遅延評価を用いて)区間更新、区間質問」、といった形だが、今回使ったsegtreeは「区間更新、一点質問」であり、勉強になった。(というか、segtreeのノードに列やsetを持たせる場合はそうなる??)



また、この問題の想定解(多分)は、結構頭いいと思った(ググると出てくる)。

コード

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<vector>
#include<cstring>
#include<set>
using namespace std;
 
//#define int long long
 
typedef pair<int,int>pint;
typedef vector<int>vint;
typedef vector<pint>vpint;
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define all(v) (v).begin(),(v).end()
#define rep(i,n) for(int i=0;i<(n);i++)
#define reps(i,f,n) for(int i=(f);i<(n);i++)
#define each(it,v) for(__typeof((v).begin()) it=(v).begin();it!=(v).end();it++)
template<class T,class U>void chmin(T &t,U f){if(t>f)t=f;}
template<class T,class U>void chmax(T &t,U f){if(t<f)t=f;}
 
struct BIT{
    int N;
    vint dat;
    void init(int n){
        N=n;
        dat.resize(N+1,0);
    }
    void add(int k,int x){
        for(k++;k<=N;k+=k&-k)dat[k]+=x;
    }
    int sum(int k){
        int ret=0;
        for(k++;k;k-=k&-k)ret+=dat[k];
        return ret;
    }
};
 
const int SIZE=100000;
const int LOG=20;
 
int N,Q;
vint G[SIZE];
int tt,tin[SIZE],tout[SIZE],par[LOG][SIZE],dep[SIZE];
BIT bit;
bool black[SIZE];
 
void dfs(int v,int p,int d){
    tin[v]=tt++;
    par[0][v]=p;
    dep[v]=d;
 
    rep(i,G[v].size()){
        int to=G[v][i];
        if(to==p)continue;
        dfs(to,v,d+1);
    }
    tout[v]=tt;
}
 
void init(){
    dfs(0,-1,0);
 
    bit.init(N+114);
    rep(i,LOG-1){
        rep(j,N){
            if(par[i][j]==-1)par[i+1][j]=-1;
            else par[i+1][j]=par[i][par[i][j]];
        }
    }
}
 
int get(int v,int k){
    rep(i,LOG)if(k>>i&1)v=par[i][v];
    return v;
}
 
signed main(){
    scanf("%d%d",&N,&Q);
    rep(i,N-1){
        int a,b;scanf("%d%d",&a,&b);a--;b--;
        G[a].pb(b);G[b].pb(a);
    }
 
    init();
 
    while(Q--){
        int t,k;
        scanf("%d%d",&t,&k);
        k--;
        if(t==0){
            if(black[k]){
                bit.add(tin[k],-1);
                bit.add(tout[k],1);
            }
            else{
                bit.add(tin[k],1);
                bit.add(tout[k],-1);
            }
            black[k]=!black[k];
        }
        else{
            if(bit.sum(tin[k])==0){
                puts("-1");
                continue;
            }
 
            int lb=0,ub=dep[k]+1;
            while(ub-lb>1){
                int mid=(ub+lb)/2;
                int v=get(k,mid);
                if(bit.sum(tin[v]))lb=mid;
                else ub=mid;
            }
            printf("%d\n",get(k,lb)+1);
        }
    }
 
    return 0;
}

想定解を実装してみたコード

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<vector>
#include<cstring>
#include<set>
using namespace std;

//#define int long long

typedef pair<int,int>pint;
typedef vector<int>vint;
typedef vector<pint>vpint;
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define all(v) (v).begin(),(v).end()
#define rep(i,n) for(int i=0;i<(n);i++)
#define reps(i,f,n) for(int i=(f);i<(n);i++)
#define each(it,v) for(__typeof((v).begin()) it=(v).begin();it!=(v).end();it++)
template<class T,class U>void chmin(T &t,U f){if(t>f)t=f;}
template<class T,class U>void chmax(T &t,U f){if(t<f)t=f;}

struct BIT{
    int N;
    vint dat;
    void init(int n){
        N=n;
        dat.resize(N+1,0);
    }
    void add(int k,int x){
        for(k++;k<=N;k+=k&-k)dat[k]+=x;
    }
    int sum(int k){
        int ret=0;
        for(k++;k;k-=k&-k)ret+=dat[k];
        return ret;
    }
};

const int SIZE=100000;
const int LOG=20;

int N,Q;
vint G[SIZE];
int tt,tin[SIZE],tout[SIZE],par[LOG][SIZE],dep[SIZE];
BIT bit;
bool black[SIZE];

void dfs(int v,int p,int d){
    tin[v]=tt++;
    par[0][v]=p;
    dep[v]=d;

    rep(i,G[v].size()){
        int to=G[v][i];
        if(to==p)continue;
        dfs(to,v,d+1);
    }
    tout[v]=tt;
}

void init(){
    dfs(0,-1,0);

    bit.init(N+114);
    rep(i,LOG-1){
        rep(j,N){
            if(par[i][j]==-1)par[i+1][j]=-1;
            else par[i+1][j]=par[i][par[i][j]];
        }
    }
}

int get(int v,int k){
    rep(i,LOG)if(k>>i&1)v=par[i][v];
    return v;
}

signed main(){
    scanf("%d%d",&N,&Q);
    rep(i,N-1){
        int a,b;scanf("%d%d",&a,&b);a--;b--;
        G[a].pb(b);G[b].pb(a);
    }

    init();

    while(Q--){
        int t,k;
        scanf("%d%d",&t,&k);
        k--;
        if(t==0){
            if(black[k]){
                bit.add(tin[k],-1);
                bit.add(tout[k],1);
            }
            else{
                bit.add(tin[k],1);
                bit.add(tout[k],-1);
            }
            black[k]=!black[k];
        }
        else{
            if(bit.sum(tin[k])==0){
                puts("-1");
                continue;
            }

            int lb=0,ub=dep[k]+1;
            while(ub-lb>1){
                int mid=(ub+lb)/2;
                int v=get(k,mid);
                if(bit.sum(tin[v]))lb=mid;
                else ub=mid;
            }
            printf("%d\n",get(k,lb)+1);
        }
    }

    return 0;
}