瀏覽代碼

Merge pull request #3727 from huangzhengxiang/bug-fix

[Bugfix] Ensure penalty sampler to be the first one in mixed samplers, Tiktoken Acceleration with Trie.
jxt1234 2 周之前
父節點
當前提交
dbe7ed318f

+ 14 - 0
transformers/llm/engine/src/sampler.cpp

@@ -235,6 +235,20 @@ void Sampler::SamplerConfig::configMixed(std::shared_ptr<LlmConfig> llmConfig) {
         this->configSampler(samplerName, llmConfig);
         // std::cout << samplerName << " " << std::flush;
     }
+    // remove all "penalty", and add one to begin if presence.
+    std::vector<std::string> newSamplers;
+    bool hasPenalty = false;
+    for (auto sampler:mixedSamplers) {
+        if (sampler!="penalty") {
+            newSamplers.push_back(sampler);
+        } else {
+            hasPenalty = true;
+        }
+    }
+    if (hasPenalty) {
+        newSamplers.insert(newSamplers.begin(), "penalty");
+    }
+    mixedSamplers = newSamplers;
     // std::cout << std::endl;
     // set select type
     // the final sampler select the token

+ 29 - 27
transformers/llm/engine/src/tokenizer.cpp

@@ -449,33 +449,13 @@ void Tiktoken::encode(const std::string& str, std::vector<int>& ids) {
     if (str.empty()) {
         return;
     }
-    size_t i = 0;
-    while (i < str.size()) {
-        bool found_pair = false;
-        // Attempt to match the longest possible symbol
-        size_t longest_match_len = 0;
-        std::string longest_match;
-
-        // Check substrings of decreasing length
-        for (size_t len = str.size() - i; len > 0; --len) {
-            std::string token = str.substr(i, len);
-            auto it = encoder_.find(token);
-            if (it != encoder_.end()) {
-                if (len > longest_match_len) {
-                    longest_match_len = len;
-                    longest_match = it->first;
-                }
-            }
-        }
-
-        if (!longest_match.empty()) {
-            ids.push_back(encoder_.at(longest_match));
-            i += longest_match_len;
-        } else {
-            // If no matching symbol is found, this typically means an error in the encoding
-            // or the input text contains characters that the encoder doesn't know how to handle
-            std::cerr << "Error: No encoding found for the sequence starting at position " << i << " , symbol: " << str[i-2] << std::endl;
-            return;
+    auto it = str.begin();
+    while(it!=str.end()) {
+        auto last_it = it;
+        int token_id = encoder_.find(it, str.end());
+        if (token_id>=0) { ids.push_back(token_id); }
+        else {
+            MNN_ERROR("Error: No encoding found for the sequence %s\n", std::string(last_it, it).c_str());
         }
     }
 }
@@ -487,6 +467,28 @@ std::string Tiktoken::decode(int id) {
     return decoder_[id];
 }
 
+bool BertTokenizer::load_vocab(std::ifstream& tok_file) {
+    std::string line;
+    std::getline(tok_file, line);
+    int vocab_len = std::stoi(line);
+    // load vocab
+    decoder_.resize(vocab_len);
+    for (int i = 0; i < vocab_len; i++) {
+        std::getline(tok_file, line);
+        auto token = base64_decode(line);
+        encoder_.insert({token, i});
+        decoder_[i] = token;
+    }
+    return true;
+}
+
+std::string BertTokenizer::decode(int id) {
+    if (id >= decoder_.size()) {
+        return "";
+    }
+    return decoder_[id];
+}
+
 std::vector<int> BertTokenizer::word_piece(const std::string& token) {
     auto it = encoder_.find(token);
     if (it != encoder_.end()) {

+ 68 - 2
transformers/llm/engine/src/tokenizer.hpp

@@ -63,6 +63,68 @@ namespace MNN {
 namespace Transformer {
 // std::string_view impl in c++11 start
 
+
+class Trie {
+public:
+    struct TrieNode
+    {
+        std::unordered_map<char, int> children;
+        int id = -1;
+    };
+private:
+    std::vector<TrieNode> list;
+    int size = 1;
+    int getFree() {
+        if (size<list.size()) { return size++; }
+        else {
+            list.resize(list.size()*2);
+            return size++; 
+        }
+    }
+    void insert(int nid, int token_id, std::string::const_iterator it, std::string::const_iterator end) {
+        auto& node = list[nid];
+        if (it==end) { 
+            if (node.id==-1) { node.id=token_id; }
+            return; 
+        }
+        auto cid = node.children.find(*it);
+        if (cid==node.children.end()) {
+            int new_id = getFree();
+            list[nid].children.insert({*it, new_id}); // access the node again even after reallocation!!!
+            insert(new_id, token_id, it+1, end);
+        } else{
+            insert(cid->second, token_id, it+1, end);
+        }
+    }
+    int find(int nid, int current_matched, std::string::const_iterator current_it, std::string::const_iterator& it, const std::string::const_iterator& end) {
+        const auto& node = list[nid];
+        if (node.id!=-1) { 
+            current_matched = node.id; 
+            current_it = it;
+        }
+        auto cid = node.children.find(*it);
+        if (cid != node.children.end()) {
+            return find(cid->second, current_matched, current_it, ++it, end);
+        } else {
+            if (node.id!=-1) { return node.id; }
+            else { it = current_it; return current_matched;} 
+        }
+    }
+public:
+    Trie(int initial_size=10000) {
+        list.resize(initial_size); // init the allocate size
+        size = 1; // root
+    }
+    void insert(std::pair<const std::string&, int> entry) {
+        insert(0, entry.second, entry.first.begin(), entry.first.end());
+    }
+    int find(std::string::const_iterator& it, const std::string::const_iterator& end) {
+        if (it==end) { return -1; }
+        return find(0, -1, it+1, it, end);
+    }
+};
+
+
 class Tokenizer {
 public:
     static constexpr int MAGIC_NUMBER = 430;
@@ -149,15 +211,19 @@ public:
 protected:
     virtual bool load_vocab(std::ifstream& file) override;
     virtual void encode(const std::string& str, std::vector<int>& ids) override;
-    std::unordered_map<std::string, int> encoder_;
+    Trie encoder_;
     std::vector<std::string> decoder_;
 };
 
-class BertTokenizer : public Tiktoken {
+class BertTokenizer : public Tokenizer {
 public:
     BertTokenizer() = default;
+    virtual std::string decode(int id) override;
 protected:
+    virtual bool load_vocab(std::ifstream& file) override;
     virtual void encode(const std::string& str, std::vector<int>& ids) override;
+    std::unordered_map<std::string, int> encoder_;
+    std::vector<std::string> decoder_;
 private:
     std::vector<int> word_piece(const std::string& token);
 };