BK-tree: 編集距離が近い単語をO(log n)で列挙する

最終更新: 2017-08-13 11:53

BK-treeは大きさNのstringの集合からqueryとの編集距離がd以下のものをO(log N)で検索するためのデータ構造です。

詳しい解説はBK-Tree | Introduction & Implementationを見てください。このページのコードを使いやすく書き換えたものを以下に示します。

#include <string>
#include <vector>
#include <map>
#include <algorithm>
#include <iostream>
 
struct Node {
    std::string word;
    std::map<int, int> child;
    Node(std::string w) : word(w), child() {
    }
};
 
class BKTree {
    std::vector<Node> tree;

    int min(int a, int b, int c);
    int editDistance(const std::string& a, const std::string& b);
    void add(Node& root, const std::string& str);
    std::vector <std::string> getSimilarWords(Node& root, const std::string& s, const int d);

public:
    void add(const std::string& str);
    std::vector <std::string> getSimilarWords(const std::string& s, const int d);
};
 
int BKTree::min(int a, int b, int c) {
    return std::min(a, std::min(b, c));
}
 
int BKTree::editDistance(const std::string& a, const std::string& b) {
    int m = a.length(), n = b.length();
    int dp[m+1][n+1];
 
    // filling base cases
    for (int i=0; i<=m; i++)
        dp[i][0] = i;
    for (int j=0; j<=n; j++)
        dp[0][j] = j;
 
    // populating matrix using dp-approach
    for (int i=1; i<=m; i++) {
        for (int j=1; j<=n; j++) {
            if (a[i-1] != b[j-1]) {
                dp[i][j] = min( 1 + dp[i-1][j],  // deletion
                                1 + dp[i][j-1],  // insertion
                                1 + dp[i-1][j-1] // replacement
                              );
            } else {
                dp[i][j] = dp[i-1][j-1];
            }
        }
    }
    return dp[m][n];
}
 
// adds curr Node to the tree
void BKTree::add(const std::string& str) {
    if(tree.size() == 0) {
        tree.push_back(Node(str));
    } else {
        add(tree[0], str);
    }
}
        
void BKTree::add(Node& root, const std::string& str) {
    int dist = editDistance(root.word, str);

    if (root.child.count(dist) == 0) {
        root.child[dist] = tree.size();
        tree.push_back(Node(str));
    } else {
        add(tree[root.child[dist]], str);
    }
}
 
std::vector <std::string> BKTree::getSimilarWords(const std::string& s, const int d) {
    if(tree.size() == 0) return std::vector<std::string>();
    else return getSimilarWords(tree[0], s, d);

}

std::vector <std::string> BKTree::getSimilarWords(Node& root, const std::string& s, const int d) {
    std::vector <std::string> ret;

    int dist = editDistance(root.word, s);
 
    if (dist <= d) ret.push_back(root.word);
 
    int start = dist - d;
    if (start < 0) start = 1;
 
    for (;start < dist+d; start++) {
        if(root.child.count(start) > 0) {
            std::vector <std::string> tmp =
             getSimilarWords(tree[root.child[start]], s, d);
            for (auto i : tmp)
                ret.push_back(i);
        }
    }
    return ret;
}
 
// driver program to run above functions
int main(int argc, char const *argv[])
{
    BKTree t;
    // dictionary words
    std::vector<std::string> dictionary = 
        {"hell", "help", "shel", "smell",
         "fell", "felt", "oops", "pop", "oouch", "halt"};

    for(auto w : dictionary) {
        t.add(w);
    }
 
    std::string w1 = "ops";
    std::string w2 = "helt";
 
    auto match = t.getSimilarWords(w2, 2);

    std::cout << "Correct words in dictionary for " << w2 << ":\n";
    for (auto x : match) std::cout << x << std::endl;
}