Browse Source

Update tokenizer.hpp

huangzhengxiang 3 weeks ago
parent
commit
d5078fa364
1 changed files with 68 additions and 2 deletions
  1. 68 2
      transformers/llm/engine/src/tokenizer.hpp

+ 68 - 2
transformers/llm/engine/src/tokenizer.hpp

@@ -63,6 +63,68 @@ namespace MNN {
 namespace Transformer {
 // std::string_view impl in c++11 start
 
+
+class Trie {
+public:
+    struct TrieNode
+    {
+        std::unordered_map<char, int> children;
+        int id = -1;
+    };
+private:
+    std::vector<TrieNode> list;
+    int size = 1;
+    int getFree() {
+        if (size<list.size()) { return size++; }
+        else {
+            list.resize(list.size()*2);
+            return size++; 
+        }
+    }
+    void insert(int nid, int token_id, std::string::const_iterator it, std::string::const_iterator end) {
+        auto& node = list[nid];
+        if (it==end) { 
+            if (node.id==-1) { node.id=token_id; }
+            return; 
+        }
+        auto cid = node.children.find(*it);
+        if (cid==node.children.end()) {
+            int new_id = getFree();
+            list[nid].children.insert({*it, new_id}); // access the node again even after reallocation!!!
+            insert(new_id, token_id, it+1, end);
+        } else{
+            insert(cid->second, token_id, it+1, end);
+        }
+    }
+    int find(int nid, int current_matched, std::string::const_iterator current_it, std::string::const_iterator& it, const std::string::const_iterator& end) {
+        const auto& node = list[nid];
+        if (node.id!=-1) { 
+            current_matched = node.id; 
+            current_it = it;
+        }
+        auto cid = node.children.find(*it);
+        if (cid != node.children.end()) {
+            return find(cid->second, current_matched, current_it, ++it, end);
+        } else {
+            if (node.id!=-1) { return node.id; }
+            else { it = current_it; return current_matched;} 
+        }
+    }
+public:
+    Trie(int initial_size=10000) {
+        list.resize(initial_size); // init the allocate size
+        size = 1; // root
+    }
+    void insert(std::pair<const std::string&, int> entry) {
+        insert(0, entry.second, entry.first.begin(), entry.first.end());
+    }
+    int find(std::string::const_iterator& it, const std::string::const_iterator& end) {
+        if (it==end) { return -1; }
+        return find(0, -1, it+1, it, end);
+    }
+};
+
+
 class Tokenizer {
 public:
     static constexpr int MAGIC_NUMBER = 430;
@@ -149,15 +211,19 @@ public:
 protected:
     virtual bool load_vocab(std::ifstream& file) override;
     virtual void encode(const std::string& str, std::vector<int>& ids) override;
-    std::unordered_map<std::string, int> encoder_;
+    Trie encoder_;
     std::vector<std::string> decoder_;
 };
 
-class BertTokenizer : public Tiktoken {
+class BertTokenizer : public Tokenizer {
 public:
     BertTokenizer() = default;
+    virtual std::string decode(int id) override;
 protected:
+    virtual bool load_vocab(std::ifstream& file) override;
     virtual void encode(const std::string& str, std::vector<int>& ids) override;
+    std::unordered_map<std::string, int> encoder_;
+    std::vector<std::string> decoder_;
 private:
     std::vector<int> word_piece(const std::string& token);
 };