|
@@ -63,6 +63,68 @@ namespace MNN {
|
|
|
namespace Transformer {
|
|
|
// std::string_view impl in c++11 start
|
|
|
|
|
|
+
|
|
|
+class Trie {
|
|
|
+public:
|
|
|
+ struct TrieNode
|
|
|
+ {
|
|
|
+ std::unordered_map<char, int> children;
|
|
|
+ int id = -1;
|
|
|
+ };
|
|
|
+private:
|
|
|
+ std::vector<TrieNode> list;
|
|
|
+ int size = 1;
|
|
|
+ int getFree() {
|
|
|
+ if (size<list.size()) { return size++; }
|
|
|
+ else {
|
|
|
+ list.resize(list.size()*2);
|
|
|
+ return size++;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ void insert(int nid, int token_id, std::string::const_iterator it, std::string::const_iterator end) {
|
|
|
+ auto& node = list[nid];
|
|
|
+ if (it==end) {
|
|
|
+ if (node.id==-1) { node.id=token_id; }
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ auto cid = node.children.find(*it);
|
|
|
+ if (cid==node.children.end()) {
|
|
|
+ int new_id = getFree();
|
|
|
+ list[nid].children.insert({*it, new_id}); // access the node again even after reallocation!!!
|
|
|
+ insert(new_id, token_id, it+1, end);
|
|
|
+ } else{
|
|
|
+ insert(cid->second, token_id, it+1, end);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ int find(int nid, int current_matched, std::string::const_iterator current_it, std::string::const_iterator& it, const std::string::const_iterator& end) {
|
|
|
+ const auto& node = list[nid];
|
|
|
+ if (node.id!=-1) {
|
|
|
+ current_matched = node.id;
|
|
|
+ current_it = it;
|
|
|
+ }
|
|
|
+ auto cid = node.children.find(*it);
|
|
|
+ if (cid != node.children.end()) {
|
|
|
+ return find(cid->second, current_matched, current_it, ++it, end);
|
|
|
+ } else {
|
|
|
+ if (node.id!=-1) { return node.id; }
|
|
|
+ else { it = current_it; return current_matched;}
|
|
|
+ }
|
|
|
+ }
|
|
|
+public:
|
|
|
+ Trie(int initial_size=10000) {
|
|
|
+ list.resize(initial_size); // init the allocate size
|
|
|
+ size = 1; // root
|
|
|
+ }
|
|
|
+ void insert(std::pair<const std::string&, int> entry) {
|
|
|
+ insert(0, entry.second, entry.first.begin(), entry.first.end());
|
|
|
+ }
|
|
|
+ int find(std::string::const_iterator& it, const std::string::const_iterator& end) {
|
|
|
+ if (it==end) { return -1; }
|
|
|
+ return find(0, -1, it+1, it, end);
|
|
|
+ }
|
|
|
+};
|
|
|
+
|
|
|
+
|
|
|
class Tokenizer {
|
|
|
public:
|
|
|
static constexpr int MAGIC_NUMBER = 430;
|
|
@@ -149,15 +211,19 @@ public:
|
|
|
protected:
|
|
|
virtual bool load_vocab(std::ifstream& file) override;
|
|
|
virtual void encode(const std::string& str, std::vector<int>& ids) override;
|
|
|
- std::unordered_map<std::string, int> encoder_;
|
|
|
+ Trie encoder_;
|
|
|
std::vector<std::string> decoder_;
|
|
|
};
|
|
|
|
|
|
-class BertTokenizer : public Tiktoken {
|
|
|
+class BertTokenizer : public Tokenizer {
|
|
|
public:
|
|
|
BertTokenizer() = default;
|
|
|
+ virtual std::string decode(int id) override;
|
|
|
protected:
|
|
|
+ virtual bool load_vocab(std::ifstream& file) override;
|
|
|
virtual void encode(const std::string& str, std::vector<int>& ids) override;
|
|
|
+ std::unordered_map<std::string, int> encoder_;
|
|
|
+ std::vector<std::string> decoder_;
|
|
|
private:
|
|
|
std::vector<int> word_piece(const std::string& token);
|
|
|
};
|