らての精進日記

修行をします

Query on a tree 2

解法

各頂点について、根からの深さと距離を求めておく。あと、ダブリングしてO(logN)でlcaを求められるようにしておく。
距離を求めるやつは、p=lca(u,v)として、根から頂点uまでの距離をdist[u]とすれば、dist[u]+dist[v]-2*dist[p]となる。こっちはめっちゃ自明。

もう一方のほうも結構自明で、if文とかを駆使して、「uとvをつなぐパス上のk番目の点を求める」->「u(またはv)の_k個上の親を求める」と簡単に変形できる。ダブリングで2^h個先の親は求めてるから、これもO(logN)で簡単にできる。

Query on a tree(1)よりめっちゃ簡単。

コード

#include<cstdio>
#include<algorithm>
#include<vector>
#include<cstring>
using namespace std;

//#define int long long

typedef pair<int,int>pint;
typedef vector<int>vint;
typedef vector<pint>vpint;
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define all(v) (v).begin(),(v).end()
#define rep(i,n) for(int i=0;i<(n);i++)
#define reps(i,f,n) for(int i=(f);i<(n);i++)
#define each(it,v) for(__typeof((v).begin()) it=(v).begin();it!=(v).end();it++)
template<class T,class U>void chmin(T &t,U f){if(t>f)t=f;}
template<class T,class U>void chmax(T &t,U f){if(t<f)t=f;}

struct edge{
    int to,cost;
    edge(int to,int cost):to(to),cost(cost){}
};

const int SIZE=10000;
const int LOG=20;
int N;
vector<edge>G[SIZE];

int par[LOG][SIZE],dep[SIZE],dist[SIZE];

void dfs(int v,int p,int d1,int d2){
    par[0][v]=p;
    dep[v]=d1;
    dist[v]=d2;
    rep(i,G[v].size()){
        edge &e=G[v][i];
        if(e.to==p)continue;
        dfs(e.to,v,d1+1,d2+e.cost);
    }
}

void init(){
    dfs(0,-1,0,0);
    rep(i,LOG-1){
        rep(j,N){
            if(par[i][j]==-1)par[i+1][j]=-1;
            else par[i+1][j]=par[i][par[i][j]];
        }
    }
}

int lca(int u,int v){
    if(dep[u]<dep[v])swap(u,v);
    rep(i,LOG)if((dep[u]-dep[v])>>i&1)u=par[i][u];
    for(int i=LOG-1;i>=0;i--)if(par[i][u]!=par[i][v])u=par[i][u],v=par[i][v];
    return u==v?u:par[0][u];
}

void solve(){
    scanf("%d",&N);
    rep(i,N)G[i].clear();
    rep(i,N-1){
        int a,b,c;
        scanf("%d%d%d",&a,&b,&c);
        a--;b--;
        G[a].pb(edge(b,c));
        G[b].pb(edge(a,c));
    }

    init();

    char type[11];
    while(scanf("%s",type),strcmp(type,"DONE")){
        if(type[0]=='D'){
            int a,b;
            scanf("%d%d",&a,&b);
            a--;b--;
            int p=lca(a,b);
            printf("%d\n",dist[a]+dist[b]-2*dist[p]);
        }
        else{
            int a,b,k;
            scanf("%d%d%d",&a,&b,&k);
            a--;b--;k--;
            int p=lca(a,b);
            if(dep[a]-dep[p]<k){
                k=dep[a]+dep[b]-2*dep[p]-k;
                swap(a,b);
            }
            rep(i,LOG)if(k>>i&1)a=par[i][a];
            printf("%d\n",a+1);
        }
    }
}

signed main(){
    int T;
    scanf("%d",&T);
    while(T--)solve();
    return 0;
}