Skip to content

爆搜

a160: P-2-9. 子集合乘積

內容

輸入n個正整數A[1..n],以及一個質數P,請計算A中元素各種組合中,有多少種組合其相乘積除以P的餘數等於1。

每個元素可以選取或不選取但不可重複選,A中的數字可能重複。P<=1000000009,0 < n < 37,且假設A中元素皆小於P。

輸入說明

第一行是n與P,第二行n個整數是A[i],同行數字以空白間隔。

輸出說明

滿足條件的組合數,因為數字可能太大,請輸出該組合數除以P的餘數

範例輸入 #1

5 11
1 1 2 6 10

範例輸出 #1

7

解題方法:

一、使用位元窮舉

位元窮舉 「NA (score:20%)(AC (0. 7s, 340KB))」
#include<bits/stdc++.h>
using namespace std;
#define nn "\n"
#define N 100001
#define int long long 


int v[N];

#undef int 
int main(){
#define int long long 


    int n,P,ans=0;
    cin>>n>>P;

    for(int i=0;i<n;i++){
        cin>>v[i];
    }

    int _one=1;


    for(int i=1;i<(_one<<n);i++){
        int pow=1;
        for(int j=0;j<n;j++){
            if(i & _one<<j){
                pow*=v[j]%P;
                pow%=P;
            }
        }
        if(pow==1){
            ans++;
        }
    }
    cout<<ans;
}
"詳細評分NA(score:20%)"

#0: 20% TLE (1s)
Killed
#1: 20% AC (0.7s, 340KB)
通過檢測
#2: 20% TLE (1s)
Killed
#3: 20% TLE (1s)
Killed
#4: 20% TLE (1s)
Killed

二、遞迴

寫遞迴直接想像樹狀圖,每一個數字有兩個選擇:選、不選
所以是:

前一個數字
|
|------------選下一個數字(加入計算)
|
|
|------------不選下一個數字(不加入計算)

遞迴 「NA (score:20%)(AC (0. 32s, 360KB))」
#include<bits/stdc++.h>
using namespace std;
#define nn "\n"
#define N 100001


int v[N],ans=0;
int n,P;


int a(int i,int p){
    if(i>=n){//跑完全部
        if(p==1){
            ans++;
        }
        return 0;
    }
    a(i+1,(v[i]*p)%P);
    a(i+1,p);
    return 0;
}


int main(){
    ios::sync_with_stdio(0);
    cin.tie(0);
    cin>>n>>P;
    for(int i=0;i<n;i++){
        cin>>v[i];
    }
    a(0,1);
    cout<<ans-1;//減去空集合
}
"詳細評分NA (score:20%)"

#0: 20% TLE (1s)
Killed
#1: 20% AC (32ms, 360KB)
通過檢測
#2: 20% TLE (1s)
Killed
#3: 20% TLE (1s)
Killed
#4: 20% TLE (1s)
Killed

三、折半枚舉

分成兩堆,再從其中一堆找另一堆

NA (score:80%) map解
#include<bits/stdc++.h>
using namespace std;
#define nn "\n"
#define N 100001
#define int long long
int v[N];
int n,p;
int ans=0;
map<int,int>m1,m2;


int mod(int a,int y){
    if(y==0){
        return 1;
    }
    else if(y%2==0){//2
        return (mod(a,y/2))%p*(mod(a,y/2))%p;
    }
    else{//1
        return (mod(a,y-1))%p*(a)%p;
    }
}

int make1(int it,int sum,bool is_o){
    if(it>=n/2){
        if(is_o){
            m1[sum%p]++;
        }
        return 0;
    }
    make1(it+1,(sum*v[it])%p,true);
    make1(it+1,sum%p,is_o);
    return 0;
}
int make2(int it,int sum,bool is_o){
    if(it>=n){
        if(is_o){
            m2[sum%p]++;
        }
        return 0;
    }
    make2(it+1,(sum*v[it])%p,true);
    make2(it+1,sum%p,is_o);
    return 0;
}

#undef int
int main(){
#define int long long

    ios::sync_with_stdio(0);
    cin.tie(0);


    //istringstream cin("5 11 1 1 2 6 10");


    cin>>n>>p;
    for(int i=0;i<n;i++){
        cin>>v[i];
    }

    //分兩半存到map<乘積,個數>
    make1(0,1,false);//  0~n/2-1
    make2(n/2,1,false);//  n/2~n-1

    ans=m1[1]+m2[1];//兩半餘數是1的個數

    for(auto i:m1){
        int x=i.first, num=i.second;
        int y=mod(x,p-2);
        auto it=m2.find(y);
        if (it!=m2.end()){
            ans = (ans + num*it->second)%p; //數字1的個數 乘以 數字2的個數
        }
    }

    cout<<ans;
}
"詳細評分NA (score:80%)"

#0: 20% AC (39ms, 336KB)

通過檢測

#1: 20% AC (14ms, 484KB)

通過檢測

#2: 20% AC (63ms, 2.3MB)

通過檢測

#3: 20% TLE (1s)

Killed

#4: 20% AC (40ms, 332KB)

通過檢測

AC (0. 9s, 26. 6MB) map解,快速幂改成迴圈
#include <bits/stdc++.h>
using namespace std;
#define nn "\n"
#define N 100001
#define int long long

int v[N];
int n, p;
int ans = 0;
unordered_map<int, int> m1, m2;

int mod(int a, int y, int p) {
    int res = 1;
    while (y > 0) {
        if (y % 2 == 1) {
            res = (res * a) % p;
        }
        a = (a * a) % p;
        y /= 2;
    }
    return res;
}

void make1(int it, int sum, bool is_o) {
    if (it >= n / 2) {
        if (is_o) {
            m1[sum % p]++;
        }
        return;
    }
    make1(it + 1, (sum * v[it]) % p, true);
    make1(it + 1, sum % p, is_o);
}

void make2(int it, int sum, bool is_o) {
    if (it >= n) {
        if (is_o) {
            m2[sum % p]++;
        }
        return;
    }
    make2(it + 1, (sum * v[it]) % p, true);
    make2(it + 1, sum % p, is_o);
}

#undef int
int main() {
#define int long long

    ios::sync_with_stdio(0);
    cin.tie(0);

    cin >> n >> p;
    for (int i = 0; i < n; i++) {
        cin >> v[i];
    }

    // 分兩半存到map<乘積,個數>
    make1(0, 1, false); // 0 ~ n/2-1
    make2(n / 2, 1, false); // n/2 ~ n-1

    ans = m1[1] + m2[1]; // 兩半餘數是1的個數

    for (auto i : m1) {
        int x = i.first, num = i.second;
        int y = mod(x, p - 2, p);
        auto it = m2.find(y);
        if (it != m2.end()) {
            ans = (ans + num * it->second) % p; // 數字1的個數乘以數字2的個數
        }
    }

    cout << ans << nn;
    return 0;
}
AC (0. 7s, 6. 3MB) 陣列解
// subset product = 1 mod P, O(n*2^(n/2)), sort
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
LL sa[1<<19], sb[1<<19]; // subset product of a and b

// generate all products of subsets of v[]
// save result in prod[], return length of prod[]
int subset(LL v[], int len, LL prod[], LL p) {
    int k=0; // size of prod[]
    for (int i=0;i<len;i++) {
        for (int j=0;j<k;j++) { // (each subset)*v[i]
            prod[k+j]=(prod[j]*v[i])%p;
        }
        prod[k+k]=v[i]; // for subset {v[i]}
        k += k+1;
    }
    return k;
}
// find x^y mod P
LL exp(LL x, LL y, LL p) {
    if (y==0) return 1;
    if (y & 1) return (exp(x, y-1,p)*x)%p;
    // otherwise y is even
    LL t=exp(x, y/2, p);
    return (t*t)%p;
}

int main() {
    int i, n;
    LL a[30], b[30]; // input data
    LL p;
    scanf("%d%lld", &n, &p);
    int len_a=n/2;
    int len_b=n-len_a;
    for (i=0;i<len_a;i++)  // half in a
        scanf("%lld", &a[i]);
    for (i=0;i<len_b;i++)  // half in b
        scanf("%lld", &b[i]);
    int len_sa=subset(a,len_a,sa,p); // all subsets of a
    int len_sb=subset(b,len_b,sb,p); // all subsets of a
    sort(sb, sb+len_sb);
    // merge same element of sb, assume not empty
    LL num[1<<19], len_sb2=1;
    num[0]=1; //its multiplicity
    for (i=1;i<len_sb;i++) {
        if (sb[i]!=sb[i-1]) { // new element
            sb[len_sb2]=sb[i];
            num[len_sb2]=1;
            len_sb2++;
        }
        else {
            num[len_sb2-1]++;
        }
    }
    LL ans = (sb[0]==1) ? num[0]%p : 0; // the number of 1 in sb2
    // compute 1 in sa and cross the two sides
    // for each x in sa, find its inverse in sb2
    for (i=0; i<len_sa; i++) {
        if (sa[i]==1) ans=(ans+1)%p;
        LL y = exp(sa[i], p-2, p); // inverse
        int it = lower_bound(sb, sb+len_sb2, y) - sb;
        if (it<len_sb2 && sb[it]==y)  // found
            ans = (ans + num[it])%p;
    }
    printf("%lld\n", ans);
    return 0;
}