剰余を取った値の四則演算

c++algorithmmath

競プロの64bit整数でもオーバーフローする大きな値が出てくる問題では、言語による差が出ないよう \(10^9+7\) といった数で割ったときの余りが要求される。

四則演算の内、割り算以外は各項の剰余を取ってから計算してもその剰余の値は変わらないので、都度剰余を取ってオーバーフローを回避できる。

#include <bits/stdc++.h>
using namespace std;

int mod(int a, int n)
{
    int x = a % n;
    if (x >= 0)
    {
        return x;
    }
    else
    {
        return n + x;
    }
}

int main()
{
    int X = 7;
    int Y = 9;
    int Z = 5;

    // (AZ + B) + (CZ + D) = (A+C)Z + (B+D) ≡ B+D (mod Z)
    cout << mod(X + Y, Z) << endl;                 // 1
    cout << mod(X, Z) + mod(Y, Z) << endl;         // 6
    cout << mod(mod(X, Z) + mod(Y, Z), Z) << endl; // 1

    // (AZ + B) - (CZ + D) = (A-C)Z + (B-D) ≡ B-D (mod Z)
    cout << mod(X - Y, Z) << endl;                 // 3
    cout << mod(X, Z) - mod(Y, Z) << endl;         // -1
    cout << mod(mod(X, Z) - mod(Y, Z), Z) << endl; // 3

    // (AZ + B) * (CZ + D) = (ACZ+AD+BC)Z + (BD) ≡ BD (mod Z)
    cout << mod(X * Y, Z) << endl;                 // 2
    cout << mod(X, Z) + mod(Y, Z) << endl;         // 6
    cout << mod(mod(X, Z) * mod(Y, Z), Z) << endl; // 3

    return 0;
}

割り算は \(b^{-1}b \equiv 1 \mod n\) となるようなモジュラ逆数\(b^{-1}\)があれば、\(\frac{a}{b} \equiv (a \mod n)(b^{-1} \mod n) \mod n\)が成り立つ。bとnが互いに素(\(gcd(b,n)=1\))であることが必要十分条件で、\(bx - ny = gcd(b, n) = 1\)の\(x\)を拡張されたユークリッドの互除法で求める。

RSA暗号とPEM/DERの構造 - sambaiz-net

例えば \(a=100, b=50, n=3\) のときモジュラ係数 \(b^{-1}\) は

50 = 16 * 3 + 2
              2 = 1 * 50 + 16 * 3

3 = 1 * 2 + 1
            1 = 1 * 3 - 1 * (1 * 50 + 16 * 3)
              = -1 * 50 - 15 * 3 

より \(-1 \equiv 2 \mod 3\) となり、\(\frac{100}{50} \equiv (100 \mod 3)(2 \mod 3) \equiv 2 \mod 3\)が成り立つ。

...

int modinv(int b, int n)
{
    int x = b, y = n;
    map<int, pair<int, int>> history({{x, {1, 0}}, // b = 1 * b + 0 * n
                                      {y, {0, 1}}});
    while (y > 1)
    {
        int next = x % y;
        history[next] = {
            history[x].first - history[y].first * (x / y),
            history[x].second - history[y].second * (x / y)}; // (x % y) = y - nx
        x = y;
        y = next;
    }
    return mod(history[1].first, n);
}

int main()
{
    int X = 100;
    int Y = 50;
    int Z = 3;
    cout << modinv(Y, Z) << endl;                             // 2
    cout << mod(X / Y, Z) << endl;                            // 2
    cout << mod(mod(X, Z) / mod(Y, Z), Z) << endl;            // 0
    cout << mod(mod(X, Z) * mod(modinv(Y, Z), Z), Z) << endl; // 2
    return 0;
}

B - Kleene Inversionは余りを返す問題で、 途中に出てくる \(1+2+…+K=(K+1)*\frac{K}{2}\)の計算を単純に行うとlongでもオーバーフローしてしまうが、都度剰余を取ればそれを回避できる。

long KS = mod(mod(mod(K + 1, m) * mod(K, m), m) * modinv(2, m), m);