abc147_dの解き方の解説 桁ごとに数え上げ
2020/05/28 13:05
感謝の正拳突き1万回ならぬ、競プロサイトで 1 日 1AC(正解)するように頑張る。
ABC147 の D 問題 Xor Sum 4を解いていきます。こういう問題苦手だ・・
XOR は繰り上がりがない
XOR は i == j の場合 0、i != j の場合 1 となります。他の bit は影響しない演算です。
つまり bit ごとにバラバラに考えても大丈夫です。
bitset べんり
整数を 2 進数に直して bit ごとに扱うとき、std::bitsetが便利です。
コンストラクタに整数型を渡すと bit ごとにばらして、1 桁目を 0 番目、2 桁目を 1 番目の配列にセットしたようなコンテナにしてくれます。
bitset<3> bs(2); // {0, 1, 0}
bitset<3> bs(5); // {1, 0, 1}
数え上げの問題に置き換える
問題を単純化するために 1 桁目だけ考えてみます。
入力例
5
1 0 1 0 1
出力例
6
このとき、i == j は 0 になるので、捨てます。つまり i != j になる個数を数え上げます。
i != j となるのは、{1, 0, 1, 0, 1}の 5 個のボールから 2 個取り出し、1, 0 になる個数です。
0 が 2 個、1 が 3 個なので 3 * 2 パターンですね。答えは 6 になります。
同様に{5, 0, 5, 0, 5}の場合も考えてみましょう。5 は 2 進数にすると 101 なります。
- 1 桁目は 0 が 2 個 * 1 が 3 個なので 6 個
- 2 桁目は 0 が 5 個 * 1 が 0 個なので 0 個
- 3 桁目は 0 が 2 個 * 1 が 3 個なので 6 個
1 桁目は 2^0 = 1 で 1 _ 6 = 6、3 桁目は 2^2 = 4 で 4 _ 6 = 24、合計 30 が答えとなります。
MOD したあまりなので、mint を使う
MOD したあまりの計算はabc156_d の解説で紹介した mintを使うと便利です。
できあがり
#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int MOD = 1e9 + 7;
class mint {
long long x;
public:
mint(long long x=0) : x((x % MOD + MOD) % MOD) {}
mint operator-() const {
return mint(-x);
}
mint& operator+=(const mint& a) {
if ((x += a.x) >= MOD) x -= MOD;
return *this;
}
mint& operator-=(const mint& a) {
if ((x += MOD-a.x) >= MOD) x -= MOD;
return *this;
}
mint& operator*=(const mint& a) {
(x *= a.x) %= MOD;
return *this;
}
mint operator+(const mint& a) const {
mint res(*this);
return res+=a;
}
mint operator-(const mint& a) const {
mint res(*this);
return res-=a;
}
mint operator*(const mint& a) const {
mint res(*this);
return res*=a;
}
mint pow(long long t) const {
if (!t) return 1;
mint a = pow(t>>1);
a *= a;
if (t&1) a *= *this;
return a;
}
// for prime MOD
mint inv() const {
return pow(MOD-2);
}
mint& operator/=(const mint& a) {
return (*this) *= a.inv();
}
mint operator/(const mint& a) const {
mint res(*this);
return res/=a;
}
friend ostream& operator<<(ostream& os, const mint& m){
os << m.x;
return os;
}
};
int main() {
ll n, a;
mint result;
vector<ll> one(61);
cin >> n;
for(int i = 0; i < n; i++) {
cin >> a;
// 各bitが1になってる個数を集計する
// 制約がa <= 2 ^ 60なので60
bitset<60> bs(a);
for(int j = 0; j < 60; j++) {
one[j] += bs[j];
}
}
// 2進数の桁は2のn乗で表す
mint bit = 1;
for(int i = 0; i < 60; i++) {
// 0の個数 = 入力のトータル - 1の個数
result += bit * one[i] * (n - one[i]);
bit *= 2; // 毎回2でかけて2のn乗にする
}
cout << fixed << result;
}