diff --git a/examples/llama_qnn_aot/aot_run.cpp b/examples/llama_qnn_aot/aot_run.cpp index c19183533..4847954e0 100644 --- a/examples/llama_qnn_aot/aot_run.cpp +++ b/examples/llama_qnn_aot/aot_run.cpp @@ -4,8 +4,7 @@ #include #include "mllm/backends/qnn/aot_rt/QnnAOTRuntime.hpp" #include "configuration_llama3.hpp" -#include "mllm/models/llama/tokenization_tiny_llama.hpp" -#include "mllm/models/qwen3/tokenization_qwen3.hpp" +#include "mllm/models/llama/tokenization_llama.hpp" using mllm::Argparse; using namespace mllm::qnn::aot; // NOLINT @@ -16,8 +15,8 @@ MLLM_MAIN({ auto& tokenizer_path = Argparse::add("-t|--tokenizer").help("Tokenizer path").def("tokenizer.json"); auto& config_path = Argparse::add("-c|--config").help("Config path").required(true); auto& ar_len = Argparse::add("--ar_len").help("Autoregressive length (chunk size)").def(128); - auto& seq_len = Argparse::add("--seq_len").help("Input sequence length").def(800); - auto& gen_len = Argparse::add("--gen_len").help("Generate token length").def(32); + // auto& seq_len = Argparse::add("--seq_len").help("Input sequence length").def(800); + // auto& gen_len = Argparse::add("--gen_len").help("Generate token length").def(32); Argparse::parse(argc, argv); @@ -37,22 +36,36 @@ MLLM_MAIN({ config.vocab_size = llama_cfg.vocab_size; config.context_len = 1024; config.ar_len = ar_len.get(); + config.type = "llama3"; // Note: Using Qwen3 tokenizer as a placeholder. // For production use, you should implement a Llama3Tokenizer or use // the appropriate tokenizer for your model. - auto tokenizer = mllm::models::llama::TinyLlamaTokenizer(tokenizer_path.get()); + auto tokenizer = mllm::models::llama::LlamaTokenizer(tokenizer_path.get()); - auto input_tensor = tokenizer.convertMessage({{ - .role = "user", - .content = "hello", - }}); + // auto input_tensor = tokenizer.convertMessage({{ + // .role = "user", + // .content = "hello", + // }}); - input_tensor["sequence"] = mllm::Tensor::arange(0, seq_len.get(), 1, mllm::kInt64, mllm::kCPU).view({1, -1}); + // input_tensor["sequence"] = mllm::Tensor::arange(0, seq_len.get(), 1, mllm::kInt64, mllm::kCPU).view({1, -1}); - // DBG: - mllm::print(input_tensor["sequence"].shape()); - mllm::print(input_tensor["sequence"]); + // // DBG: + // mllm::print(input_tensor["sequence"].shape()); + // mllm::print(input_tensor["sequence"]); + + // Runner runner(config, &tokenizer); + // if (!runner.load()) { + // std::cerr << "Failed to load model\n"; + // return 1; + // } + + + std::string prompt_text; + fmt::print("💬 Prompt text (or 'exit/quit'): "); + std::getline(std::cin, prompt_text); + + auto input_tensor = tokenizer.convertMessage({{.role = "user", .content = prompt_text}}); Runner runner(config, &tokenizer); if (!runner.load()) { @@ -60,8 +73,8 @@ MLLM_MAIN({ return 1; } - runner.generate( - input_tensor["sequence"], gen_len.get(), [](const std::string& token) { std::cout << token << std::flush; }, true); + runner.generate(input_tensor["sequence"], config.context_len, + [](const std::string& token) { std::cout << token << std::flush; }); std::cout << "\n"; return 0; diff --git a/mllm/backends/qnn/aot_rt/QnnAOTConfig.hpp b/mllm/backends/qnn/aot_rt/QnnAOTConfig.hpp index 8943d6cec..3486cc3f9 100644 --- a/mllm/backends/qnn/aot_rt/QnnAOTConfig.hpp +++ b/mllm/backends/qnn/aot_rt/QnnAOTConfig.hpp @@ -3,11 +3,14 @@ #pragma once +#include #include "mllm/core/DataTypes.hpp" namespace mllm::qnn::aot { struct QnnAOTConfig { + std::string type = "qwen3"; + int num_layers = 28; int num_heads = 12; int head_dim = 128; diff --git a/mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp b/mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp index ae1fafa29..68d002c67 100644 --- a/mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp +++ b/mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp @@ -46,8 +46,22 @@ bool Runner::load() { // init token generator(decode) // TODO: EOS IDs auto eos_ids = std::make_unique>(); - eos_ids->insert(151643); - eos_ids->insert(151645); + // eos_ids->insert(151643); + // eos_ids->insert(151645); + + // Dynamically determine the currently loaded model based on the model name. + if (config_.type == "llama3") { + eos_ids->insert(128001); // <|end_of_text|> + eos_ids->insert(128008); // <|eom_id|> + eos_ids->insert(128009); // <|eot_id|> + } else if (config_.type == "qwen2"){ + eos_ids->insert(151643); + eos_ids->insert(151645); + } else{ + // qwen3 + eos_ids->insert(151643); + eos_ids->insert(151645); + } token_generator_ = std::make_unique>(tokenizer_, kv_manager_.get(), std::move(eos_ids), config_); diff --git a/mllm/models/llama/tokenization_llama.hpp b/mllm/models/llama/tokenization_llama.hpp new file mode 100644 index 000000000..ad5f2ca15 --- /dev/null +++ b/mllm/models/llama/tokenization_llama.hpp @@ -0,0 +1,245 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "mllm/preprocessor/tokenizers/BPE.hpp" +#include "mllm/models/ARGeneration.hpp" +#include "mllm/preprocessor/tokenizers/Unicode.hpp" +#include "mllm/preprocessor/tokenizers/AutoTokenizer.hpp" + +namespace mllm::models::llama { + +// 适配 Llama 3 的正则切分逻辑 +inline bool llama3TokenizerMatchPattern(const std::wstring& str, size_t& pos, std::wstring& matched) { + if (pos >= str.size()) return false; + + // 1. 匹配缩写 + static const std::wstring contractions[] = {L"'s", L"'t", L"'re", L"'ve", L"'m", L"'ll", L"'d", L"'S", L"'T", L"'RE", L"'VE", L"'M", L"'LL", L"'D"}; + for (const auto& contraction : contractions) { + if (pos + contraction.size() <= str.size() && str.compare(pos, contraction.size(), contraction) == 0) { + matched = contraction; + pos += contraction.size(); + return true; + } + } + + // 2. 匹配字母 + { + size_t original_pos = pos; + matched.clear(); + if (!preprocessor::isLetter(str[pos]) && !preprocessor::isDigit(str[pos]) && str[pos] != L'\r' && str[pos] != L'\n') { + matched += str[pos]; + ++pos; + } + if (pos < str.size() && preprocessor::isLetter(str[pos])) { + do { + matched += str[pos]; + ++pos; + } while (pos < str.size() && preprocessor::isLetter(str[pos])); + return true; + } + pos = original_pos; + } + + // 3. 匹配数字 + if (preprocessor::isDigit(str[pos])) { + matched = str.substr(pos, 1); + ++pos; + return true; + } + + // 4. 匹配符号 + { + size_t start = pos; + if (str[pos] == L' ') { ++pos; } + if (pos < str.size() && !std::iswspace(str[pos]) && !preprocessor::isLetter(str[pos]) && !preprocessor::isDigit(str[pos])) { + do { ++pos; } while (pos < str.size() && !std::iswspace(str[pos]) && !preprocessor::isLetter(str[pos]) && !preprocessor::isDigit(str[pos])); + matched = str.substr(start, pos - start); + while (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) { + matched += str[pos]; + ++pos; + } + return true; + } + pos = start; + } + + // 5. 匹配空格 + if (std::iswspace(str[pos])) { + size_t start = pos; + while (pos < str.size() && std::iswspace(str[pos])) ++pos; + matched = str.substr(start, pos - start); + return true; + } + + return false; +} + +inline void llama3Regex(const std::string& str, std::vector& splitted) { + auto w_string = preprocessor::utf8string2WideString(str); + size_t pos = 0; + while (pos < w_string.size()) { + std::wstring matched; + if (llama3TokenizerMatchPattern(w_string, pos, matched)) { + splitted.push_back(matched); + } else { + ++pos; + } + } +} + +struct LlamaMessage { + std::string role; + std::string content; +}; + +// 恢复继承自原版的 AutoTokenizer,满足 aot_run.cpp 的要求 +class LlamaTokenizer final : public mllm::preprocessor::AutoTokenizer { + public: + explicit LlamaTokenizer(const std::string& file_path, bool add_bos = true) : add_bos_(add_bos) { + preprocessor::initLocal(); + // 恢复内建的字典映射机制 + preprocessor::makeBytes2UnicodeMap(bytes_2_unicode_dict_); + for (auto& kv : bytes_2_unicode_dict_) { bytes_2_unicode_dict_inverse_.insert({kv.second, kv.first}); } + + bpe_.initFromSentencePieceJson(file_path); + + special_tokens_trie_.add(L"<|begin_of_text|>"); + special_tokens_trie_.add(L"<|end_of_text|>"); + special_tokens_trie_.add(L"<|start_header_id|>"); + special_tokens_trie_.add(L"<|end_header_id|>"); + special_tokens_trie_.add(L"<|eot_id|>"); + } + + std::string getSystemPromptPrefix() { + std::time_t t = std::time(nullptr); + std::tm tm_ = *std::localtime(&t); + std::ostringstream oss; + oss << std::put_time(&tm_, "%d %b %Y"); + return "Cutting Knowledge Date: December 2023\nToday Date: " + oss.str() + "\n\n"; + } + + inline std::string applyChatTemplate(const std::vector& messages, bool add_generation_prompt = true) { + std::string result = ""; + if (add_bos_) result += "<|begin_of_text|>"; + for (const auto& msg : messages) { + std::string content = msg.content; + if (msg.role == "system") content = getSystemPromptPrefix() + content; + result += "<|start_header_id|>" + msg.role + "<|end_header_id|>\n\n" + content + "<|eot_id|>"; + } + if (add_generation_prompt) result += "<|start_header_id|>assistant<|end_header_id|>\n\n"; + return result; + } + + std::vector _tokenize(const std::string& str) override { + std::vector ret; + std::vector splitted; + llama3Regex(str, splitted); + for (const auto& s : splitted) { + auto utf_8_str = preprocessor::wideString2Utf8String(s); + std::wstring mapped_str; + // 执行字节映射 + for (unsigned char c : utf_8_str) { mapped_str.push_back(bytes_2_unicode_dict_[c]); } + auto bpe_ts = bpe_._bpe(mapped_str); + for (const auto& bpe_t : bpe_ts) { ret.push_back(bpe_t); } + } + return ret; + } + + std::vector tokenize(const std::string& str) override { + std::string processed_str = str; + bool text_has_bos = (processed_str.find("<|begin_of_text|>") == 0); + if (add_bos_ && !text_has_bos) { + processed_str = "<|begin_of_text|>" + processed_str; + } + + auto tokens = special_tokens_trie_.split(preprocessor::utf8string2WideString(processed_str)); + std::vector all_tokens; + for (const auto& token : tokens) { + if (special_tokens_trie_.isSpecialToken(token)) { + all_tokens.emplace_back(token); + continue; + } + auto tmp_tokens = _tokenize(preprocessor::wideString2Utf8String(token)); + all_tokens.insert(all_tokens.end(), tmp_tokens.begin(), tmp_tokens.end()); + } + return all_tokens; + } + + std::wstring _detokenize(int64_t pos_idx) override { return bpe_._lookup_inverse_vocab(pos_idx); } + + std::wstring detokenize(int64_t pos_idx) override { + auto str = _detokenize(pos_idx); + std::string utf_8_str; + for (wchar_t c : str) { + if (bytes_2_unicode_dict_inverse_.count(c)) { + utf_8_str.push_back((unsigned char)(bytes_2_unicode_dict_inverse_[c])); + } else { + return str; + } + } + return mllm::preprocessor::utf8string2WideString(utf_8_str); + } + + Tensor convert2Ids(const std::vector& strs) override { + std::vector ids; + for (const auto& str : strs) { ids.emplace_back(bpe_._lookup_vocab(str)); } + Tensor ret = Tensor::empty({1, (int32_t)ids.size()}, kInt64, kCPU) + .setMemType(kExtraInput) + .setName("llama-tokenizer-i0") + .alloc(); + auto ptr = ret.ptr(); + for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; } + return ret; + } + + // 供 test_c.cpp 调用的便捷接口 + std::vector encode(const std::string& str) { + auto sub_tokens = tokenize(str); + std::vector ret; + for (auto& token : sub_tokens) { ret.emplace_back(bpe_._lookup_vocab(token)); } + return ret; + } + + std::string decode(const std::vector& ids) { + std::string ret; + for (auto& each_id : ids) { + auto wstr = detokenize(each_id); + ret += mllm::preprocessor::wideString2Utf8String(wstr); + } + return ret; + } + + ARGenerationOutputPast convertMessage(const std::vector& messages) { + auto applied_string = applyChatTemplate(messages, true); + auto sequence_str = tokenize(applied_string); + std::vector ids; + for (const auto& str : sequence_str) { ids.emplace_back(bpe_._lookup_vocab(str)); } + + Tensor sequence = Tensor::empty({1, (int32_t)ids.size()}, kInt64, kCPU) + .setMemType(kNormal) + .setName("llama-tokenizer-i0") + .alloc(); + auto ptr = sequence.ptr(); + for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; } + + return { + {"sequence", sequence}, + }; + } + + private: + bool add_bos_ = true; + preprocessor::BPE bpe_; + std::unordered_map bytes_2_unicode_dict_; + std::unordered_map bytes_2_unicode_dict_inverse_; +}; + +} // namespace mllm::models::llama \ No newline at end of file diff --git a/pymllm/mobile/backends/qualcomm/transformers/llama/modeling_llama.py b/pymllm/mobile/backends/qualcomm/transformers/llama/modeling_llama.py index 6b65f34b9..8ebf0afcd 100644 --- a/pymllm/mobile/backends/qualcomm/transformers/llama/modeling_llama.py +++ b/pymllm/mobile/backends/qualcomm/transformers/llama/modeling_llama.py @@ -302,8 +302,8 @@ def __init__(self, config: LlamaConfig, layer_idx: int): # QDQ self.q_proj_input_qdq = ActivationQDQ(bits=16) - self.k_proj_input_qdq = ActivationQDQ(bits=16) - self.v_proj_input_qdq = ActivationQDQ(bits=16) + # self.k_proj_input_qdq = ActivationQDQ(bits=16) + # self.v_proj_input_qdq = ActivationQDQ(bits=16) self.q_proj_output_qdq = ActivationQDQ(bits=16) self.k_proj_output_qdq = ActivationQDQ(bits=16) @@ -336,13 +336,13 @@ def __init__(self, config: LlamaConfig, layer_idx: int): ) self.k_rope_neg_half_qdq = ActivationQDQ(bits=16) self.k_rope_concat_observer.add_observer( - self.k_proj_input_qdq.fake_quant.activation_post_process + self.k_proj_output_qdq.fake_quant.activation_post_process ) self.k_rope_concat_observer.add_observer( self.k_rope_neg_half_qdq.fake_quant.activation_post_process ) self.q_rope_concat_observer.add_observer( - self.q_proj_input_qdq.fake_quant.activation_post_process + self.q_proj_output_qdq.fake_quant.activation_post_process ) self.q_rope_concat_observer.add_observer( self.q_rope_neg_half_qdq.fake_quant.activation_post_process @@ -384,12 +384,12 @@ def forward( query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) query_states = self.q_proj_output_qdq(query_states) - hidden_states_k = self.k_proj_input_qdq(hidden_states) - key_states = self.k_proj(hidden_states_k).view(hidden_shape).transpose(1, 2) + # hidden_states_k = self.k_proj_input_qdq(hidden_states) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj_output_qdq(key_states) - hidden_states_v = self.v_proj_input_qdq(hidden_states) - value_states = self.v_proj(hidden_states_v).view(hidden_shape).transpose(1, 2) + # hidden_states_v = self.v_proj_input_qdq(hidden_states) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings cos = cos.unsqueeze(1) @@ -399,7 +399,7 @@ def forward( + self.q_rope_mul_1_output_qdq( rotate_half( query_states, - self.q_proj_input_qdq.fake_quant.activation_post_process, + self.q_proj_output_qdq.fake_quant.activation_post_process, self.q_rope_neg_half_qdq, self.q_rope_concat_observer, ) @@ -411,7 +411,7 @@ def forward( + self.k_rope_mul_1_output_qdq( rotate_half( key_states, - self.k_proj_input_qdq.fake_quant.activation_post_process, + self.k_proj_output_qdq.fake_quant.activation_post_process, self.k_rope_neg_half_qdq, self.k_rope_concat_observer, ) diff --git a/pymllm/mobile/backends/qualcomm/transformers/llama/runner.py b/pymllm/mobile/backends/qualcomm/transformers/llama/runner.py index 45af95f8f..9aa1a4f73 100644 --- a/pymllm/mobile/backends/qualcomm/transformers/llama/runner.py +++ b/pymllm/mobile/backends/qualcomm/transformers/llama/runner.py @@ -251,6 +251,12 @@ def compile(self): print("Compile done.") def infer(self, prompt: str): + messages = [{"role": "user", "content": prompt}] + prompt = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) # Llama models typically don't use chat templates, so we tokenize directly model_inputs = self.tokenizer([prompt], return_tensors="pt").to( self.model.device @@ -308,12 +314,13 @@ def calibrate(self, num_samples=64, max_seq_length=512): for entry in dataset: if samples_processed >= num_samples: break - - if len(entry["text"].strip()) < 1024: + + text = entry["text"].strip() + if len(text) < 50: continue # Llama models typically don't use chat templates - text = entry["text"] + # text = entry["text"] model_inputs = self.tokenizer( [text], return_tensors="pt", @@ -322,16 +329,18 @@ def calibrate(self, num_samples=64, max_seq_length=512): padding=False, ).to(self.model.device) - # Only need Prefill stage: directly call forward - # This will trigger observer update statistics in ActivationQDQ - self.model.generate( - **model_inputs, - max_new_tokens=1, - do_sample=False, - temperature=None, - top_p=None, - top_k=None, - ) + self.model(**model_inputs) + + # # Only need Prefill stage: directly call forward + # # This will trigger observer update statistics in ActivationQDQ + # self.model.generate( + # **model_inputs, + # max_new_tokens=1, + # do_sample=False, + # temperature=None, + # top_p=None, + # top_k=None, + # ) samples_processed += 1 pbar.update(1) diff --git a/pymllm/mobile/backends/qualcomm/transformers/qwen2/modeling_qwen2.py b/pymllm/mobile/backends/qualcomm/transformers/qwen2/modeling_qwen2.py index a43d8b7ea..f8ad9ec56 100644 --- a/pymllm/mobile/backends/qualcomm/transformers/qwen2/modeling_qwen2.py +++ b/pymllm/mobile/backends/qualcomm/transformers/qwen2/modeling_qwen2.py @@ -186,12 +186,12 @@ def __init__(self, config: Qwen2Config, layer_idx: int): # QDQ self.q_proj_input_qdq = ActivationQDQ(bits=16) - self.k_proj_input_qdq = ActivationQDQ(bits=16) + # self.k_proj_input_qdq = ActivationQDQ(bits=16) self.q_proj_output_qdq = ActivationQDQ(bits=16) self.k_proj_output_qdq = ActivationQDQ(bits=16) - self.v_proj_input_qdq = ActivationQDQ(bits=16) + # self.v_proj_input_qdq = ActivationQDQ(bits=16) self.q_rope_mul_0_output_qdq = ActivationQDQ(bits=16) self.q_rope_mul_1_output_qdq = ActivationQDQ(bits=16) self.q_rope_add_0_output_qdq = ActivationQDQ(bits=16) @@ -220,13 +220,13 @@ def __init__(self, config: Qwen2Config, layer_idx: int): ) self.k_rope_neg_half_qdq = ActivationQDQ(bits=16) self.k_rope_concat_observer.add_observer( - self.k_proj_input_qdq.fake_quant.activation_post_process + self.k_proj_output_qdq.fake_quant.activation_post_process ) self.k_rope_concat_observer.add_observer( self.k_rope_neg_half_qdq.fake_quant.activation_post_process ) self.q_rope_concat_observer.add_observer( - self.q_proj_input_qdq.fake_quant.activation_post_process + self.q_proj_output_qdq.fake_quant.activation_post_process ) self.q_rope_concat_observer.add_observer( self.q_rope_neg_half_qdq.fake_quant.activation_post_process @@ -268,12 +268,12 @@ def forward( query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) query_states = self.q_proj_output_qdq(query_states) - hidden_states_k = self.k_proj_input_qdq(hidden_states) - key_states = self.k_proj(hidden_states_k).view(hidden_shape).transpose(1, 2) + # hidden_states_k = self.k_proj_input_qdq(hidden_states) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj_output_qdq(key_states) - hidden_states_v = self.v_proj_input_qdq(hidden_states) - value_states = self.v_proj(hidden_states_v).view(hidden_shape).transpose(1, 2) + # hidden_states_v = self.v_proj_input_qdq(hidden_states) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings cos = cos.unsqueeze(1) @@ -283,7 +283,7 @@ def forward( + self.q_rope_mul_1_output_qdq( rotate_half( query_states, - self.q_proj_input_qdq.fake_quant.activation_post_process, + self.q_proj_output_qdq.fake_quant.activation_post_process, self.q_rope_neg_half_qdq, self.q_rope_concat_observer, ) @@ -295,7 +295,7 @@ def forward( + self.k_rope_mul_1_output_qdq( rotate_half( key_states, - self.k_proj_input_qdq.fake_quant.activation_post_process, + self.k_proj_output_qdq.fake_quant.activation_post_process, self.k_rope_neg_half_qdq, self.k_rope_concat_observer, )