Query on a tree 2
解法
各頂点について、根からの深さと距離を求めておく。あと、ダブリングしてO(logN)でlcaを求められるようにしておく。
距離を求めるやつは、p=lca(u,v)として、根から頂点uまでの距離をdist[u]とすれば、dist[u]+dist[v]-2*dist[p]となる。こっちはめっちゃ自明。
もう一方のほうも結構自明で、if文とかを駆使して、「uとvをつなぐパス上のk番目の点を求める」->「u(またはv)の_k個上の親を求める」と簡単に変形できる。ダブリングで2^h個先の親は求めてるから、これもO(logN)で簡単にできる。
Query on a tree(1)よりめっちゃ簡単。
コード
#include<cstdio> #include<algorithm> #include<vector> #include<cstring> using namespace std; //#define int long long typedef pair<int,int>pint; typedef vector<int>vint; typedef vector<pint>vpint; #define pb push_back #define mp make_pair #define fi first #define se second #define all(v) (v).begin(),(v).end() #define rep(i,n) for(int i=0;i<(n);i++) #define reps(i,f,n) for(int i=(f);i<(n);i++) #define each(it,v) for(__typeof((v).begin()) it=(v).begin();it!=(v).end();it++) template<class T,class U>void chmin(T &t,U f){if(t>f)t=f;} template<class T,class U>void chmax(T &t,U f){if(t<f)t=f;} struct edge{ int to,cost; edge(int to,int cost):to(to),cost(cost){} }; const int SIZE=10000; const int LOG=20; int N; vector<edge>G[SIZE]; int par[LOG][SIZE],dep[SIZE],dist[SIZE]; void dfs(int v,int p,int d1,int d2){ par[0][v]=p; dep[v]=d1; dist[v]=d2; rep(i,G[v].size()){ edge &e=G[v][i]; if(e.to==p)continue; dfs(e.to,v,d1+1,d2+e.cost); } } void init(){ dfs(0,-1,0,0); rep(i,LOG-1){ rep(j,N){ if(par[i][j]==-1)par[i+1][j]=-1; else par[i+1][j]=par[i][par[i][j]]; } } } int lca(int u,int v){ if(dep[u]<dep[v])swap(u,v); rep(i,LOG)if((dep[u]-dep[v])>>i&1)u=par[i][u]; for(int i=LOG-1;i>=0;i--)if(par[i][u]!=par[i][v])u=par[i][u],v=par[i][v]; return u==v?u:par[0][u]; } void solve(){ scanf("%d",&N); rep(i,N)G[i].clear(); rep(i,N-1){ int a,b,c; scanf("%d%d%d",&a,&b,&c); a--;b--; G[a].pb(edge(b,c)); G[b].pb(edge(a,c)); } init(); char type[11]; while(scanf("%s",type),strcmp(type,"DONE")){ if(type[0]=='D'){ int a,b; scanf("%d%d",&a,&b); a--;b--; int p=lca(a,b); printf("%d\n",dist[a]+dist[b]-2*dist[p]); } else{ int a,b,k; scanf("%d%d%d",&a,&b,&k); a--;b--;k--; int p=lca(a,b); if(dep[a]-dep[p]<k){ k=dep[a]+dep[b]-2*dep[p]-k; swap(a,b); } rep(i,LOG)if(k>>i&1)a=par[i][a]; printf("%d\n",a+1); } } } signed main(){ int T; scanf("%d",&T); while(T--)solve(); return 0; }