剰余を取った値の四則演算

c++algorithmmath

競プロの、64bit整数でもオーバーフローする大きな値が出てくる問題では、言語による差が出ないよう、10^9+7 といった数で割ったときの余りが要求される。

四則演算の内、割り算以外は各項の剰余を取ってから計算してもその剰余の値は変わらないので、都度剰余を取ってオーバーフローを回避できる。

#include <bits/stdc++.h>
using namespace std;

int mod(int a, int n)
{
    int x = a % n;
    if (x >= 0)
    {
        return x;
    }
    else
    {
        return n + x;
    }
}

int main()
{
    int X = 7;
    int Y = 9;
    int Z = 5;

    // (AZ + B) + (CZ + D) = (A+C)Z + (B+D) ≡ B+D (mod Z)
    cout << mod(X + Y, Z) << endl;                 // 1
    cout << mod(X, Z) + mod(Y, Z) << endl;         // 6
    cout << mod(mod(X, Z) + mod(Y, Z), Z) << endl; // 1

    // (AZ + B) - (CZ + D) = (A-C)Z + (B-D) ≡ B-D (mod Z)
    cout << mod(X - Y, Z) << endl;                 // 3
    cout << mod(X, Z) - mod(Y, Z) << endl;         // -1
    cout << mod(mod(X, Z) - mod(Y, Z), Z) << endl; // 3

    // (AZ + B) * (CZ + D) = (ACZ+AD+BC)Z + (BD) ≡ BD (mod Z)
    cout << mod(X * Y, Z) << endl;                 // 2
    cout << mod(X, Z) + mod(Y, Z) << endl;         // 6
    cout << mod(mod(X, Z) * mod(Y, Z), Z) << endl; // 3

    return 0;
}

割り算は b^{-1}b ≡ 1 (mod n) となるようなモジュラ逆数b^{-1}があれば、a/b ≡ (a mod n)(b^{-1} mod n) (mod n)が成り立つ。bとnが互いに素(gcd(b,n)=1)であることが必要十分条件。 モジュラ逆数を計算する方法としては、mx + ny = gcd(m, n) の解 x, y が得られる、拡張されたユークリッドの互除法などがあり、これにより bx - ny = 1 のxを求める。

a=100, b=50, n=3 だとすると、モジュラ係数 b^{-1}

x = ay + b
=> b = x - ay

y = nb + m
=> m = y - nb
     = y - n(x - ay)
     = -nx + (1+na)y
...
---

50 = 16 * 3 + 2
=> 2 = 1 * 50 + 16 * 3

3 = 1 * 2 + 1
=> 1 = 1 * 3 - 1 * (1 * 50 + 16 * 3)
     = -1 * 50 - 15 * 3 

より -1 ≡ 2 (mod 3) で、実際 100 / 50 ≡ (100 mod 3)(2 mod 3) = 2 (mod 3) となる。

コードで表すと次のようになる。

...

int modinv(int b, int n)
{
    int x = b, y = n;
    map<int, pair<int, int>> history({{x, {1, 0}}, // b = 1 * b + 0 * n
                                      {y, {0, 1}}});
    while (y > 1)
    {
        int next = x % y;
        history[next] = {
            history[x].first - history[y].first * (x / y),
            history[x].second - history[y].second * (x / y)}; // (x % y) = y - nx
        x = y;
        y = next;
    }
    return mod(history[1].first, n);
}

int main()
{
    int X = 100;
    int Y = 50;
    int Z = 3;
    cout << modinv(Y, Z) << endl;                             // 2
    cout << mod(X / Y, Z) << endl;                            // 2
    cout << mod(mod(X, Z) / mod(Y, Z), Z) << endl;            // 0
    cout << mod(mod(X, Z) * mod(modinv(Y, Z), Z), Z) << endl; // 2
    return 0;
}

B - Kleene Inversionは余りを返す問題で、 途中に出てくる 1+2+...+K=(K+1)*(K/2)の計算を単純に行うとlongでもオーバーフローしてしまうが、都度剰余を取ればそれを回避できる。

long KS = mod(mod(mod(K + 1, m) * mod(K, m), m) * modinv(2, m), m);