ラビン-カープアルゴリズムをC++で実装する

algorithmc++

ラビン-カープアルゴリズムはローリングハッシュを用いて部分文字列を探索するアルゴリズム。 ローリングハッシュは前のハッシュから先頭の要素を取り除き、次の要素を追加することによってO(1)で次のハッシュが得られる。 これを満たせるハッシュ関数としては様々なものが考えられるが、例えば文字コードを単に足し合わせる場合、同じ文字が含まれているだけでハッシュが衝突してしまうので、通常は次のように文字列を数値化したものが用いられる。剰余を取れば長い文字列でもオーバーフローを回避できる。

剰余を取った値の四則演算 - sambaiz-net

#include<bits/stdc++.h>
using namespace std;

class rolling_hash {
    const int base = 256;
    const int mod = pow(10, 9) + 7; // prime modulus

    string str;
    int window_length;
    long tail_p; // = pow(base,window_length)

    long current_hash = 0;
    int head_idx = 0;
public:
    rolling_hash(string str, int window_length) {
        this->str = str;
        this->window_length = window_length;
        long p = 1;
        for (int i = window_length - 1; i>= 0; i--) {
            current_hash = (current_hash + ((str[i] * p) % mod)) % mod;
            if (i != 0) p = (p * base) % mod;
            else this->tail_p = p;
        }
    }
    int get() {
        return current_hash;
    }
    //    a_0 * base^n + a_1 * base^n-1 + ... + a_n * base^0
    // =>                a_1 * base^n + ... + a_n * base^1 + a_n+1 * base^0
    void roll() {
        head_idx++;
        current_hash = (current_hash - ((str[head_idx-1] * this->tail_p) % mod)) % mod; // a_1 * base^n-1 + ... + a_n * base^0
        int tail_idx = head_idx+window_length-1;
        if (tail_idx >= str.length()) throw new out_of_range("the window is out of range");
        current_hash = (((current_hash * base) % mod) + str[tail_idx]) % mod; // a_1 * base^n + ... + a_n * base^n-1 + a_n+1 * base^0
    }
};

このハッシュが一致する部分を探索していく。ハッシュは衝突する可能性があるので一致した後に文字列比較を行う必要がある。

int index_of(string str, string substr) {
    auto str_hash = rolling_hash(str, substr.length());
    auto substr_hash = rolling_hash(substr, substr.length());
    for (int i = 0; i < str.length() - substr.length(); i++) {
        if(str_hash.get() == substr_hash.get() && str.substr(i, substr.length()) == substr) return i;
        str_hash.roll();
    }
    return -1;
}

int main(void){
    cout << index_of("ABCDEF", "CDE") << endl; // 2
    cout << index_of("ABEF", "BCD") << endl; // -1
}

全体の文字列の長さをn、部分文字列の長さをmとすると、平均としてはO(n+m)で探すことができるが、全てのハッシュが衝突する最悪の場合は毎回文字列比較を行うことになるので総当たりと同じO(nm)となってしまう。ただ、複数の部分文字列を探索する場合に実行時間がその数に依存しない特長があり、盗用の検出などに使われる。

参考

Rabin–Karp algorithm - Wikipedia

プログラミングコンテストチャレンジブック [第2版]