[ACNOI2022]异或
考察知识点多样、覆盖面广,代码实现有所要求,可谓良心好题;这样的出题人,不拉到油锅里炸一炸,怎么对得起 Ta 的辛勤付出!
题目
题目描述
求出有多少个长度为
L
L
L 的序列
{
a
}
\{a\}
{a},满足
a
i
∈
S
a_i\in S
ai∈S 且
⨁
j
=
1
i
a
i
∉
T
(
1
⩽
i
⩽
L
)
\bigoplus_{j=1}^{i}a_i\notin T\;(1\leqslant i\leqslant L)
⨁j=1iai∈/T(1⩽i⩽L) 。其中
⊕
\oplus
⊕ 表示按位异或。答案对
998244353
998244353
998244353 取模。
数据范围与提示
L
⩽
4000
L\leqslant 4000
L⩽4000 且
∣
S
∣
=
n
⩽
35
|S|=n\leqslant 35
∣S∣=n⩽35 且
∣
T
∣
=
m
⩽
20
|T|=m\leqslant 20
∣T∣=m⩽20 且
S
,
T
S,T
S,T 中元素均为不超过
1
0
9
10^9
109 的正整数。时限
3
s
\rm 3s
3s 。
思路
首先写个暴力
d
p
\tt dp
dp 转移出来。类似这道题,我们简单 “容斥” 一下。记
f
(
L
,
x
)
f(L,x)
f(L,x) 为,长度为
L
L
L 的序列,满足
d
L
=
x
d_L=x
dL=x 且
d
1
,
d
2
,
d
3
,
…
,
d
L
−
1
∉
T
d_1,d_2,d_3,\dots,d_{L-1}\notin T
d1,d2,d3,…,dL−1∈/T 的数量。
g
(
L
,
x
)
g(L,x)
g(L,x) 为,长度为
L
L
L 的序列,满足
d
L
=
x
d_L=x
dL=x 的数量。有转移
f
(
L
,
x
)
=
g
(
L
,
x
)
−
∑
i
=
1
L
−
1
∑
d
∈
T
f
(
i
,
d
)
×
g
(
L
−
i
,
x
⊕
d
)
f(L,x)=g(L,x)-\sum_{i=1}^{L-1}\sum_{d\in T}f(i,d)\times g(L{-}i,\;x{\oplus} d)
f(L,x)=g(L,x)−i=1∑L−1d∈T∑f(i,d)×g(L−i,x⊕d)
只求出 f ( i , x ) ( x ∈ T ) f(i,x)\;(x\in T) f(i,x)(x∈T),即可向后转移;于是问题变为求出 g ( i , x ) g(i,x) g(i,x),其中 x = a ⊕ b ( a , b ∈ T ∪ { 0 } ) x=a\oplus b\;(a,b\in T\cup\{0\}) x=a⊕b(a,b∈T∪{0}) 。
我本以为,可以压缩到 T T T 中元素的张成子空间;但是仔细想想,其实每个 b i t \rm bit bit 都可能是独一无二的,所以还是要记录每个 b i t \rm bit bit 的情况。于是我就卡住了。
不妨先考虑暴力:用 2 ∣ S ∣ 2^{|S|} 2∣S∣ 枚举每种数字的出现次数奇偶性,设数量为 t t t,则其对 g ( i , x ) g(i,x) g(i,x) 的贡献为 ω ( i , t ) \omega(i,t) ω(i,t) 。这个 ω ( L , x ) \omega(L,x) ω(L,x) 表示长度为 L L L 的序列中,恰有 x x x 种数字需要出现奇数次;考虑这一位是否在 x x x 中,就可以 O ( L n ) \mathcal O(Ln) O(Ln) 递推出整个数组。
所以优化就是,求出异或和为 x x x 且 ∣ S ′ ∣ |S'| ∣S′∣ 为奇数(或偶数)的不同 S ′ ⫅ S S'\subseteqq S S′⫅S 数量,然后对于每个 i i i 都可以 O ( n ) \mathcal O(n) O(n) 硬算了。
对 v ∈ S v\in S v∈S 建立线性基,设线性基大小为 B B B 。当 B B B 较小时,将所有元素用线性基表示,可以将值域压缩到 2 B 2^{B} 2B,此时要算异或卷积的集合幂级数 ∏ v ∈ S ( 1 + x v y ) \prod_{v\in S}(1+x^vy) ∏v∈S(1+xvy),手算 FWT \textit{FWT} FWT 结果,乘起来,然后做 FWT \textit{FWT} FWT 逆变换。复杂度 O ( n 2 2 B ) \mathcal O(n^22^B) O(n22B),然后 O ( m 2 n ) \mathcal O(m^2n) O(m2n) 查询不同的 x x x 和 y y y 的指数即可。
当
B
B
B 较大时,对于非线性基元素,求出它们要怎样得到。然后
2
n
−
B
2^{n-B}
2n−B 枚举它们的出现情况,递推、状压、__builtin_popcount
能
O
(
1
)
\mathcal O(1)
O(1) 得出线性基内元素使用量。但是
x
x
x 需额外枚举,复杂度
O
(
n
2
+
m
2
2
n
−
B
)
\mathcal O(n^2+m^22^{n-B})
O(n2+m22n−B) 。
平衡复杂度约为 B = n 2 B=\frac{n}{2} B=2n 。复杂度约为 O ( n 2 2 B ) \mathcal O(n^22^B) O(n22B) 。
第一步复杂度 O ( m 2 L 2 ) \mathcal O(m^2L^2) O(m2L2),怎么办?分治 NTT \textit{NTT} NTT,变为 O ( m 2 L log 2 L ) \mathcal O(m^2L\log^2L) O(m2Llog2L),还不够快?我口胡了一个 “分治过程中维护点值” 的做法,然后被 Tiw-Air-OAO \textsf{Tiw-Air-OAO} Tiw-Air-OAO 嘲笑了……他说:
跟此题一模一样,这就是一个 多项式方程组。它实际上等价于
v
g
−
A
g
v
f
=
v
f
v_g-A_gv_f=v_f
vg−Agvf=vf
其中
v
f
v_f
vf 是元素为多项式(生成函数)的列向量,同理有
A
g
A_g
Ag 为多项式的矩阵。移项即有
(
I
+
A
g
)
v
f
=
v
g
(I+A_g)v_f=v_g
(I+Ag)vf=vg
其中 I I I 为对角矩阵。如果我们能求出 ( I + A g ) − 1 (I+A_g)^{-1} (I+Ag)−1,问题就解决了——右边进行 m 2 m^2 m2 次多项式乘法,复杂度仍是正确的。多项式对 x L x^{L} xL 取模。
直接高斯消元是危险而缓慢的。我们需要一些神奇操作。就像形式幂级数,系数放在哪里,只是 “占位” 作用;也就是说,它实际上是矩阵下标
(
i
,
j
)
(i,j)
(i,j) 和
x
t
x^t
xt 指数
t
t
t 的 “直和”,二者分别有自己的加法。所以,谁当主元都可以。将
x
t
x^t
xt 系数构成的矩阵写为
A
t
A_t
At,则原矩阵等于
F
(
x
)
=
∑
i
⩾
0
A
i
x
i
F(x)=\sum_{i\geqslant 0}A_ix^i
F(x)=i⩾0∑Aixi
没错,只是简单地将系数变为矩阵的多项式。反正矩阵构成环,生成函数就构成多项式环。然后我们对这个多项式求逆元,时间复杂度 O ( m 3 L + m 2 L log L ) \mathcal O(m^3L+m^2L\log L) O(m3L+m2LlogL) 。只需要 A 0 A_0 A0 满秩即可。而 A g A_g Ag 中多项式,即某个 g g g 的生成函数,不含 x 0 x^0 x0 项;于是求 ( I + A g ) − 1 (I+A_g)^{-1} (I+Ag)−1 时 A 0 = I A_0=I A0=I,肯定可行。
故 f f f 可知。最后求答案,类似于 f f f 的转移,枚举前缀第一个不合法之处;只是总方案数改为 ∣ S ∣ L |S|^L ∣S∣L、后面随便放的方案数改为 ∣ S ∣ L − i |S|^{L-i} ∣S∣L−i 罢了。
代码
长度惊人。注意,矩阵不构成交换环,求逆元的乘法顺序不能改变!逆元是左逆元。
#include <cstdio> // JZM yydJUNK!!!
#include <iostream> // XJX yyds!!!
#include <algorithm> // XYX yydLONELY!!!
#include <cstring> // (the STRONG long for LONELINESS)
#include <cctype> // ZXY yydSISTER!!!
using namespace std;
# define rep(i,a,b) for(int i=(a); i<=(b); ++i)
# define drep(i,a,b) for(int i=(a); i>=(b); --i)
typedef long long llong;
inline int readint(){
int a = 0, c = getchar(), f = 1;
for(; !isdigit(c); c=getchar())
if(c == '-') f = -f;
for(; isdigit(c); c=getchar())
a = (a<<3)+(a<<1)+(c^48);
return a*f;
}
const int MOD = 998244353;
inline int modAdd(int x, const int &y){
return ((x += y) >= MOD) ? (x-MOD) : x;
}
inline llong qkpow(llong b, int q){
llong a = 1;
for(; q; q>>=1,b=b*b%MOD)
if(q&1) a = a*b%MOD;
return a;
}
extern int m; ///< for Matrix
struct Matrix{
static const int MAXM = 21;
int v[MAXM][MAXM];
Matrix operator + (const Matrix &t) const {
Matrix c; rep(i,1,m) rep(j,1,m)
c.v[i][j] = modAdd(v[i][j],t.v[i][j]);
return c;
}
Matrix operator - (const Matrix &t) const {
Matrix c; rep(i,1,m) rep(j,1,m)
c.v[i][j] = modAdd(v[i][j],MOD-t.v[i][j]);
return c;
}
Matrix operator * (const Matrix &t) const {
Matrix c; memset(c.v,0,sizeof(c.v));
rep(i,1,m) rep(j,1,m) rep(k,1,m)
c.v[i][k] = int((c.v[i][k]+
llong(v[i][j])*t.v[j][k])%MOD);
return c;
}
Matrix operator * (llong k) const {
Matrix c; rep(i,1,m) rep(j,1,m)
c.v[i][j] = int(k*v[i][j]%MOD);
return c;
}
static Matrix I(){
Matrix c; rep(i,1,m) rep(j,1,m) c.v[i][j] = (i == j);
return c;
}
};
const int LOGMOD = 30;
int yg[LOGMOD], inv2[LOGMOD];
void prepare(){
int p = MOD-1, x = 0; inv2[0] = 1;
for(inv2[1]=(MOD+1)>>1; !(p&1); p>>=1,++x)
inv2[x+1] = int(llong(inv2[x])*inv2[1]%MOD);
for(yg[x]=int(qkpow(3,p)); x; --x)
yg[x-1] = int(llong(yg[x])*yg[x]%MOD);
}
void NTT(Matrix poly[], int n){
for(int w=1<<n>>1,x=n; x; w>>=1,--x)
for(Matrix *p=poly; p!=poly+(1<<n); p+=w<<1)
for(int i=0,v=1; i!=w; ++i,v=int(llong(yg[x])*v%MOD)){
const Matrix mat = p[i+w];
p[i+w] = (p[i]-mat)*v, p[i] = p[i]+mat;
}
}
void DNTT(Matrix poly[], int n){
for(int w=1,x=1; x<=n; w<<=1,++x)
for(Matrix *p=poly; p!=poly+(1<<n); p+=w<<1)
for(int i=0,v=1; i!=w; ++i,v=int(llong(yg[x])*v%MOD)){
const Matrix mat = p[i+w]*v;
p[i+w] = p[i]-mat, p[i] = p[i]+mat;
}
std::reverse(poly+1,poly+(1<<n));
for(Matrix *i=poly; i!=poly+(1<<n); ++i) *i = (*i)*inv2[n];
}
const int MAXL = 4003;
const size_t _BLOCK = sizeof(Matrix);
Matrix inv[MAXL<<2], tmp[MAXL<<2];
void getInv(const Matrix poly[], int n){
inv[0] = Matrix::I();
for(int l=1; l<=n; ++l){
memcpy(tmp,poly,(1<<l)*_BLOCK);
NTT(tmp,l+1), NTT(inv,l+1);
for(int i=0; i!=(2<<l); ++i)
inv[i] = inv[i]*2-inv[i]*tmp[i]*inv[i];
DNTT(inv,l+1); // polluted
memset(inv+(1<<l),0,(1<<l)*_BLOCK);
}
}
template < int __length_of_poly >
void DFWT(int a[][__length_of_poly], int n){
for(int w=1; w!=(1<<n); w<<=1)
for(int p=0; p!=(1<<n); p+=(w<<1))
for(int i=p,j=p+w; i!=p+w; ++i,++j)
for(int k=0,t; k!=__length_of_poly; ++k){
t = a[i][k], a[i][k] = modAdd(t,a[j][k]);
a[j][k] = modAdd(t,MOD-a[j][k]);
}
const llong coe = inv2[n];
for(int i=0; i!=(1<<n); ++i)
for(int j=0; j!=__length_of_poly; ++j)
a[i][j] = int(coe*a[i][j]%MOD);
}
const int MAXN = 36;
int n, m, L, a[MAXN], b[MAXN];
int omega[MAXL][MAXN];
void input(){
n = readint(), m = readint(), L = readint();
omega[0][0] = 1; // no one want to be odd
rep(i,1,L) rep(j,0,n){
omega[i][j] = int(llong(n-j)*omega[i-1][j+1]%MOD);
if(j) omega[i][j] = int((omega[i][j]+
llong(j)*omega[i-1][j-1])%MOD);
}
rep(i,1,n) a[i] = readint();
rep(i,1,m) b[i] = readint();
}
/// g[i,j,k]: to get b_i xor b_j, k numbers appear odd times
int g[MAXN][MAXN][MAXN];
const int HALF = 17;
namespace linear_base{
const int LOGV = 30;
int bas[LOGV-1]; llong from[LOGV-1];
bool insert(int x, uint8_t id){
llong now_from = 1ull<<id;
drep(i,LOGV-1,0) if(x>>i&1){
if(bas[i]) x ^= bas[i], now_from ^= from[i];
else return bas[i] = x, from[i] = now_from, true;
}
return false; // no place to be held
}
int discretize(int x) noexcept {
int res = 0; // int32_t gaurantee
drep(i,LOGV-1,0) if(bas[i]){
res <<= 1; // one at a time
if(x>>i&1) x ^= bas[i], res ^= 1;
}
return x ? -1 : res;
}
llong query(int x){
llong res = 0;
drep(i,LOGV-1,0) if(x>>i&1)
x ^= bas[i], res ^= from[i];
return x ? -1 : res;
}
}
# define __cnt_bit(x) __builtin_popcountll(x)
void solve_g(){
static int bitcnt[1<<HALF];
for(int S=bitcnt[0]=0; !(S>>HALF); ++S)
bitcnt[S] = bitcnt[S>>1]+(S&1);
uint8_t tot = 0; static bool is_in[MAXN];
rep(i,1,n) if((is_in[i] = linear_base::insert(a[i],tot))) ++ tot;
if(tot <= HALF){ // small_base
rep(i,1,n) a[i] = linear_base::discretize(a[i]);
static int buc[1<<HALF][MAXN];
for(int S=0; !(S>>HALF); ++S){
memset(buc[S],0,MAXN<<2), *buc[S] = 1;
rep(i,1,n){ // manual calculation
int sgn = (bitcnt[S&a[i]]&1) ? -1 : 1;
for(int j=i; j; --j){
buc[S][j] = sgn*buc[S][j-1]+buc[S][j];
if(buc[S][j] > MOD) buc[S][j] -= MOD;
buc[S][j] += (buc[S][j]>>31)&MOD;
}
}
}
DFWT(buc,HALF); // longer is not unacceptable
rep(i,0,m) rep(j,1,m){
int v = linear_base::discretize(b[i]^b[j]);
if(v == -1) continue; // cannot be
memcpy(g[i][j],buc[v],(n+1)<<2);
}
}
else{ // large_base
static llong to_got[1<<HALF]; tot = 0;
rep(i,1,n) if(!is_in[i]) // way to got them
to_got[1<<tot] = linear_base::query(a[i]), ++ tot;
for(int S=0; !(S>>tot); ++S) if(S&(S-1))
to_got[S] = to_got[S&(S-1)]^to_got[S&-S];
rep(i,0,m) rep(j,1,m){
llong v = linear_base::query(b[i]^b[j]);
if(v == -1) continue; // cannot be
for(int S=0; !(S>>tot); ++S){
const llong now = to_got[S]^v;
++ g[i][j][__cnt_bit(now)+bitcnt[S]];
}
}
}
}
inline int get_g(int len, int x, int y){
int res = 0;
rep(i,0,n) res = int((res+llong(
omega[len][i])*g[x][y][i])%MOD);
return res;
}
Matrix poly[MAXL<<2], vg[MAXL<<2];
int f[MAXL][MAXN];
int main(){
prepare(); input(); solve_g();
rep(i,1,L) rep(j,1,m) rep(k,1,m)
poly[i].v[j][k] = get_g(i,j,k);
poly[0] = Matrix::I();
const int len = 32-__builtin_clz(L);
getInv(poly,len); // only once
rep(i,1,L) rep(j,1,m) vg[i].v[j][1] = get_g(i,0,j);
NTT(vg,len+1), NTT(inv,len+1);
for(int i=0; i!=(2<<len); ++i)
vg[i] = inv[i]*vg[i]; // not commutative!
DNTT(vg,len+1); // answer
rep(i,1,L) rep(j,1,m) f[i][j] = vg[i].v[j][1];
int ans = int(qkpow(n,L));
rep(i,1,L) rep(j,1,m) ans = int((ans+MOD
-llong(f[i][j])*qkpow(n,L-i)%MOD)%MOD);
printf("%d\n",ans);
return 0;
}
更多推荐
所有评论(0)