「题解」「PKUWC2018」Minimax

PKUWC2018 Day 1 T1,一道有趣的 线段树合并 题目。

题目链接:Luogu P5298LibreOJ 2537

题目

题目描述

小 \(C\) 有一棵 \(n\) 个结点的有根树,根是 \(1\) 号结点,且每个结点最多有两个子结点。

定义结点 \(x\) 的权值为:

  1. 若 \(x\) 没有子结点,那么它的权值会在输入里给出,保证这类点中每个结点的权值互不相同
  2. 若 \(x\) 有子结点,那么它的权值有 \(p _ x\) 的概率是它的子结点的权值的最大值,有 \(1-p _ x\) 的概率是它的子结点的权值的最小值。

现在小 \(C\) 想知道,假设 \(1\) 号结点的权值有 \(m\) 种可能性,权值第 \(i\) 小的可能性的权值是 \(V_i\)​,它的概率为 \(D _ i(D _ i>0)\),求:
\[\sum^m _ {i=1}i\cdot V _ i\cdot D^2 _ i\]

你需要输出答案对 \(998244353\) 取模的值。

数据范围

对于 \(10\%\) 的数据,有 \(1\leq n\leq 20\);
对于 \(20\%\) 的数据,有 \(1\leq n\leq 400\);
对于 \(40\%\) 的数据,有 \(1\leq n\leq 5000\);
对于 \(60\%\) 的数据,有 \(1\leq n\leq 10^5\);
另有 \(10\%\) 的数据保证树的形态随机。
对于 \(100\%\) 的数据,有 \(1\leq n\leq 3\times 10^5\),\(1\leq w _ i\leq 10^9\)。

对于所有数据,满足 \(0<p _ i\cdot 10000<100000\),所以易证明所有叶子的权值都有概率被根取到。

题解

思路

思路比较显然,就不写了。

代码

#include<bits/stdc++.h>
using namespace std;
#define reg register
typedef long long ll;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++)
static char buf[100000],*p1=buf,*p2=buf;
inline int read(void){
	reg char ch=getchar();
	reg int res=0;
	while(ch<'0'||'9'<ch)ch=getchar();
	while('0'<=ch&&ch<='9')res=10*res+ch-'0',ch=getchar();
	return res;
}

inline int pow(reg int x,reg int exp,reg int mod){
	reg int res=1;
	while(exp){
		if(exp&1)
			res=1ll*res*x%mod;
		x=1ll*x*x%mod;
		exp>>=1;
	}
	return res;
}

const int MAXN=3e5+5;
const int mod=998244353;
const int inv1e4=pow(1e4,mod-2,mod);

inline int add(reg int a,reg int b){
	reg int sum=a+b;
	return sum>=mod?sum-mod:sum;
}

inline int mul(reg int a,reg int b){
	return 1ll*a*b%mod;
}

int n,m;
int p[MAXN];
vector<int> V;
int root[MAXN];

namespace SegmentTree{
	#define mid ( ( (l) + (r) ) >> 1 )
	struct Node{
		int lson,rson;
		int sum;
		int tagM;
		#define lson(x) unit[(x)].lson
		#define rson(x) unit[(x)].rson
		#define sum(x) unit[(x)].sum
		#define tagM(x) unit[(x)].tagM
	};
	const int MAXSIZE=40*MAXN;
	int tot;
	Node unit[MAXSIZE];
	int top,S[MAXSIZE];
	inline int New(void){
		reg int k;
		if(top)
			k=S[top--];
		else
			k=++tot;
		lson(k)=rson(k)=sum(k)=0;
		tagM(k)=1;
		return k;
	}
	inline void del(reg int k){
		lson(k)=rson(k)=sum(k)=0;
		tagM(k)=1;
		S[++top]=k;
		return;
	}
	inline void pushup(reg int k){
		sum(k)=add(sum(lson(k)),sum(rson(k)));
		return;
	}
	inline void Mul(reg int k,reg int val){
		sum(k)=mul(sum(k),val);
		tagM(k)=mul(tagM(k),val);
		return;
	}
	inline void pushdown(reg int k){
		if(tagM(k)!=1){
			if(lson(k))
				Mul(lson(k),tagM(k));
			if(rson(k))
				Mul(rson(k),tagM(k));
			tagM(k)=1;
		}
		return;
	}
	inline void update(reg int &k,reg int l,reg int r,reg int pos,reg int val){
		if(!k)
			k=New();
		if(l==r){
			sum(k)=add(sum(k),val);
			return;
		}
		pushdown(k);
		if(pos<=mid)
			update(lson(k),l,mid,pos,val);
		else
			update(rson(k),mid+1,r,pos,val);
		pushup(k);
		return;
	}
	inline int merge(reg int x,reg int y,reg int l,reg int r,reg int xmul,reg int ymul,const int& val){
		if(!x&&!y)
			return 0;
		else if(!x){
			Mul(y,ymul);
			return y;
		}
		else if(!y){
			Mul(x,xmul);
			return x;
		}
		else{
			pushdown(x),pushdown(y);
			reg int k=New();
			reg int lsumx=sum(lson(x)),lsumy=sum(lson(y)),rsumx=sum(rson(x)),rsumy=sum(rson(y));
			lson(k)=merge(lson(x),lson(y),l,mid,add(xmul,mul(rsumy,add(1,mod-val))),add(ymul,mul(rsumx,add(1,mod-val))),val);
			rson(k)=merge(rson(x),rson(y),mid+1,r,add(xmul,mul(lsumy,val)),add(ymul,mul(lsumx,val)),val);
			pushup(k);
			del(x),del(y);
			return k;
		}
	}
	int cnt,T[MAXN];
	inline void query(reg int k,reg int l,reg int r){
		if(l==r){
			T[++cnt]=mul(sum(k),sum(k));
			return;
		}
		pushdown(k);
		if(lson(k))
			query(lson(k),l,mid);
		if(rson(k))
			query(rson(k),mid+1,r);
		return;
	}
	#undef mid
}

int deg[MAXN];
int son[MAXN][2];

inline void dfs(reg int u){
	switch(deg[u]){
		case 0:{
			SegmentTree::update(root[u],1,m,p[u],1);
			break;
		}
		case 1:{
			dfs(son[u][0]);
			root[u]=root[son[u][0]];
			break;
		}
		case 2:{
			dfs(son[u][0]),dfs(son[u][1]);
			root[u]=SegmentTree::merge(root[son[u][0]],root[son[u][1]],1,m,0,0,p[u]);
			break;
		}
	}
	return;
}

int main(void){
	n=read(),read();
	for(reg int i=2;i<=n;++i){
		static int f;
		f=read();
		son[f][deg[f]++]=i;
	}
	for(reg int i=1;i<=n;++i){
		p[i]=read();
		if(deg[i])
			p[i]=mul(p[i],inv1e4);
		else
			V.push_back(p[i]);
	}
	sort(V.begin(),V.end());
	V.erase(unique(V.begin(),V.end()),V.end());
	m=V.size();
	for(reg int i=1;i<=n;++i)
		if(!deg[i])
			p[i]=lower_bound(V.begin(),V.end(),p[i])-V.begin()+1;
	dfs(1);
	SegmentTree::query(root[1],1,m);
	reg int ans=0;
	for(reg int i=1;i<=SegmentTree::cnt;++i)
		ans=add(ans,mul(i,mul(V[i-1],SegmentTree::T[i])));
	printf("%d\n",ans);
	return 0;
}