word similarity
This commit is contained in:
@@ -4,10 +4,24 @@ plugins {
|
||||
id 'kotlin-kapt'
|
||||
}
|
||||
|
||||
kapt {
|
||||
// Room uses javac stubs under kapt; keep parameter names for :bind variables.
|
||||
javacOptions {
|
||||
option("-parameters")
|
||||
}
|
||||
}
|
||||
|
||||
android {
|
||||
namespace 'com.digitalperson'
|
||||
compileSdk 34
|
||||
|
||||
sourceSets {
|
||||
main {
|
||||
// app/note/ref → assets 中为 ref/...(与 AppConfig.RefCorpus.ASSETS_ROOT 一致)
|
||||
assets.srcDirs = ['src/main/assets', 'note']
|
||||
}
|
||||
}
|
||||
|
||||
buildFeatures {
|
||||
buildConfig true
|
||||
}
|
||||
@@ -100,4 +114,9 @@ dependencies {
|
||||
|
||||
implementation project(':tuanjieLibrary')
|
||||
implementation files('../tuanjieLibrary/libs/unity-classes.jar')
|
||||
|
||||
// BGE tokenizer (BasicTokenizer) + SimilarityManager 批量相似度测试
|
||||
implementation 'com.google.guava:guava:31.1-android'
|
||||
implementation 'org.ejml:ejml-core:0.43.1'
|
||||
implementation 'org.ejml:ejml-simple:0.43.1'
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
<uses-feature android:name="android.hardware.camera.any" />
|
||||
|
||||
<application
|
||||
android:name="com.digitalperson.DigitalPersonApp"
|
||||
android:allowBackup="true"
|
||||
android:label="@string/app_name"
|
||||
android:supportsRtl="true"
|
||||
|
||||
5
app/src/main/assets/bge_models/PLACE_RKNN_HERE.txt
Normal file
5
app/src/main/assets/bge_models/PLACE_RKNN_HERE.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
Required assets for BGE (see AppConfig.Bge):
|
||||
bge-small-zh-v1.5.rknn, vocab.txt, tokenizer.json, tokenizer_config.json
|
||||
|
||||
First run: FileHelper.copyBgeModels / BgeEmbedding.initialize / SimilarityManager.initBgeModel
|
||||
copies from assets/bge_models to internal storage.
|
||||
BIN
app/src/main/assets/bge_models/bge-small-zh-v1.5.rknn
Normal file
BIN
app/src/main/assets/bge_models/bge-small-zh-v1.5.rknn
Normal file
Binary file not shown.
21278
app/src/main/assets/bge_models/tokenizer.json
Normal file
21278
app/src/main/assets/bge_models/tokenizer.json
Normal file
File diff suppressed because it is too large
Load Diff
15
app/src/main/assets/bge_models/tokenizer_config.json
Normal file
15
app/src/main/assets/bge_models/tokenizer_config.json
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"clean_up_tokenization_spaces": true,
|
||||
"cls_token": "[CLS]",
|
||||
"do_basic_tokenize": true,
|
||||
"do_lower_case": false,
|
||||
"mask_token": "[MASK]",
|
||||
"model_max_length": 512,
|
||||
"never_split": null,
|
||||
"pad_token": "[PAD]",
|
||||
"sep_token": "[SEP]",
|
||||
"strip_accents": null,
|
||||
"tokenize_chinese_chars": true,
|
||||
"tokenizer_class": "BertTokenizer",
|
||||
"unk_token": "[UNK]"
|
||||
}
|
||||
21128
app/src/main/assets/bge_models/vocab.txt
Normal file
21128
app/src/main/assets/bge_models/vocab.txt
Normal file
File diff suppressed because it is too large
Load Diff
504
app/src/main/cpp/BgeEngineRKNN.cpp
Normal file
504
app/src/main/cpp/BgeEngineRKNN.cpp
Normal file
@@ -0,0 +1,504 @@
|
||||
#include "BgeEngineRKNN.h"
|
||||
#include <android/log.h>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <unordered_map>
|
||||
|
||||
#define LOG_TAG "BgeEngineRKNN"
|
||||
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__)
|
||||
#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__)
|
||||
#define LOGD(...) __android_log_print(ANDROID_LOG_DEBUG, LOG_TAG, __VA_ARGS__)
|
||||
|
||||
BgeEngineRKNN::BgeEngineRKNN() : m_ctx(0), m_initialized(false), m_embedding_dim(768) {
|
||||
// 默认嵌入维度为768,可根据模型实际情况调整
|
||||
}
|
||||
|
||||
BgeEngineRKNN::~BgeEngineRKNN() {
|
||||
freeModel();
|
||||
}
|
||||
|
||||
int BgeEngineRKNN::loadModel(const char* modelPath, const char* vocabPath) {
|
||||
if (m_initialized) {
|
||||
LOGI("Model already loaded");
|
||||
return 0;
|
||||
}
|
||||
|
||||
// 加载词汇表(如果提供了路径)
|
||||
if (vocabPath != nullptr) {
|
||||
int vocab_ret = loadVocab(vocabPath);
|
||||
if (vocab_ret != 0) {
|
||||
LOGE("Failed to load vocab: %d", vocab_ret);
|
||||
return vocab_ret;
|
||||
}
|
||||
LOGI("Vocab loaded successfully with %zu tokens", m_vocab.size());
|
||||
} else {
|
||||
LOGE("No vocab path provided, using fallback tokenization");
|
||||
}
|
||||
|
||||
int ret = initRKNN(modelPath);
|
||||
if (ret != 0) {
|
||||
LOGE("Failed to initialize RKNN: %d", ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
m_initialized = true;
|
||||
LOGI("BGE model loaded successfully");
|
||||
return 0;
|
||||
}
|
||||
|
||||
void BgeEngineRKNN::freeModel() {
|
||||
if (m_initialized) {
|
||||
deinitRKNN();
|
||||
m_initialized = false;
|
||||
LOGI("BGE model freed");
|
||||
}
|
||||
}
|
||||
|
||||
float* BgeEngineRKNN::getEmbedding(const std::string& text) {
|
||||
if (!m_initialized) {
|
||||
LOGE("Model not initialized");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// 简单的tokenize实现(需要替换为完整的tokenizer)
|
||||
std::vector<int> tokens = tokenize(text);
|
||||
if (tokens.empty()) {
|
||||
LOGE("Tokenization failed");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// 推理获取嵌入
|
||||
float* embedding = inferEmbedding(tokens);
|
||||
if (embedding == nullptr) {
|
||||
LOGE("Failed to get embedding");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return embedding;
|
||||
}
|
||||
|
||||
float BgeEngineRKNN::calculateSimilarity(const std::string& text1, const std::string& text2) {
|
||||
if (!m_initialized) {
|
||||
LOGE("Model not initialized");
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
// 获取两个文本的嵌入
|
||||
float* embedding1 = getEmbedding(text1);
|
||||
if (embedding1 == nullptr) {
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
float* embedding2 = getEmbedding(text2);
|
||||
if (embedding2 == nullptr) {
|
||||
delete[] embedding1;
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
// 归一化嵌入
|
||||
std::vector<float> normalized1 = normalizeEmbedding(embedding1, m_embedding_dim);
|
||||
std::vector<float> normalized2 = normalizeEmbedding(embedding2, m_embedding_dim);
|
||||
|
||||
// 计算余弦相似度
|
||||
float similarity = cosineSimilarity(normalized1, normalized2);
|
||||
|
||||
// 释放内存
|
||||
delete[] embedding1;
|
||||
delete[] embedding2;
|
||||
|
||||
return similarity;
|
||||
}
|
||||
|
||||
int BgeEngineRKNN::getEmbeddingDim() const {
|
||||
return m_embedding_dim;
|
||||
}
|
||||
|
||||
std::vector<int> BgeEngineRKNN::tokenize(const std::string& text) {
|
||||
std::vector<int> tokens;
|
||||
tokens.push_back(101); // [CLS] token
|
||||
|
||||
// 如果词汇表已加载,使用完整的tokenization流程
|
||||
if (!m_vocab.empty()) {
|
||||
// 步骤1: 基本分词
|
||||
std::vector<std::string> basic_tokens = basicTokenize(text);
|
||||
|
||||
// 步骤2: WordPiece分词
|
||||
std::vector<std::string> word_pieces = wordPieceTokenize(basic_tokens);
|
||||
|
||||
// 步骤3: 转换为token id
|
||||
for (const std::string& piece : word_pieces) {
|
||||
if (tokens.size() > 510) break; // 限制最大长度
|
||||
|
||||
auto it = m_vocab.find(piece);
|
||||
if (it != m_vocab.end()) {
|
||||
tokens.push_back(it->second);
|
||||
} else {
|
||||
// 对于未在词表中的token,使用[UNK] token
|
||||
auto unk_it = m_vocab.find("[UNK]");
|
||||
if (unk_it != m_vocab.end()) {
|
||||
tokens.push_back(unk_it->second);
|
||||
} else {
|
||||
// 如果[UNK]也不存在,使用0作为fallback
|
||||
tokens.push_back(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// fallback: 使用字符级分词(当词汇表未加载时)
|
||||
size_t i = 0;
|
||||
while (i < text.size() && tokens.size() <= 510) {
|
||||
unsigned char c = static_cast<unsigned char>(text[i]);
|
||||
int token_id;
|
||||
|
||||
// 处理不同长度的UTF-8字符
|
||||
if (c < 0x80) {
|
||||
// ASCII字符 (1字节)
|
||||
token_id = static_cast<int>(c) + 100;
|
||||
i++;
|
||||
} else if (c < 0xE0) {
|
||||
// 2字节UTF-8字符
|
||||
token_id = 3000 + ((c & 0x1F) << 6) + (static_cast<unsigned char>(text[i+1]) & 0x3F);
|
||||
i += 2;
|
||||
} else if (c < 0xF0) {
|
||||
// 3字节UTF-8字符(常见中文)
|
||||
token_id = 4000 + ((c & 0x0F) << 12) + ((static_cast<unsigned char>(text[i+1]) & 0x3F) << 6) + (static_cast<unsigned char>(text[i+2]) & 0x3F);
|
||||
i += 3;
|
||||
} else {
|
||||
// 4字节UTF-8字符
|
||||
token_id = 8000 + ((c & 0x07) << 18) + ((static_cast<unsigned char>(text[i+1]) & 0x3F) << 12) + ((static_cast<unsigned char>(text[i+2]) & 0x3F) << 6) + (static_cast<unsigned char>(text[i+3]) & 0x3F);
|
||||
i += 4;
|
||||
}
|
||||
|
||||
// 确保token_id在合理范围内
|
||||
token_id = token_id % 30000;
|
||||
tokens.push_back(token_id);
|
||||
}
|
||||
}
|
||||
|
||||
tokens.push_back(102); // [SEP] token
|
||||
return tokens;
|
||||
}
|
||||
|
||||
int BgeEngineRKNN::loadVocab(const char* vocabPath) {
|
||||
std::ifstream vocab_file(vocabPath);
|
||||
if (!vocab_file.is_open()) {
|
||||
LOGE("Failed to open vocab file: %s", vocabPath);
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::string line;
|
||||
int id = 0;
|
||||
while (std::getline(vocab_file, line)) {
|
||||
if (!line.empty()) {
|
||||
m_vocab[line] = id;
|
||||
id++;
|
||||
}
|
||||
}
|
||||
|
||||
vocab_file.close();
|
||||
|
||||
if (m_vocab.empty()) {
|
||||
LOGE("Vocab file is empty");
|
||||
return -2;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::vector<std::string> BgeEngineRKNN::basicTokenize(const std::string& text) {
|
||||
std::vector<std::string> tokens;
|
||||
|
||||
size_t i = 0;
|
||||
while (i < text.size()) {
|
||||
unsigned char c = static_cast<unsigned char>(text[i]);
|
||||
|
||||
if (std::isspace(c)) {
|
||||
// 跳过空格
|
||||
i++;
|
||||
} else if (c < 0x80) {
|
||||
// ASCII字符,按空格切分
|
||||
std::string token;
|
||||
while (i < text.size() && !std::isspace(static_cast<unsigned char>(text[i])) && static_cast<unsigned char>(text[i]) < 0x80) {
|
||||
token += text[i];
|
||||
i++;
|
||||
}
|
||||
if (!token.empty()) {
|
||||
tokens.push_back(token);
|
||||
}
|
||||
} else {
|
||||
// 非ASCII字符(如中文),按字符切分
|
||||
// 处理UTF-8编码的字符
|
||||
if (c < 0xE0) {
|
||||
// 2字节UTF-8字符
|
||||
if (i + 1 < text.size()) {
|
||||
tokens.push_back(text.substr(i, 2));
|
||||
i += 2;
|
||||
} else {
|
||||
i++;
|
||||
}
|
||||
} else if (c < 0xF0) {
|
||||
// 3字节UTF-8字符(常见中文)
|
||||
if (i + 2 < text.size()) {
|
||||
tokens.push_back(text.substr(i, 3));
|
||||
i += 3;
|
||||
} else {
|
||||
i++;
|
||||
}
|
||||
} else {
|
||||
// 4字节UTF-8字符
|
||||
if (i + 3 < text.size()) {
|
||||
tokens.push_back(text.substr(i, 4));
|
||||
i += 4;
|
||||
} else {
|
||||
i++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
||||
std::vector<std::string> BgeEngineRKNN::wordPieceTokenize(const std::vector<std::string>& basic_tokens) {
|
||||
std::vector<std::string> word_pieces;
|
||||
|
||||
for (const std::string& token : basic_tokens) {
|
||||
if (m_vocab.find(token) != m_vocab.end()) {
|
||||
word_pieces.push_back(token);
|
||||
} else {
|
||||
// 简单的字符级分词作为 fallback
|
||||
for (char c : token) {
|
||||
word_pieces.push_back(std::string(1, c));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return word_pieces;
|
||||
}
|
||||
|
||||
std::vector<float> BgeEngineRKNN::normalizeEmbedding(const float* embedding, int dim) {
|
||||
std::vector<float> normalized(dim);
|
||||
float norm = 0.0f;
|
||||
|
||||
// 计算L2范数
|
||||
for (int i = 0; i < dim; i++) {
|
||||
norm += embedding[i] * embedding[i];
|
||||
}
|
||||
norm = std::sqrt(norm);
|
||||
|
||||
// 归一化
|
||||
if (norm > 0.0f) {
|
||||
for (int i = 0; i < dim; i++) {
|
||||
normalized[i] = embedding[i] / norm;
|
||||
}
|
||||
}
|
||||
|
||||
return normalized;
|
||||
}
|
||||
|
||||
float BgeEngineRKNN::cosineSimilarity(const std::vector<float>& vec1, const std::vector<float>& vec2) {
|
||||
if (vec1.size() != vec2.size()) {
|
||||
LOGE("Vector dimension mismatch");
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
float dot_product = 0.0f;
|
||||
for (size_t i = 0; i < vec1.size(); i++) {
|
||||
dot_product += vec1[i] * vec2[i];
|
||||
}
|
||||
|
||||
return dot_product;
|
||||
}
|
||||
|
||||
int BgeEngineRKNN::initRKNN(const char* modelPath) {
|
||||
int ret = 0;
|
||||
rknn_context ctx = 0;
|
||||
|
||||
// 初始化RKNN
|
||||
ret = rknn_init(&ctx, (void*)modelPath, 0, 0, NULL);
|
||||
if (ret != RKNN_SUCC) {
|
||||
LOGE("rknn_init failed: %d", ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
m_ctx = ctx;
|
||||
|
||||
// 查询模型输入输出信息
|
||||
rknn_input_output_num io_num;
|
||||
ret = rknn_query(m_ctx, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num));
|
||||
if (ret != RKNN_SUCC) {
|
||||
LOGE("rknn_query input_output_num failed: %d", ret);
|
||||
rknn_destroy(ctx);
|
||||
return ret;
|
||||
}
|
||||
|
||||
LOGI("Input number: %d, Output number: %d", io_num.n_input, io_num.n_output);
|
||||
|
||||
// 查询所有输入属性
|
||||
for (int i = 0; i < io_num.n_input; i++) {
|
||||
rknn_tensor_attr input_attr;
|
||||
memset(&input_attr, 0, sizeof(input_attr));
|
||||
input_attr.index = i;
|
||||
ret = rknn_query(m_ctx, RKNN_QUERY_INPUT_ATTR, &input_attr, sizeof(input_attr));
|
||||
if (ret != RKNN_SUCC) {
|
||||
LOGE("rknn_query input_attr %d failed: %d", i, ret);
|
||||
rknn_destroy(ctx);
|
||||
return ret;
|
||||
}
|
||||
|
||||
LOGI("Input %d info:", i);
|
||||
LOGI(" type: %d", input_attr.type);
|
||||
LOGI(" size: %d", input_attr.size);
|
||||
LOGI(" n_dims: %d", input_attr.n_dims);
|
||||
LOGI(" dims:");
|
||||
for (int j = 0; j < input_attr.n_dims; j++) {
|
||||
LOGI(" dim[%d] = %d", j, input_attr.dims[j]);
|
||||
}
|
||||
}
|
||||
|
||||
// 查询所有输出属性
|
||||
for (int i = 0; i < io_num.n_output; i++) {
|
||||
rknn_tensor_attr output_attr;
|
||||
memset(&output_attr, 0, sizeof(output_attr));
|
||||
output_attr.index = i;
|
||||
ret = rknn_query(m_ctx, RKNN_QUERY_OUTPUT_ATTR, &output_attr, sizeof(output_attr));
|
||||
if (ret != RKNN_SUCC) {
|
||||
LOGE("rknn_query output_attr %d failed: %d", i, ret);
|
||||
rknn_destroy(ctx);
|
||||
return ret;
|
||||
}
|
||||
|
||||
// 打印输出信息
|
||||
LOGI("Output %d dims:", i);
|
||||
for (int j = 0; j < output_attr.n_dims; j++) {
|
||||
LOGI(" dim[%d] = %d", j, output_attr.dims[j]);
|
||||
}
|
||||
|
||||
// 检查是否为 output1 且维度符合要求
|
||||
if (i == 1) {
|
||||
bool is_valid_output = false;
|
||||
if (output_attr.n_dims == 2 && output_attr.dims[0] == 1 && output_attr.dims[1] == 512) {
|
||||
is_valid_output = true;
|
||||
} else if (output_attr.n_dims == 3 && output_attr.dims[0] == 1 && output_attr.dims[1] == 1 && output_attr.dims[2] == 512) {
|
||||
is_valid_output = true;
|
||||
}
|
||||
|
||||
if (is_valid_output) {
|
||||
LOGI("Output 1 has valid embedding dimensions: using Output 1");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 确定嵌入维度(暂时使用 output 0,后续可能需要修改为使用 output 1)
|
||||
rknn_tensor_attr output0_attr;
|
||||
memset(&output0_attr, 0, sizeof(output0_attr));
|
||||
output0_attr.index = 0;
|
||||
ret = rknn_query(m_ctx, RKNN_QUERY_OUTPUT_ATTR, &output0_attr, sizeof(output0_attr));
|
||||
if (ret == RKNN_SUCC && output0_attr.n_dims >= 2) {
|
||||
m_embedding_dim = output0_attr.dims[output0_attr.n_dims - 1];
|
||||
LOGI("Embedding dimension (from output 0): %d", m_embedding_dim);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void BgeEngineRKNN::deinitRKNN() {
|
||||
if (m_ctx != 0) {
|
||||
rknn_destroy(m_ctx);
|
||||
m_ctx = 0;
|
||||
}
|
||||
}
|
||||
|
||||
float* BgeEngineRKNN::inferEmbedding(const std::vector<int>& tokens) {
|
||||
// 准备输入数据
|
||||
const int max_seq_len = 512;
|
||||
std::vector<int> padded_tokens = tokens;
|
||||
|
||||
// 填充到最大序列长度
|
||||
if (padded_tokens.size() < max_seq_len) {
|
||||
padded_tokens.resize(max_seq_len, 0); // 使用0作为padding
|
||||
} else if (padded_tokens.size() > max_seq_len) {
|
||||
padded_tokens.resize(max_seq_len); // 截断
|
||||
}
|
||||
|
||||
// 准备attention_mask:1表示真实token,0表示padding
|
||||
std::vector<int> attention_mask(max_seq_len, 0);
|
||||
for (int i = 0; i < tokens.size() && i < max_seq_len; i++) {
|
||||
attention_mask[i] = 1;
|
||||
}
|
||||
|
||||
// 准备token_type_ids:对于单句子任务,全部为0
|
||||
std::vector<int> token_type_ids(max_seq_len, 0);
|
||||
|
||||
// 准备输入张量
|
||||
rknn_input inputs[3]; // 最多3个输入:input_ids, attention_mask, token_type_ids
|
||||
memset(inputs, 0, sizeof(inputs));
|
||||
|
||||
// 输入0: input_ids
|
||||
inputs[0].index = 0;
|
||||
inputs[0].type = RKNN_TENSOR_INT32;
|
||||
inputs[0].size = sizeof(int) * max_seq_len;
|
||||
inputs[0].buf = padded_tokens.data();
|
||||
inputs[0].pass_through = 0;
|
||||
|
||||
// 输入1: attention_mask
|
||||
inputs[1].index = 1;
|
||||
inputs[1].type = RKNN_TENSOR_INT32;
|
||||
inputs[1].size = sizeof(int) * max_seq_len;
|
||||
inputs[1].buf = attention_mask.data();
|
||||
inputs[1].pass_through = 0;
|
||||
|
||||
// 输入2: token_type_ids
|
||||
inputs[2].index = 2;
|
||||
inputs[2].type = RKNN_TENSOR_INT32;
|
||||
inputs[2].size = sizeof(int) * max_seq_len;
|
||||
inputs[2].buf = token_type_ids.data();
|
||||
inputs[2].pass_through = 0;
|
||||
|
||||
// 设置输入(根据实际的输入数量)
|
||||
rknn_input_output_num io_num;
|
||||
int ret = rknn_query(m_ctx, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num));
|
||||
if (ret != RKNN_SUCC) {
|
||||
LOGE("rknn_query input_output_num failed: %d", ret);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ret = rknn_inputs_set(m_ctx, io_num.n_input, inputs);
|
||||
if (ret != RKNN_SUCC) {
|
||||
LOGE("rknn_inputs_set failed: %d", ret);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// 运行推理
|
||||
ret = rknn_run(m_ctx, NULL);
|
||||
if (ret != RKNN_SUCC) {
|
||||
LOGE("rknn_run failed: %d", ret);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// 获取输出
|
||||
rknn_output outputs[2]; // 增加到2个输出
|
||||
memset(outputs, 0, sizeof(outputs));
|
||||
outputs[0].want_float = 1;
|
||||
outputs[1].want_float = 1;
|
||||
|
||||
ret = rknn_outputs_get(m_ctx, 2, outputs, NULL);
|
||||
if (ret != RKNN_SUCC) {
|
||||
LOGE("rknn_outputs_get failed: %d", ret);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// 提取嵌入(使用output 1)
|
||||
float* embedding = new float[m_embedding_dim];
|
||||
memcpy(embedding, outputs[1].buf, sizeof(float) * m_embedding_dim);
|
||||
LOGI("Using output 1 for embedding");
|
||||
|
||||
// 释放输出
|
||||
rknn_outputs_release(m_ctx, 2, outputs);
|
||||
|
||||
return embedding;
|
||||
}
|
||||
42
app/src/main/cpp/BgeEngineRKNN.h
Normal file
42
app/src/main/cpp/BgeEngineRKNN.h
Normal file
@@ -0,0 +1,42 @@
|
||||
#ifndef BGE_ENGINE_RKNN_H
|
||||
#define BGE_ENGINE_RKNN_H
|
||||
|
||||
#include <rknn_api.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
|
||||
class BgeEngineRKNN {
|
||||
public:
|
||||
BgeEngineRKNN();
|
||||
~BgeEngineRKNN();
|
||||
|
||||
int loadModel(const char* modelPath, const char* vocabPath = nullptr);
|
||||
void freeModel();
|
||||
float* getEmbedding(const std::string& text);
|
||||
float calculateSimilarity(const std::string& text1, const std::string& text2);
|
||||
int getEmbeddingDim() const; // 获取嵌入维度
|
||||
|
||||
private:
|
||||
rknn_context m_ctx;
|
||||
bool m_initialized;
|
||||
int m_embedding_dim;
|
||||
std::map<std::string, int> m_vocab; // 词汇表映射
|
||||
|
||||
// 辅助方法
|
||||
std::vector<int> tokenize(const std::string& text);
|
||||
std::vector<float> normalizeEmbedding(const float* embedding, int dim);
|
||||
float cosineSimilarity(const std::vector<float>& vec1, const std::vector<float>& vec2);
|
||||
|
||||
// Tokenizer相关方法
|
||||
int loadVocab(const char* vocabPath);
|
||||
std::vector<std::string> basicTokenize(const std::string& text);
|
||||
std::vector<std::string> wordPieceTokenize(const std::vector<std::string>& basic_tokens);
|
||||
|
||||
// RKNN相关方法
|
||||
int initRKNN(const char* modelPath);
|
||||
void deinitRKNN();
|
||||
float* inferEmbedding(const std::vector<int>& tokens);
|
||||
};
|
||||
|
||||
#endif // BGE_ENGINE_RKNN_H
|
||||
229
app/src/main/cpp/BgeEngineRKNNJNI.cpp
Normal file
229
app/src/main/cpp/BgeEngineRKNNJNI.cpp
Normal file
@@ -0,0 +1,229 @@
|
||||
#include <jni.h>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <android/log.h>
|
||||
#include "BgeEngineRKNN.h"
|
||||
|
||||
#define LOG_TAG "BgeEngineJNI"
|
||||
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__)
|
||||
#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__)
|
||||
|
||||
// 全局引擎实例映射
|
||||
static std::map<jlong, BgeEngineRKNN*> g_engine_map;
|
||||
static jclass g_float_array_class;
|
||||
static jmethodID g_float_array_ctor;
|
||||
|
||||
// 获取BGE引擎实例
|
||||
BgeEngineRKNN* getBgeEngineFromPtr(jlong ptr) {
|
||||
auto it = g_engine_map.find(ptr);
|
||||
if (it != g_engine_map.end()) {
|
||||
return it->second;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// JNI接口实现
|
||||
|
||||
// 创建BGE引擎
|
||||
extern "C" JNIEXPORT jlong JNICALL Java_com_digitalperson_engine_BgeEngineRKNN_createBgeEngine(
|
||||
JNIEnv* env,
|
||||
jobject thiz) {
|
||||
LOGI("Creating BGE engine");
|
||||
BgeEngineRKNN* engine = new BgeEngineRKNN();
|
||||
jlong ptr = reinterpret_cast<jlong>(engine);
|
||||
g_engine_map[ptr] = engine;
|
||||
return ptr;
|
||||
}
|
||||
|
||||
// 加载模型
|
||||
extern "C" JNIEXPORT jint JNICALL Java_com_digitalperson_engine_BgeEngineRKNN_loadModel(
|
||||
JNIEnv* env,
|
||||
jobject thiz,
|
||||
jlong ptr,
|
||||
jstring modelPath,
|
||||
jstring vocabPath) {
|
||||
LOGI("Loading BGE model with vocab");
|
||||
BgeEngineRKNN* engine = getBgeEngineFromPtr(ptr);
|
||||
if (engine == nullptr) {
|
||||
LOGE("Invalid engine pointer");
|
||||
return -1;
|
||||
}
|
||||
|
||||
const char* c_modelPath = env->GetStringUTFChars(modelPath, NULL);
|
||||
if (c_modelPath == nullptr) {
|
||||
LOGE("Failed to get model path string");
|
||||
return -1;
|
||||
}
|
||||
|
||||
const char* c_vocabPath = nullptr;
|
||||
if (vocabPath != nullptr) {
|
||||
c_vocabPath = env->GetStringUTFChars(vocabPath, NULL);
|
||||
if (c_vocabPath == nullptr) {
|
||||
LOGE("Failed to get vocab path string");
|
||||
env->ReleaseStringUTFChars(modelPath, c_modelPath);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
int ret = engine->loadModel(c_modelPath, c_vocabPath);
|
||||
|
||||
env->ReleaseStringUTFChars(modelPath, c_modelPath);
|
||||
if (c_vocabPath != nullptr) {
|
||||
env->ReleaseStringUTFChars(vocabPath, c_vocabPath);
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
// 释放模型
|
||||
extern "C" JNIEXPORT void JNICALL Java_com_digitalperson_engine_BgeEngineRKNN_freeModel(
|
||||
JNIEnv* env,
|
||||
jobject thiz,
|
||||
jlong ptr) {
|
||||
LOGI("Freeing BGE model");
|
||||
BgeEngineRKNN* engine = getBgeEngineFromPtr(ptr);
|
||||
if (engine == nullptr) {
|
||||
LOGE("Invalid engine pointer");
|
||||
return;
|
||||
}
|
||||
|
||||
engine->freeModel();
|
||||
}
|
||||
|
||||
// 获取嵌入
|
||||
extern "C" JNIEXPORT jfloatArray JNICALL Java_com_digitalperson_engine_BgeEngineRKNN_getEmbeddingNative(
|
||||
JNIEnv* env,
|
||||
jobject thiz,
|
||||
jlong ptr,
|
||||
jstring text) {
|
||||
BgeEngineRKNN* engine = getBgeEngineFromPtr(ptr);
|
||||
if (engine == nullptr) {
|
||||
LOGE("Invalid engine pointer");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const char* c_text = env->GetStringUTFChars(text, NULL);
|
||||
if (c_text == nullptr) {
|
||||
LOGE("Failed to get text string");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::string cpp_text(c_text);
|
||||
env->ReleaseStringUTFChars(text, c_text);
|
||||
|
||||
float* embedding = engine->getEmbedding(cpp_text);
|
||||
if (embedding == nullptr) {
|
||||
LOGE("Failed to get embedding");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// 获取正确的嵌入维度
|
||||
int embedding_dim = engine->getEmbeddingDim();
|
||||
LOGI("Using embedding dimension: %d", embedding_dim);
|
||||
|
||||
// 创建Java float数组并返回
|
||||
jfloatArray j_embedding = env->NewFloatArray(embedding_dim);
|
||||
if (j_embedding == nullptr) {
|
||||
LOGE("Failed to create float array");
|
||||
delete[] embedding;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
env->SetFloatArrayRegion(j_embedding, 0, embedding_dim, embedding);
|
||||
delete[] embedding;
|
||||
return j_embedding;
|
||||
}
|
||||
|
||||
// 获取嵌入维度
|
||||
extern "C" JNIEXPORT jint JNICALL Java_com_digitalperson_engine_BgeEngineRKNN_getEmbeddingDim(
|
||||
JNIEnv* env,
|
||||
jobject thiz,
|
||||
jlong ptr) {
|
||||
BgeEngineRKNN* engine = getBgeEngineFromPtr(ptr);
|
||||
if (engine == nullptr) {
|
||||
LOGE("Invalid engine pointer");
|
||||
return -1;
|
||||
}
|
||||
|
||||
int embedding_dim = engine->getEmbeddingDim();
|
||||
LOGI("Returning embedding dimension: %d", embedding_dim);
|
||||
return embedding_dim;
|
||||
}
|
||||
|
||||
// 计算相似度
|
||||
extern "C" JNIEXPORT jfloat JNICALL Java_com_digitalperson_engine_BgeEngineRKNN_calculateSimilarityNative(
|
||||
JNIEnv* env,
|
||||
jobject thiz,
|
||||
jlong ptr,
|
||||
jstring text1,
|
||||
jstring text2) {
|
||||
LOGI("Calculating similarity");
|
||||
BgeEngineRKNN* engine = getBgeEngineFromPtr(ptr);
|
||||
if (engine == nullptr) {
|
||||
LOGE("Invalid engine pointer");
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
const char* c_text1 = env->GetStringUTFChars(text1, NULL);
|
||||
const char* c_text2 = env->GetStringUTFChars(text2, NULL);
|
||||
if (c_text1 == nullptr || c_text2 == nullptr) {
|
||||
LOGE("Failed to get text strings");
|
||||
if (c_text1 != nullptr) {
|
||||
env->ReleaseStringUTFChars(text1, c_text1);
|
||||
}
|
||||
if (c_text2 != nullptr) {
|
||||
env->ReleaseStringUTFChars(text2, c_text2);
|
||||
}
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
std::string cpp_text1(c_text1);
|
||||
std::string cpp_text2(c_text2);
|
||||
env->ReleaseStringUTFChars(text1, c_text1);
|
||||
env->ReleaseStringUTFChars(text2, c_text2);
|
||||
|
||||
float similarity = engine->calculateSimilarity(cpp_text1, cpp_text2);
|
||||
return similarity;
|
||||
}
|
||||
|
||||
|
||||
|
||||
// 初始化JNI
|
||||
jint JNI_OnLoad(JavaVM* vm, void* reserved) {
|
||||
JNIEnv* env = nullptr;
|
||||
if (vm->GetEnv((void**)&env, JNI_VERSION_1_6) != JNI_OK) {
|
||||
LOGE("Failed to get JNI environment");
|
||||
return JNI_ERR;
|
||||
}
|
||||
|
||||
// 缓存类和方法ID
|
||||
jclass float_array_cls = env->FindClass("[F");
|
||||
if (float_array_cls == nullptr) {
|
||||
LOGE("Failed to find float array class");
|
||||
return JNI_ERR;
|
||||
}
|
||||
g_float_array_class = (jclass)env->NewGlobalRef(float_array_cls);
|
||||
|
||||
return JNI_VERSION_1_6;
|
||||
}
|
||||
|
||||
// 清理JNI
|
||||
void JNI_OnUnload(JavaVM* vm, void* reserved) {
|
||||
JNIEnv* env = nullptr;
|
||||
if (vm->GetEnv((void**)&env, JNI_VERSION_1_6) != JNI_OK) {
|
||||
LOGE("Failed to get JNI environment");
|
||||
return;
|
||||
}
|
||||
|
||||
// 释放全局引用
|
||||
if (g_float_array_class != nullptr) {
|
||||
env->DeleteGlobalRef(g_float_array_class);
|
||||
g_float_array_class = nullptr;
|
||||
}
|
||||
|
||||
// 清理引擎实例映射
|
||||
for (auto& pair : g_engine_map) {
|
||||
delete pair.second;
|
||||
}
|
||||
g_engine_map.clear();
|
||||
}
|
||||
@@ -58,5 +58,21 @@ if (ANDROID)
|
||||
jnigraphics
|
||||
log
|
||||
)
|
||||
|
||||
add_library(bgeEngine SHARED
|
||||
BgeEngineRKNN.cpp
|
||||
BgeEngineRKNNJNI.cpp
|
||||
)
|
||||
|
||||
target_include_directories(bgeEngine PRIVATE
|
||||
${CMAKE_CURRENT_LIST_DIR}
|
||||
${CMAKE_CURRENT_LIST_DIR}/zipformer_headers
|
||||
${CMAKE_CURRENT_LIST_DIR}/utils
|
||||
)
|
||||
|
||||
target_link_libraries(bgeEngine
|
||||
rknnrt
|
||||
log
|
||||
)
|
||||
endif()
|
||||
|
||||
|
||||
26
app/src/main/java/com/digitalperson/DigitalPersonApp.kt
Normal file
26
app/src/main/java/com/digitalperson/DigitalPersonApp.kt
Normal file
@@ -0,0 +1,26 @@
|
||||
package com.digitalperson
|
||||
|
||||
import android.app.Application
|
||||
import android.util.Log
|
||||
import com.digitalperson.config.AppConfig
|
||||
import com.digitalperson.embedding.RefEmbeddingIndexer
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.SupervisorJob
|
||||
import kotlinx.coroutines.launch
|
||||
|
||||
class DigitalPersonApp : Application() {
|
||||
|
||||
private val appScope = CoroutineScope(SupervisorJob() + Dispatchers.Default)
|
||||
|
||||
override fun onCreate() {
|
||||
super.onCreate()
|
||||
appScope.launch {
|
||||
try {
|
||||
RefEmbeddingIndexer.runOnce(this@DigitalPersonApp)
|
||||
} catch (t: Throwable) {
|
||||
Log.e(AppConfig.TAG, "[RefEmbed] 索引任务异常", t)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -38,6 +38,7 @@ import com.digitalperson.interaction.ConversationSummaryMemory
|
||||
|
||||
import java.io.File
|
||||
import android.graphics.BitmapFactory
|
||||
import android.widget.ImageView
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.Job
|
||||
@@ -49,6 +50,7 @@ import kotlinx.coroutines.withContext
|
||||
|
||||
import com.digitalperson.onboard_testing.FaceRecognitionTest
|
||||
import com.digitalperson.onboard_testing.LLMSummaryTest
|
||||
import com.digitalperson.embedding.RefImageMatcher
|
||||
|
||||
class Live2DChatActivity : AppCompatActivity() {
|
||||
companion object {
|
||||
@@ -109,6 +111,7 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
|
||||
private lateinit var faceRecognitionTest: FaceRecognitionTest
|
||||
private lateinit var llmSummaryTest: LLMSummaryTest
|
||||
private var refMatchImageView: ImageView? = null
|
||||
|
||||
override fun onRequestPermissionsResult(
|
||||
requestCode: Int,
|
||||
@@ -160,6 +163,8 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
live2dViewId = R.id.live2d_view
|
||||
)
|
||||
|
||||
refMatchImageView = findViewById(R.id.ref_match_image)
|
||||
|
||||
cameraPreviewView = findViewById(R.id.camera_preview)
|
||||
cameraPreviewView.implementationMode = PreviewView.ImplementationMode.COMPATIBLE
|
||||
faceOverlayView = findViewById(R.id.face_overlay)
|
||||
@@ -611,6 +616,7 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
runOnUiThread {
|
||||
uiManager.appendToUi("${filteredText.orEmpty()}\n")
|
||||
}
|
||||
maybeShowMatchedRefImage(filteredText ?: response)
|
||||
}
|
||||
interactionCoordinator.onCloudFinalResponse(response)
|
||||
}
|
||||
@@ -649,6 +655,24 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
}
|
||||
}
|
||||
|
||||
private fun maybeShowMatchedRefImage(text: String) {
|
||||
val imageView = refMatchImageView ?: return
|
||||
ioScope.launch {
|
||||
val match = RefImageMatcher.findBestMatch(applicationContext, text)
|
||||
if (match == null) return@launch
|
||||
val bitmap = try {
|
||||
assets.open(match.pngAssetPath).use { BitmapFactory.decodeStream(it) }
|
||||
} catch (_: Throwable) {
|
||||
null
|
||||
}
|
||||
if (bitmap == null) return@launch
|
||||
runOnUiThread {
|
||||
imageView.setImageBitmap(bitmap)
|
||||
imageView.visibility = android.view.View.VISIBLE
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun createTtsCallback() = object : TtsController.TtsCallback {
|
||||
override fun onTtsStarted(text: String) {
|
||||
runOnUiThread {
|
||||
|
||||
@@ -25,6 +25,7 @@ import android.view.View
|
||||
import androidx.lifecycle.Lifecycle
|
||||
import androidx.lifecycle.LifecycleOwner
|
||||
import androidx.lifecycle.LifecycleRegistry
|
||||
import android.widget.ImageView
|
||||
import com.unity3d.player.UnityPlayer
|
||||
import com.unity3d.player.UnityPlayerActivity
|
||||
import com.digitalperson.audio.AudioProcessor
|
||||
@@ -47,6 +48,8 @@ import com.digitalperson.tts.TtsController
|
||||
import com.digitalperson.util.FileHelper
|
||||
import com.digitalperson.vad.VadManager
|
||||
import kotlinx.coroutines.*
|
||||
import com.digitalperson.embedding.RefImageMatcher
|
||||
import android.graphics.BitmapFactory
|
||||
|
||||
class UnityDigitalPersonActivity : UnityPlayerActivity(), LifecycleOwner {
|
||||
|
||||
@@ -108,6 +111,7 @@ class UnityDigitalPersonActivity : UnityPlayerActivity(), LifecycleOwner {
|
||||
private lateinit var holdToSpeakButton: Button
|
||||
private var recordButtonGlow: View? = null
|
||||
private var pulseAnimator: ObjectAnimator? = null
|
||||
private var refMatchImageView: ImageView? = null
|
||||
|
||||
|
||||
// 音频和AI模块
|
||||
@@ -254,6 +258,7 @@ class UnityDigitalPersonActivity : UnityPlayerActivity(), LifecycleOwner {
|
||||
chatHistoryText = chatLayout.findViewById(R.id.my_text)
|
||||
holdToSpeakButton = chatLayout.findViewById(R.id.record_button)
|
||||
recordButtonGlow = chatLayout.findViewById(R.id.record_button_glow)
|
||||
refMatchImageView = chatLayout.findViewById(R.id.ref_match_image)
|
||||
|
||||
// 根据配置设置按钮可见性
|
||||
if (AppConfig.USE_HOLD_TO_SPEAK) {
|
||||
@@ -735,6 +740,7 @@ class UnityDigitalPersonActivity : UnityPlayerActivity(), LifecycleOwner {
|
||||
val filteredText = ttsController.speakLlmResponse(response)
|
||||
android.util.Log.d("UnityDigitalPerson", "LLM response filtered: ${filteredText?.take(60)}")
|
||||
if (filteredText != null) appendChat("助手: $filteredText")
|
||||
maybeShowMatchedRefImage(filteredText ?: response)
|
||||
interactionCoordinator.onCloudFinalResponse(filteredText ?: response.trim())
|
||||
}
|
||||
|
||||
@@ -751,6 +757,25 @@ class UnityDigitalPersonActivity : UnityPlayerActivity(), LifecycleOwner {
|
||||
}
|
||||
}
|
||||
|
||||
private fun maybeShowMatchedRefImage(text: String) {
|
||||
val imageView = refMatchImageView ?: return
|
||||
// Unity Activity already has coroutines
|
||||
CoroutineScope(SupervisorJob() + Dispatchers.IO).launch {
|
||||
val match = RefImageMatcher.findBestMatch(applicationContext, text)
|
||||
if (match == null) return@launch
|
||||
val bitmap = try {
|
||||
assets.open(match.pngAssetPath).use { BitmapFactory.decodeStream(it) }
|
||||
} catch (_: Throwable) {
|
||||
null
|
||||
}
|
||||
if (bitmap == null) return@launch
|
||||
runOnUiThread {
|
||||
imageView.setImageBitmap(bitmap)
|
||||
imageView.visibility = View.VISIBLE
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun requestLocalThought(prompt: String, onResult: (String) -> Unit) {
|
||||
val local = llmManager
|
||||
if (local == null) {
|
||||
|
||||
@@ -110,6 +110,19 @@ object AppConfig {
|
||||
const val MODEL_SIZE_ESTIMATE = 500L * 1024 * 1024 // 500MB
|
||||
}
|
||||
|
||||
/** BGE-small-zh-v1.5 文本嵌入(RKNN),用于语义相似度 / 检索。 */
|
||||
object Bge {
|
||||
const val ASSET_DIR = "bge_models"
|
||||
const val MODEL_FILE = "bge-small-zh-v1.5.rknn"
|
||||
}
|
||||
|
||||
/**
|
||||
* app/note/ref 通过 Gradle 额外 assets 目录打入 apk 后,在 assets 中的根路径为 `ref/`。
|
||||
*/
|
||||
object RefCorpus {
|
||||
const val ASSETS_ROOT = "ref"
|
||||
}
|
||||
|
||||
object OnboardTesting {
|
||||
// 测试人脸识别
|
||||
const val FACE_REGONITION = false
|
||||
|
||||
@@ -9,15 +9,24 @@ import com.digitalperson.data.dao.UserAnswerDao
|
||||
import com.digitalperson.data.dao.UserMemoryDao
|
||||
import com.digitalperson.data.dao.ChatMessageDao
|
||||
import com.digitalperson.data.dao.ConversationSummaryDao
|
||||
import com.digitalperson.data.dao.RefTextEmbeddingDao
|
||||
import com.digitalperson.data.entity.Question
|
||||
import com.digitalperson.data.entity.UserAnswer
|
||||
import com.digitalperson.data.entity.UserMemory
|
||||
import com.digitalperson.data.entity.ChatMessageEntity
|
||||
import com.digitalperson.data.entity.ConversationSummaryEntity
|
||||
import com.digitalperson.data.entity.RefTextEmbedding
|
||||
|
||||
@Database(
|
||||
entities = [UserMemory::class, Question::class, UserAnswer::class, ChatMessageEntity::class, ConversationSummaryEntity::class],
|
||||
version = 4,
|
||||
entities = [
|
||||
UserMemory::class,
|
||||
Question::class,
|
||||
UserAnswer::class,
|
||||
ChatMessageEntity::class,
|
||||
ConversationSummaryEntity::class,
|
||||
RefTextEmbedding::class
|
||||
],
|
||||
version = 5,
|
||||
exportSchema = false
|
||||
)
|
||||
abstract class AppDatabase : RoomDatabase() {
|
||||
@@ -26,6 +35,7 @@ abstract class AppDatabase : RoomDatabase() {
|
||||
abstract fun userAnswerDao(): UserAnswerDao
|
||||
abstract fun chatMessageDao(): ChatMessageDao
|
||||
abstract fun conversationSummaryDao(): ConversationSummaryDao
|
||||
abstract fun refTextEmbeddingDao(): RefTextEmbeddingDao
|
||||
|
||||
companion object {
|
||||
private const val DATABASE_NAME = "digital_human.db"
|
||||
|
||||
@@ -10,6 +10,17 @@ interface QuestionDao {
|
||||
@Insert
|
||||
fun insert(question: Question): Long
|
||||
|
||||
@Query(
|
||||
"""
|
||||
SELECT * FROM questions
|
||||
WHERE content = :content
|
||||
AND ((:subject IS NULL AND subject IS NULL) OR subject = :subject)
|
||||
AND ((:grade IS NULL AND grade IS NULL) OR grade = :grade)
|
||||
LIMIT 1
|
||||
"""
|
||||
)
|
||||
fun findByContentSubjectGrade(content: String, subject: String?, grade: Int?): Question?
|
||||
|
||||
@Query("SELECT * FROM questions WHERE subject = :subject ORDER BY difficulty")
|
||||
fun getQuestionsBySubject(subject: String): List<Question>
|
||||
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
package com.digitalperson.data.dao
|
||||
|
||||
import androidx.room.Dao
|
||||
import androidx.room.Insert
|
||||
import androidx.room.OnConflictStrategy
|
||||
import androidx.room.Query
|
||||
import com.digitalperson.data.entity.RefTextEmbedding
|
||||
|
||||
@Dao
|
||||
interface RefTextEmbeddingDao {
|
||||
|
||||
@Query("SELECT * FROM ref_text_embeddings WHERE assetPath = :path LIMIT 1")
|
||||
fun getByPath(path: String): RefTextEmbedding?
|
||||
|
||||
@Insert(onConflict = OnConflictStrategy.REPLACE)
|
||||
fun insert(row: RefTextEmbedding): Long
|
||||
|
||||
@Query("SELECT * FROM ref_text_embeddings")
|
||||
fun getAll(): List<RefTextEmbedding>
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
package com.digitalperson.data.entity
|
||||
|
||||
import androidx.room.Entity
|
||||
import androidx.room.Index
|
||||
import androidx.room.PrimaryKey
|
||||
import com.digitalperson.data.util.embeddingBytesToFloatArray
|
||||
|
||||
@Entity(
|
||||
tableName = "ref_text_embeddings",
|
||||
indices = [Index(value = ["assetPath"], unique = true)]
|
||||
)
|
||||
data class RefTextEmbedding(
|
||||
@PrimaryKey(autoGenerate = true) val id: Long = 0,
|
||||
/** assets 相对路径,如 ref/一年级.../xxx.txt */
|
||||
val assetPath: String,
|
||||
/** 参与嵌入的正文 UTF-8 的 SHA-256 十六进制,用于跳过未变更文件 */
|
||||
val contentHash: String,
|
||||
val dim: Int,
|
||||
val embedding: ByteArray,
|
||||
val updatedAt: Long = System.currentTimeMillis()
|
||||
) {
|
||||
fun toFloatArray(): FloatArray = embeddingBytesToFloatArray(embedding)
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package com.digitalperson.data.util
|
||||
|
||||
import java.nio.ByteBuffer
|
||||
import java.nio.ByteOrder
|
||||
|
||||
fun floatArrayToEmbeddingBytes(values: FloatArray): ByteArray {
|
||||
val bb = ByteBuffer.allocate(values.size * 4).order(ByteOrder.LITTLE_ENDIAN)
|
||||
for (v in values) {
|
||||
bb.putFloat(v)
|
||||
}
|
||||
return bb.array()
|
||||
}
|
||||
|
||||
fun embeddingBytesToFloatArray(blob: ByteArray): FloatArray {
|
||||
val bb = ByteBuffer.wrap(blob).order(ByteOrder.LITTLE_ENDIAN)
|
||||
val n = blob.size / 4
|
||||
return FloatArray(n) { bb.getFloat() }
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package com.digitalperson.embedding
|
||||
|
||||
import android.content.Context
|
||||
import com.digitalperson.config.AppConfig
|
||||
import com.digitalperson.engine.BgeEngineRKNN
|
||||
import com.digitalperson.util.FileHelper
|
||||
import java.io.File
|
||||
|
||||
/**
|
||||
* 懒加载 BGE 文本嵌入(RKNN)。用于对话文本与预存标注的语义相似度检索。
|
||||
*
|
||||
* 初始化会复制 [AppConfig.Bge.ASSET_DIR] 下资源到内部存储并加载 [AppConfig.Bge.MODEL_FILE]。
|
||||
*/
|
||||
object BgeEmbedding {
|
||||
|
||||
@Volatile
|
||||
private var engine: BgeEngineRKNN? = null
|
||||
|
||||
fun isReady(): Boolean = engine?.isInitialized == true
|
||||
|
||||
/**
|
||||
* 在主线程调用会阻塞;建议在后台线程或协程 [Dispatchers.IO] 中调用。
|
||||
*/
|
||||
@Synchronized
|
||||
fun initialize(context: Context): Boolean {
|
||||
if (engine?.isInitialized == true) return true
|
||||
val dir = FileHelper.copyBgeModels(context.applicationContext) ?: return false
|
||||
val path = File(dir, AppConfig.Bge.MODEL_FILE).absolutePath
|
||||
if (!File(path).exists()) return false
|
||||
val eng = BgeEngineRKNN(context.applicationContext)
|
||||
if (!eng.initialize(path)) return false
|
||||
engine = eng
|
||||
return true
|
||||
}
|
||||
|
||||
@Synchronized
|
||||
fun release() {
|
||||
engine?.deinitialize()
|
||||
engine = null
|
||||
}
|
||||
|
||||
fun getEmbedding(text: String): FloatArray? = engine?.getEmbedding(text)
|
||||
|
||||
fun similarity(text1: String, text2: String): Float? =
|
||||
engine?.calculateSimilarity(text1, text2)
|
||||
|
||||
fun embeddingDim(): Int = engine?.embeddingDim ?: -1
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
package com.digitalperson.embedding
|
||||
|
||||
import android.content.Context
|
||||
|
||||
internal object RefCorpusAssetScanner {
|
||||
|
||||
/**
|
||||
* 递归列出 [root] 目录下(含子目录)所有 `.txt` 的 assets 路径,使用 `/` 分隔。
|
||||
*/
|
||||
fun listTxtFilesUnder(context: Context, root: String): List<String> {
|
||||
val am = context.assets
|
||||
val out = ArrayList<String>()
|
||||
|
||||
fun walk(prefix: String) {
|
||||
val children = am.list(prefix) ?: return
|
||||
if (children.isEmpty()) return
|
||||
for (child in children) {
|
||||
if (child.isEmpty()) continue
|
||||
val full = "$prefix/$child"
|
||||
val grand = am.list(full)
|
||||
if (grand.isNullOrEmpty()) {
|
||||
if (child.endsWith(".txt", ignoreCase = true)) {
|
||||
out.add(full)
|
||||
}
|
||||
} else {
|
||||
walk(full)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
walk(root.removeSuffix("/"))
|
||||
return out.sorted()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,173 @@
|
||||
package com.digitalperson.embedding
|
||||
|
||||
import android.content.Context
|
||||
import android.util.Log
|
||||
import com.digitalperson.config.AppConfig
|
||||
import com.digitalperson.data.AppDatabase
|
||||
import com.digitalperson.data.entity.Question
|
||||
import com.digitalperson.data.entity.RefTextEmbedding
|
||||
import com.digitalperson.data.util.floatArrayToEmbeddingBytes
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.withContext
|
||||
import java.security.MessageDigest
|
||||
|
||||
/**
|
||||
* 开机(进程启动)时在后台扫描 assets 下 [AppConfig.RefCorpus.ASSETS_ROOT] 中全部 `.txt`,
|
||||
* 跳过 `#` 行后做 BGE 嵌入,写入 Room 并填充 [RefEmbeddingMemoryCache]。
|
||||
*
|
||||
* DAO 为同步接口,整个函数在 [Dispatchers.IO] 上运行,不阻塞主线程。
|
||||
*/
|
||||
object RefEmbeddingIndexer {
|
||||
|
||||
private const val TAG = AppConfig.TAG
|
||||
|
||||
suspend fun runOnce(context: Context) = withContext(Dispatchers.IO) {
|
||||
val app = context.applicationContext
|
||||
val db = AppDatabase.getInstance(app)
|
||||
val dao = db.refTextEmbeddingDao()
|
||||
val questionDao = db.questionDao()
|
||||
|
||||
if (!BgeEmbedding.initialize(app)) {
|
||||
Log.e(TAG, "[RefEmbed] BGE 初始化失败,跳过 ref 语料索引")
|
||||
return@withContext
|
||||
}
|
||||
|
||||
val root = AppConfig.RefCorpus.ASSETS_ROOT
|
||||
val paths = RefCorpusAssetScanner.listTxtFilesUnder(app, root)
|
||||
Log.i(TAG, "[RefEmbed] 发现 ${paths.size} 个 txt(root=$root)")
|
||||
|
||||
var skipped = 0
|
||||
var embedded = 0
|
||||
var empty = 0
|
||||
var failed = 0
|
||||
|
||||
for (path in paths) {
|
||||
val raw = try {
|
||||
app.assets.open(path).bufferedReader(Charsets.UTF_8).use { it.readText() }
|
||||
} catch (e: Exception) {
|
||||
Log.w(TAG, "[RefEmbed] 读取失败 $path: ${e.message}")
|
||||
failed++
|
||||
continue
|
||||
}
|
||||
|
||||
// 题库:遇到包含 ?/? 的行,写入 questions
|
||||
val subject = extractSubjectFromRaw(raw)
|
||||
val grade = extractGradeFromPath(path)
|
||||
val questionLines = extractQuestionLines(raw)
|
||||
for (line in questionLines) {
|
||||
val content = line.trim()
|
||||
if (content.isEmpty()) continue
|
||||
val exists = questionDao.findByContentSubjectGrade(content, subject, grade)
|
||||
if (exists == null) {
|
||||
questionDao.insert(
|
||||
Question(
|
||||
id = 0,
|
||||
content = content,
|
||||
answer = null,
|
||||
subject = subject,
|
||||
grade = grade,
|
||||
difficulty = 1,
|
||||
createdAt = System.currentTimeMillis()
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
val embedText = RefTxtEmbedText.fromRawFileContent(raw)
|
||||
if (embedText.isEmpty()) {
|
||||
empty++
|
||||
continue
|
||||
}
|
||||
|
||||
val hash = sha256Hex(embedText.toByteArray(Charsets.UTF_8))
|
||||
// 同步 DAO 调用(已在 IO 线程)
|
||||
val existing = dao.getByPath(path)
|
||||
if (existing != null && existing.contentHash == hash) {
|
||||
RefEmbeddingMemoryCache.put(path, normalizeL2(existing.toFloatArray()))
|
||||
skipped++
|
||||
continue
|
||||
}
|
||||
|
||||
val vec = BgeEmbedding.getEmbedding(embedText)
|
||||
if (vec == null || vec.isEmpty()) {
|
||||
Log.w(TAG, "[RefEmbed] 嵌入为空 $path")
|
||||
failed++
|
||||
continue
|
||||
}
|
||||
|
||||
val normalized = normalizeL2(vec)
|
||||
dao.insert(
|
||||
RefTextEmbedding(
|
||||
assetPath = path,
|
||||
contentHash = hash,
|
||||
dim = normalized.size,
|
||||
embedding = floatArrayToEmbeddingBytes(normalized)
|
||||
)
|
||||
)
|
||||
RefEmbeddingMemoryCache.put(path, normalized)
|
||||
embedded++
|
||||
}
|
||||
|
||||
Log.i(
|
||||
TAG,
|
||||
"[RefEmbed] 完成 embedded=$embedded skipped=$skipped empty=$empty failed=$failed cacheSize=${RefEmbeddingMemoryCache.size()}"
|
||||
)
|
||||
}
|
||||
|
||||
private fun extractSubjectFromRaw(raw: String): String? {
|
||||
val line = raw.lineSequence()
|
||||
.map { it.trimEnd() }
|
||||
.firstOrNull { it.trimStart().startsWith("#") }
|
||||
?: return null
|
||||
val s = line.trimStart().removePrefix("#").trim()
|
||||
return s.ifEmpty { null }
|
||||
}
|
||||
|
||||
private fun extractQuestionLines(raw: String): List<String> {
|
||||
return raw.lineSequence()
|
||||
.map { it.trimEnd() }
|
||||
.filter { it.isNotBlank() }
|
||||
.filter { !it.trimStart().startsWith("#") }
|
||||
.filter { it.contains('?') || it.contains('?') }
|
||||
.toList()
|
||||
}
|
||||
|
||||
private fun extractGradeFromPath(assetPath: String): Int? {
|
||||
// example: ref/一年级上-生活适应/... or ref/二年级下-...
|
||||
val idx = assetPath.indexOf("年级")
|
||||
if (idx <= 0) return null
|
||||
val prefix = assetPath.substring(0, idx)
|
||||
val cn = prefix.lastOrNull() ?: return null
|
||||
return chineseGradeToInt(cn)
|
||||
}
|
||||
|
||||
private fun chineseGradeToInt(c: Char): Int? {
|
||||
return when (c) {
|
||||
'一' -> 1
|
||||
'二' -> 2
|
||||
'三' -> 3
|
||||
'四' -> 4
|
||||
'五' -> 5
|
||||
'六' -> 6
|
||||
'七' -> 7
|
||||
'八' -> 8
|
||||
'九' -> 9
|
||||
'十' -> 10
|
||||
else -> null
|
||||
}
|
||||
}
|
||||
|
||||
private fun normalizeL2(v: FloatArray): FloatArray {
|
||||
var sum = 0.0
|
||||
for (x in v) sum += (x * x).toDouble()
|
||||
val norm = kotlin.math.sqrt(sum).toFloat()
|
||||
if (norm <= 1e-12f) return v.copyOf()
|
||||
return FloatArray(v.size) { i -> v[i] / norm }
|
||||
}
|
||||
|
||||
private fun sha256Hex(data: ByteArray): String {
|
||||
val md = MessageDigest.getInstance("SHA-256")
|
||||
val digest = md.digest(data)
|
||||
return digest.joinToString("") { "%02x".format(it) }
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
package com.digitalperson.embedding
|
||||
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
|
||||
/**
|
||||
* ref 语料 txt 嵌入的内存缓存,键为 assets 相对路径(与数据库 [com.digitalperson.data.entity.RefTextEmbedding.assetPath] 一致)。
|
||||
*/
|
||||
object RefEmbeddingMemoryCache {
|
||||
|
||||
private val vectors = ConcurrentHashMap<String, FloatArray>()
|
||||
|
||||
fun put(assetPath: String, embedding: FloatArray) {
|
||||
vectors[assetPath] = embedding.copyOf()
|
||||
}
|
||||
|
||||
fun get(assetPath: String): FloatArray? {
|
||||
val v = vectors[assetPath] ?: return null
|
||||
return v.copyOf()
|
||||
}
|
||||
|
||||
fun clear() {
|
||||
vectors.clear()
|
||||
}
|
||||
|
||||
/** 只读快照(向量已为拷贝,可安全使用)。 */
|
||||
fun snapshot(): Map<String, FloatArray> =
|
||||
vectors.mapValues { it.value.copyOf() }
|
||||
|
||||
fun size(): Int = vectors.size
|
||||
}
|
||||
@@ -0,0 +1,95 @@
|
||||
package com.digitalperson.embedding
|
||||
|
||||
import android.content.Context
|
||||
import android.util.Log
|
||||
import com.digitalperson.config.AppConfig
|
||||
import kotlin.math.sqrt
|
||||
|
||||
data class RefImageMatch(
|
||||
val txtAssetPath: String,
|
||||
val pngAssetPath: String,
|
||||
val score: Float
|
||||
)
|
||||
|
||||
object RefImageMatcher {
|
||||
|
||||
private const val TAG = AppConfig.TAG
|
||||
|
||||
/**
|
||||
* @param threshold 余弦相似度阈值(向量已归一化时等价于 dot product)。
|
||||
*/
|
||||
fun findBestMatch(
|
||||
context: Context,
|
||||
text: String,
|
||||
threshold: Float = 0.70f
|
||||
): RefImageMatch? {
|
||||
val query = text.trim()
|
||||
if (query.isEmpty()) return null
|
||||
|
||||
if (!BgeEmbedding.isReady()) {
|
||||
val ok = BgeEmbedding.initialize(context.applicationContext)
|
||||
if (!ok) {
|
||||
Log.w(TAG, "[RefMatch] BGE not ready, skip match")
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
val q = BgeEmbedding.getEmbedding(query) ?: return null
|
||||
if (q.isEmpty()) return null
|
||||
val qn = normalizeL2(q)
|
||||
|
||||
val vectors = RefEmbeddingMemoryCache.snapshot()
|
||||
if (vectors.isEmpty()) return null
|
||||
|
||||
var bestPath: String? = null
|
||||
var bestScore = -1f
|
||||
|
||||
for ((path, v) in vectors) {
|
||||
if (v.isEmpty() || v.size != qn.size) continue
|
||||
val score = dot(qn, v)
|
||||
if (score > bestScore) {
|
||||
bestScore = score
|
||||
bestPath = path
|
||||
}
|
||||
}
|
||||
|
||||
val txtPath = bestPath ?: return null
|
||||
if (bestScore < threshold) return null
|
||||
|
||||
val pngPath = if (txtPath.endsWith(".txt", ignoreCase = true)) {
|
||||
txtPath.dropLast(4) + ".png"
|
||||
} else {
|
||||
"$txtPath.png"
|
||||
}
|
||||
|
||||
// 不在这里 decode,只检查是否存在,避免 UI 线程 IO。
|
||||
val exists = try {
|
||||
context.assets.open(pngPath).close()
|
||||
true
|
||||
} catch (_: Throwable) {
|
||||
false
|
||||
}
|
||||
if (!exists) return null
|
||||
|
||||
return RefImageMatch(
|
||||
txtAssetPath = txtPath,
|
||||
pngAssetPath = pngPath,
|
||||
score = bestScore
|
||||
)
|
||||
}
|
||||
|
||||
private fun dot(a: FloatArray, b: FloatArray): Float {
|
||||
var s = 0f
|
||||
for (i in a.indices) s += a[i] * b[i]
|
||||
return s
|
||||
}
|
||||
|
||||
private fun normalizeL2(v: FloatArray): FloatArray {
|
||||
var sum = 0.0
|
||||
for (x in v) sum += (x * x).toDouble()
|
||||
val norm = sqrt(sum).toFloat()
|
||||
if (norm <= 1e-12f) return v.copyOf()
|
||||
return FloatArray(v.size) { i -> v[i] / norm }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
package com.digitalperson.embedding
|
||||
|
||||
/**
|
||||
* 从原始 txt 中构造待嵌入文本:去掉「以 # 开头」的行(行首可含空白),其余行以换行连接。
|
||||
*/
|
||||
object RefTxtEmbedText {
|
||||
|
||||
fun fromRawFileContent(raw: String): String {
|
||||
return raw.lineSequence()
|
||||
.map { it.trimEnd() }
|
||||
.filter { line ->
|
||||
if (line.isEmpty()) return@filter false
|
||||
!line.trimStart().startsWith("#")
|
||||
}
|
||||
.joinToString("\n")
|
||||
.trim()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,353 @@
|
||||
package com.digitalperson.embedding;
|
||||
|
||||
import android.content.Context;
|
||||
import android.os.Handler;
|
||||
import android.os.Looper;
|
||||
import android.util.Log;
|
||||
|
||||
import com.digitalperson.config.AppConfig;
|
||||
import com.digitalperson.engine.BgeEngineRKNN;
|
||||
import com.digitalperson.util.FileHelper;
|
||||
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class SimilarityManager {
|
||||
private static final String TAG = "SimilarityManager";
|
||||
|
||||
private Context mContext;
|
||||
private Handler mHandler;
|
||||
private BgeEngineRKNN mBgeEngine;
|
||||
|
||||
// 测试数据相关
|
||||
private List<float[]> mTestEmbeddings; // 预计算的100个句子的嵌入(float数组格式)
|
||||
private List<SimpleMatrix> mTestEmbeddingsEJML; // 预计算的100个句子的嵌入(EJML格式)
|
||||
private SimpleMatrix mTestEmbeddingsMatrix; // 所有测试嵌入的矩阵(嵌入维度 × 句子数量)
|
||||
private List<Double> mTestEmbeddingsNorms; // 预计算的100个句子的嵌入范数
|
||||
private SimpleMatrix mTestEmbeddingsNormsMatrix; // 预计算的嵌入范数向量(1 × 句子数量)
|
||||
private SimpleMatrix mTestEmbeddingsNormsReciprocalMatrix; // 预计算的嵌入范数倒数向量(1 × 句子数量)
|
||||
private List<String> mTestSentences; // 预生成的100个测试句子
|
||||
private boolean useEJMLForSimilarity = false; // 默认使用传统for循环计算相似度
|
||||
|
||||
public interface SimilarityListener {
|
||||
void onSimilarityCalculated(float similarity, long timeTaken);
|
||||
void onPerformanceTestComplete(long traditionalTime, long ejmlTime, double speedup);
|
||||
void onError(String errorMessage);
|
||||
}
|
||||
|
||||
private SimilarityListener mListener;
|
||||
|
||||
public SimilarityManager(Context context) {
|
||||
this.mContext = context;
|
||||
this.mHandler = new Handler(Looper.getMainLooper());
|
||||
}
|
||||
|
||||
public void setListener(SimilarityListener listener) {
|
||||
this.mListener = listener;
|
||||
}
|
||||
|
||||
// 初始化BGE模型
|
||||
public void initBgeModel() {
|
||||
try {
|
||||
File bgeModelDir = FileHelper.copyBgeModels(mContext);
|
||||
if (bgeModelDir == null) {
|
||||
Log.e(TAG, "BGE model directory copy failed");
|
||||
if (mListener != null) {
|
||||
mListener.onError("BGE 模型文件复制失败");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// 初始化BGE模型
|
||||
mBgeEngine = new BgeEngineRKNN(mContext);
|
||||
String modelPath = new File(bgeModelDir, AppConfig.Bge.MODEL_FILE).getAbsolutePath();
|
||||
|
||||
// 检查模型文件是否存在
|
||||
if (!new File(modelPath).exists()) {
|
||||
Log.e(TAG, "BGE model file does not exist: " + modelPath);
|
||||
if (mListener != null) {
|
||||
mListener.onError("BGE 模型文件不存在");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
boolean success = mBgeEngine.initialize(modelPath);
|
||||
if (!success) {
|
||||
Log.e(TAG, "Failed to initialize BGE model");
|
||||
if (mListener != null) {
|
||||
mListener.onError("BGE 模型初始化失败");
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
Log.e(TAG, "Exception in initBgeModel", e);
|
||||
if (mListener != null) {
|
||||
mListener.onError("初始化BGE模型异常: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 计算BGE相似度
|
||||
public void calculateSimilarity(String text1, String text2) {
|
||||
if (mBgeEngine == null || !mBgeEngine.isInitialized()) {
|
||||
Log.w(TAG, "BGE engine not initialized, skipping similarity calculation");
|
||||
if (mListener != null) {
|
||||
mListener.onError("BGE 模型未初始化");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// 在后台线程中计算相似度
|
||||
new Thread(() -> {
|
||||
try {
|
||||
long startTime = System.currentTimeMillis();
|
||||
float similarity = mBgeEngine.calculateSimilarity(text1, text2);
|
||||
long timeTaken = System.currentTimeMillis() - startTime;
|
||||
|
||||
if (mListener != null) {
|
||||
mListener.onSimilarityCalculated(similarity, timeTaken);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
Log.e(TAG, "Exception in calculateSimilarity", e);
|
||||
if (mListener != null) {
|
||||
mListener.onError("相似度计算失败: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
}).start();
|
||||
}
|
||||
|
||||
// 测试BGE性能
|
||||
public void testPerformance(String userInput) {
|
||||
if (mBgeEngine == null || !mBgeEngine.isInitialized()) {
|
||||
Log.w(TAG, "BGE engine not initialized, skipping performance test");
|
||||
if (mListener != null) {
|
||||
mListener.onError("BGE 模型未初始化");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// 在后台线程中执行测试
|
||||
new Thread(() -> {
|
||||
try {
|
||||
// 准备测试数据
|
||||
prepareTestData();
|
||||
|
||||
// 1. 计算用户输入的嵌入
|
||||
float[] userEmbedding = mBgeEngine.getEmbedding(userInput);
|
||||
if (userEmbedding == null) {
|
||||
if (mListener != null) {
|
||||
mListener.onError("无法计算用户输入的嵌入");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// 创建EJML格式的用户嵌入
|
||||
SimpleMatrix userVec = new SimpleMatrix(userEmbedding.length, 1);
|
||||
for (int i = 0; i < userEmbedding.length; i++) {
|
||||
userVec.set(i, 0, userEmbedding[i]);
|
||||
}
|
||||
|
||||
// 方法1: 使用float数组和循环计算相似度(传统方法)
|
||||
long startTime1 = System.currentTimeMillis();
|
||||
List<Float> similarities1 = new ArrayList<>();
|
||||
|
||||
// 优化传统方法:使用范数倒数和乘法代替除法
|
||||
double normUserTraditional = 0.0;
|
||||
for (float value : userEmbedding) {
|
||||
normUserTraditional += value * value;
|
||||
}
|
||||
normUserTraditional = Math.sqrt(normUserTraditional);
|
||||
double normUserReciprocalTraditional = (normUserTraditional > 1e-10) ? 1.0 / normUserTraditional : 0.0;
|
||||
|
||||
for (int i = 0; i < mTestEmbeddings.size(); i++) {
|
||||
float[] embedding = mTestEmbeddings.get(i);
|
||||
// 计算点积
|
||||
double dotProduct = 0.0;
|
||||
for (int j = 0; j < userEmbedding.length; j++) {
|
||||
dotProduct += userEmbedding[j] * embedding[j];
|
||||
}
|
||||
// 使用乘法代替除法:cos(u,v_i) = (u·v_i) × (1/||u||) × (1/||v_i||)
|
||||
double normEmbeddingReciprocal = mTestEmbeddingsNorms.get(i) > 1e-10 ? 1.0 / mTestEmbeddingsNorms.get(i) : 0.0;
|
||||
float sim = (float) (dotProduct * normUserReciprocalTraditional * normEmbeddingReciprocal);
|
||||
similarities1.add(sim);
|
||||
}
|
||||
long timeTaken1 = System.currentTimeMillis() - startTime1;
|
||||
|
||||
// 方法2: 使用EJML计算相似度(优化方法)
|
||||
long startTime2 = System.currentTimeMillis();
|
||||
List<Float> similarities2 = new ArrayList<>();
|
||||
|
||||
// 优化1: 预计算用户向量的范数和范数倒数
|
||||
double normUser = userVec.normF();
|
||||
double normUserReciprocal = (normUser > 1e-10) ? 1.0 / normUser : 0.0;
|
||||
|
||||
// 优化2: 使用批量矩阵计算一次性计算所有点积和相似度
|
||||
if (mTestEmbeddingsMatrix != null && mTestEmbeddingsNormsReciprocalMatrix != null) {
|
||||
// 使用矩阵乘法一次性计算所有点积
|
||||
SimpleMatrix dotProducts = userVec.transpose().mult(mTestEmbeddingsMatrix);
|
||||
|
||||
// 一次性计算所有相似度
|
||||
// 使用乘法代替除法:cos(u,v_i) = (u·v_i) × (1/||u||) × (1/||v_i||)
|
||||
// 步骤1: 计算 (u·v_i) × (1/||u||)
|
||||
SimpleMatrix dotProductsScaled = dotProducts.scale(normUserReciprocal);
|
||||
// 步骤2: 逐元素乘法计算最终相似度
|
||||
SimpleMatrix similaritiesMatrix = dotProductsScaled.elementMult(mTestEmbeddingsNormsReciprocalMatrix);
|
||||
|
||||
// 将结果转换为列表
|
||||
for (int i = 0; i < similaritiesMatrix.numCols(); i++) {
|
||||
similarities2.add((float) similaritiesMatrix.get(0, i));
|
||||
}
|
||||
} else if (mTestEmbeddingsMatrix != null) {
|
||||
// 降级方案1: 只有嵌入矩阵,使用批量点积 + 循环计算相似度
|
||||
SimpleMatrix dotProducts = userVec.transpose().mult(mTestEmbeddingsMatrix);
|
||||
for (int i = 0; i < dotProducts.numCols(); i++) {
|
||||
double dotProduct = dotProducts.get(0, i);
|
||||
double normEmbeddingReciprocal = mTestEmbeddingsNorms.get(i) > 1e-10 ? 1.0 / mTestEmbeddingsNorms.get(i) : 0.0;
|
||||
float sim = (float) (dotProduct * normUserReciprocal * normEmbeddingReciprocal);
|
||||
similarities2.add(sim);
|
||||
}
|
||||
} else {
|
||||
// 降级方案2: 使用循环计算(当矩阵未初始化时)
|
||||
for (int i = 0; i < mTestEmbeddingsEJML.size(); i++) {
|
||||
SimpleMatrix embedding = mTestEmbeddingsEJML.get(i);
|
||||
double dotProduct = userVec.dot(embedding);
|
||||
double normEmbeddingReciprocal = mTestEmbeddingsNorms.get(i) > 1e-10 ? 1.0 / mTestEmbeddingsNorms.get(i) : 0.0;
|
||||
float sim = (float) (dotProduct * normUserReciprocal * normEmbeddingReciprocal);
|
||||
similarities2.add(sim);
|
||||
}
|
||||
}
|
||||
long timeTaken2 = System.currentTimeMillis() - startTime2;
|
||||
|
||||
// 计算速度提升
|
||||
double speedup = (double) timeTaken1 / timeTaken2;
|
||||
|
||||
if (mListener != null) {
|
||||
mListener.onPerformanceTestComplete(timeTaken1, timeTaken2, speedup);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
Log.e(TAG, "Exception in testPerformance", e);
|
||||
if (mListener != null) {
|
||||
mListener.onError("性能测试失败: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
}).start();
|
||||
}
|
||||
|
||||
// 准备测试数据
|
||||
private void prepareTestData() {
|
||||
if (mTestEmbeddings == null || mTestEmbeddings.isEmpty() ||
|
||||
mTestEmbeddingsEJML == null || mTestEmbeddingsEJML.isEmpty() ||
|
||||
mTestEmbeddingsNorms == null || mTestEmbeddingsNorms.isEmpty() ||
|
||||
mTestEmbeddingsMatrix == null ||
|
||||
mTestEmbeddingsNormsMatrix == null ||
|
||||
mTestEmbeddingsNormsReciprocalMatrix == null) {
|
||||
// 生成100个测试句子
|
||||
generateTestSentences();
|
||||
|
||||
// 预计算嵌入(float数组格式)
|
||||
mTestEmbeddings = new ArrayList<>();
|
||||
mTestEmbeddingsEJML = new ArrayList<>();
|
||||
mTestEmbeddingsNorms = new ArrayList<>();
|
||||
List<String> validTestSentences = new ArrayList<>(); // 只包含成功获取嵌入的句子
|
||||
|
||||
for (String sentence : mTestSentences) {
|
||||
float[] embedding = mBgeEngine.getEmbedding(sentence);
|
||||
if (embedding != null) {
|
||||
// 添加float数组格式的嵌入
|
||||
mTestEmbeddings.add(embedding);
|
||||
|
||||
// 添加EJML格式的嵌入
|
||||
SimpleMatrix vec = new SimpleMatrix(embedding.length, 1);
|
||||
for (int i = 0; i < embedding.length; i++) {
|
||||
vec.set(i, 0, embedding[i]);
|
||||
}
|
||||
mTestEmbeddingsEJML.add(vec);
|
||||
|
||||
// 预计算嵌入的范数
|
||||
double norm = vec.normF();
|
||||
mTestEmbeddingsNorms.add(norm);
|
||||
|
||||
// 添加到有效句子列表
|
||||
validTestSentences.add(sentence);
|
||||
}
|
||||
}
|
||||
|
||||
// 更新mTestSentences为只包含有效句子的列表,确保索引匹配
|
||||
mTestSentences = validTestSentences;
|
||||
|
||||
// 创建测试嵌入矩阵(嵌入维度 × 句子数量)
|
||||
if (!mTestEmbeddingsEJML.isEmpty()) {
|
||||
int embeddingDim = mTestEmbeddingsEJML.get(0).numRows();
|
||||
int numSentences = mTestEmbeddingsEJML.size();
|
||||
mTestEmbeddingsMatrix = new SimpleMatrix(embeddingDim, numSentences);
|
||||
|
||||
for (int i = 0; i < numSentences; i++) {
|
||||
SimpleMatrix embedding = mTestEmbeddingsEJML.get(i);
|
||||
for (int j = 0; j < embeddingDim; j++) {
|
||||
mTestEmbeddingsMatrix.set(j, i, embedding.get(j, 0));
|
||||
}
|
||||
}
|
||||
Log.d(TAG, "Created test embeddings matrix, size: " + embeddingDim + " × " + numSentences);
|
||||
}
|
||||
|
||||
// 创建测试嵌入范数矩阵(1 × 句子数量)
|
||||
if (!mTestEmbeddingsNorms.isEmpty()) {
|
||||
int numSentences = mTestEmbeddingsNorms.size();
|
||||
mTestEmbeddingsNormsMatrix = new SimpleMatrix(1, numSentences);
|
||||
mTestEmbeddingsNormsReciprocalMatrix = new SimpleMatrix(1, numSentences);
|
||||
for (int i = 0; i < numSentences; i++) {
|
||||
double norm = mTestEmbeddingsNorms.get(i);
|
||||
mTestEmbeddingsNormsMatrix.set(0, i, norm);
|
||||
// 预计算范数倒数,避免在相似度计算时重复计算
|
||||
double reciprocal = (norm > 1e-10) ? 1.0 / norm : 0.0;
|
||||
mTestEmbeddingsNormsReciprocalMatrix.set(0, i, reciprocal);
|
||||
}
|
||||
Log.d(TAG, "Created test embeddings norms matrix, size: 1 × " + numSentences);
|
||||
Log.d(TAG, "Created test embeddings norms reciprocal matrix, size: 1 × " + numSentences);
|
||||
}
|
||||
|
||||
Log.d(TAG, "Generated embeddings for " + mTestEmbeddings.size() + " test sentences");
|
||||
}
|
||||
}
|
||||
|
||||
// 生成测试句子
|
||||
private void generateTestSentences() {
|
||||
mTestSentences = new ArrayList<>();
|
||||
|
||||
// 基础词汇
|
||||
String[] words = {"手", "头", "眼睛", "鼻子", "嘴巴", "耳朵", "头发", "手指", "脚", "腿"};
|
||||
String[] adjectives = {"大", "小", "长", "短", "多", "少", "漂亮", "丑陋", "干净", "脏"};
|
||||
String[] verbs = {"有", "是", "看", "听", "说", "走", "跑", "跳", "吃", "喝"};
|
||||
String[] subjects = {"我", "你", "他", "她", "它", "我们", "你们", "他们", "这个", "那个"};
|
||||
|
||||
// 添加一些固定的测试句子
|
||||
mTestSentences.add("我的手有很多手指头");
|
||||
mTestSentences.add("他的头发很长");
|
||||
mTestSentences.add("她的眼睛很大");
|
||||
mTestSentences.add("这个鼻子很高");
|
||||
mTestSentences.add("那个嘴巴很小");
|
||||
|
||||
Log.d(TAG, "Generated " + mTestSentences.size() + " test sentences");
|
||||
}
|
||||
|
||||
// 设置是否使用EJML进行相似度计算
|
||||
public void setUseEJMLForSimilarity(boolean useEJML) {
|
||||
this.useEJMLForSimilarity = useEJML;
|
||||
}
|
||||
|
||||
// 检查BGE模型是否初始化
|
||||
public boolean isInitialized() {
|
||||
return mBgeEngine != null && mBgeEngine.isInitialized();
|
||||
}
|
||||
|
||||
// 释放BGE模型资源
|
||||
public void deinitialize() {
|
||||
if (mBgeEngine != null) {
|
||||
mBgeEngine.deinitialize();
|
||||
mBgeEngine = null;
|
||||
Log.d(TAG, "BGE engine deinitialized");
|
||||
}
|
||||
}
|
||||
}
|
||||
104
app/src/main/java/com/digitalperson/engine/BasicTokenizer.java
Normal file
104
app/src/main/java/com/digitalperson/engine/BasicTokenizer.java
Normal file
@@ -0,0 +1,104 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
package com.digitalperson.engine;
|
||||
|
||||
import com.google.common.base.Ascii;
|
||||
import com.google.common.collect.Iterables;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
/** Basic tokenization (punctuation splitting, lower casing, etc.) */
|
||||
public final class BasicTokenizer {
|
||||
private final boolean doLowerCase;
|
||||
|
||||
public BasicTokenizer(boolean doLowerCase) {
|
||||
this.doLowerCase = doLowerCase;
|
||||
}
|
||||
|
||||
public List<String> tokenize(String text) {
|
||||
String cleanedText = cleanText(text);
|
||||
|
||||
List<String> origTokens = whitespaceTokenize(cleanedText);
|
||||
|
||||
StringBuilder stringBuilder = new StringBuilder();
|
||||
for (String token : origTokens) {
|
||||
if (doLowerCase) {
|
||||
token = Ascii.toLowerCase(token);
|
||||
}
|
||||
List<String> list = runSplitOnPunc(token);
|
||||
for (String subToken : list) {
|
||||
stringBuilder.append(subToken).append(" ");
|
||||
}
|
||||
}
|
||||
return whitespaceTokenize(stringBuilder.toString());
|
||||
}
|
||||
|
||||
/* Performs invalid character removal and whitespace cleanup on text. */
|
||||
static String cleanText(String text) {
|
||||
if (text == null) {
|
||||
throw new NullPointerException("The input String is null.");
|
||||
}
|
||||
|
||||
StringBuilder stringBuilder = new StringBuilder("");
|
||||
for (int index = 0; index < text.length(); index++) {
|
||||
char ch = text.charAt(index);
|
||||
|
||||
// Skip the characters that cannot be used.
|
||||
if (CharChecker.isInvalid(ch) || CharChecker.isControl(ch)) {
|
||||
continue;
|
||||
}
|
||||
if (CharChecker.isWhitespace(ch)) {
|
||||
stringBuilder.append(" ");
|
||||
} else {
|
||||
stringBuilder.append(ch);
|
||||
}
|
||||
}
|
||||
return stringBuilder.toString();
|
||||
}
|
||||
|
||||
/* Runs basic whitespace cleaning and splitting on a piece of text. */
|
||||
static List<String> whitespaceTokenize(String text) {
|
||||
if (text == null) {
|
||||
throw new NullPointerException("The input String is null.");
|
||||
}
|
||||
return Arrays.asList(text.split(" "));
|
||||
}
|
||||
|
||||
/* Splits punctuation on a piece of text. */
|
||||
static List<String> runSplitOnPunc(String text) {
|
||||
if (text == null) {
|
||||
throw new NullPointerException("The input String is null.");
|
||||
}
|
||||
|
||||
List<String> tokens = new ArrayList<>();
|
||||
boolean startNewWord = true;
|
||||
for (int i = 0; i < text.length(); i++) {
|
||||
char ch = text.charAt(i);
|
||||
if (CharChecker.isPunctuation(ch)) {
|
||||
tokens.add(String.valueOf(ch));
|
||||
startNewWord = true;
|
||||
} else {
|
||||
if (startNewWord) {
|
||||
tokens.add("");
|
||||
startNewWord = false;
|
||||
}
|
||||
tokens.set(tokens.size() - 1, Iterables.getLast(tokens) + ch);
|
||||
}
|
||||
}
|
||||
|
||||
return tokens;
|
||||
}
|
||||
}
|
||||
263
app/src/main/java/com/digitalperson/engine/BgeEngineRKNN.java
Normal file
263
app/src/main/java/com/digitalperson/engine/BgeEngineRKNN.java
Normal file
@@ -0,0 +1,263 @@
|
||||
package com.digitalperson.engine;
|
||||
|
||||
import android.content.Context;
|
||||
import android.util.Log;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
public class BgeEngineRKNN {
|
||||
private static final String TAG = "BgeEngineRKNN";
|
||||
private final long nativePtr;
|
||||
private final Context mContext;
|
||||
private boolean mIsInitialized = false;
|
||||
private FullTokenizer mFullTokenizer = null;
|
||||
private Map<String, Integer> mVocabMap = new HashMap<>();
|
||||
|
||||
static {
|
||||
try {
|
||||
// Load dependent libraries
|
||||
System.loadLibrary("rknnrt");
|
||||
System.loadLibrary("bgeEngine");
|
||||
Log.d(TAG, "Successfully loaded librknnrt.so and libbgeEngine.so");
|
||||
} catch (UnsatisfiedLinkError e) {
|
||||
Log.e(TAG, "Failed to load native library", e);
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
public BgeEngineRKNN(Context context) {
|
||||
mContext = context;
|
||||
try {
|
||||
nativePtr = createBgeEngine();
|
||||
if (nativePtr == 0) {
|
||||
throw new RuntimeException("Failed to create native BGE engine");
|
||||
}
|
||||
} catch (UnsatisfiedLinkError e) {
|
||||
Log.e(TAG, "Failed to load native library", e);
|
||||
throw new RuntimeException("Failed to load native library: " + e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
|
||||
public boolean initialize(String modelPath) {
|
||||
// 自动查找词汇表路径
|
||||
String vocabPath = modelPath.replace("bge-small-zh-v1.5.rknn", "vocab.txt");
|
||||
return initialize(modelPath, vocabPath);
|
||||
}
|
||||
|
||||
public boolean initialize(String modelPath, String vocabPath) {
|
||||
if (mIsInitialized) {
|
||||
Log.i(TAG, "Model already initialized");
|
||||
return true;
|
||||
}
|
||||
|
||||
Log.d(TAG, "Loading BGE model: " + modelPath);
|
||||
Log.d(TAG, "Loading vocab: " + vocabPath);
|
||||
|
||||
// 加载词汇表
|
||||
if (!loadVocab(vocabPath)) {
|
||||
Log.e(TAG, "Failed to load vocab");
|
||||
return false;
|
||||
}
|
||||
|
||||
// 创建FullTokenizer实例
|
||||
mFullTokenizer = new FullTokenizer(mVocabMap, true); // true表示使用小写
|
||||
|
||||
int ret = loadModel(nativePtr, modelPath, vocabPath);
|
||||
if (ret == 0) {
|
||||
mIsInitialized = true;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load vocabulary from file
|
||||
* @param vocabPath Path to vocabulary file
|
||||
* @return True if successful, false otherwise
|
||||
*/
|
||||
private boolean loadVocab(String vocabPath) {
|
||||
try {
|
||||
java.io.BufferedReader reader = new java.io.BufferedReader(
|
||||
new java.io.FileReader(vocabPath));
|
||||
String line;
|
||||
int index = 0;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
line = line.trim();
|
||||
if (!line.isEmpty()) {
|
||||
mVocabMap.put(line, index);
|
||||
index++;
|
||||
}
|
||||
}
|
||||
reader.close();
|
||||
Log.d(TAG, "Vocab loaded successfully with " + mVocabMap.size() + " tokens");
|
||||
return true;
|
||||
} catch (java.io.IOException e) {
|
||||
Log.e(TAG, "Failed to load vocab: " + e.getMessage());
|
||||
e.printStackTrace();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public void deinitialize() {
|
||||
if (nativePtr != 0) {
|
||||
freeModel(nativePtr);
|
||||
}
|
||||
mIsInitialized = false;
|
||||
}
|
||||
|
||||
public boolean isInitialized() {
|
||||
return mIsInitialized;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the embedding dimension of the model
|
||||
* @return Embedding dimension
|
||||
*/
|
||||
public int getEmbeddingDim() {
|
||||
if (!mIsInitialized) {
|
||||
Log.e(TAG, "Engine not initialized");
|
||||
return -1;
|
||||
}
|
||||
return getEmbeddingDim(nativePtr);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate embedding for a single text
|
||||
* @param text Input text
|
||||
* @return Embedding vector as float array
|
||||
*/
|
||||
public float[] getEmbedding(String text) {
|
||||
if (!mIsInitialized) {
|
||||
Log.e(TAG, "Engine not initialized");
|
||||
return null;
|
||||
}
|
||||
return getEmbeddingNative(nativePtr, text);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate cosine similarity between two texts
|
||||
* @param text1 First text
|
||||
* @param text2 Second text
|
||||
* @return Cosine similarity score between -1.0 and 1.0
|
||||
*/
|
||||
public float calculateSimilarity(String text1, String text2) {
|
||||
if (!mIsInitialized) {
|
||||
Log.e(TAG, "Engine not initialized");
|
||||
return 0.0f;
|
||||
}
|
||||
return calculateSimilarityNative(nativePtr, text1, text2);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get token ids for a text using FullTokenizer
|
||||
* @param text Input text
|
||||
* @return Token ids as int array
|
||||
*/
|
||||
public int[] getTokens(String text) {
|
||||
if (!mIsInitialized) {
|
||||
Log.e(TAG, "Engine not initialized");
|
||||
return null;
|
||||
}
|
||||
if (mFullTokenizer == null) {
|
||||
Log.e(TAG, "FullTokenizer not initialized");
|
||||
return null;
|
||||
}
|
||||
|
||||
// 使用FullTokenizer进行tokenization
|
||||
java.util.List<String> tokens = mFullTokenizer.tokenize(text);
|
||||
java.util.List<Integer> ids = mFullTokenizer.convertTokensToIds(tokens);
|
||||
|
||||
// 转换为int数组
|
||||
int[] result = new int[ids.size()];
|
||||
for (int i = 0; i < ids.size(); i++) {
|
||||
result[i] = ids.get(i);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate [UNK] ratio for token ids
|
||||
* @param tokens Token ids array
|
||||
* @return [UNK] ratio between 0.0 and 1.0
|
||||
*/
|
||||
public float calculateUnkRatio(int[] tokens) {
|
||||
if (!mIsInitialized) {
|
||||
Log.e(TAG, "Engine not initialized");
|
||||
return 0.0f;
|
||||
}
|
||||
if (mFullTokenizer == null || tokens == null || tokens.length == 0) {
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
// 获取[UNK]的id
|
||||
int unkId = mVocabMap.getOrDefault("[UNK]", -1);
|
||||
if (unkId == -1) {
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
// 统计[UNK]的数量
|
||||
int unkCount = 0;
|
||||
for (int token : tokens) {
|
||||
if (token == unkId) {
|
||||
unkCount++;
|
||||
}
|
||||
}
|
||||
|
||||
return (float) unkCount / tokens.length;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get BERT model inputs (input_ids, attention_mask, token_type_ids)
|
||||
* @param text Input text
|
||||
* @param maxSeqLength Maximum sequence length
|
||||
* @return Array of int arrays: [input_ids, attention_mask, token_type_ids]
|
||||
*/
|
||||
public int[][] getBertInputs(String text, int maxSeqLength) {
|
||||
if (!mIsInitialized) {
|
||||
Log.e(TAG, "Engine not initialized");
|
||||
return null;
|
||||
}
|
||||
if (mFullTokenizer == null) {
|
||||
Log.e(TAG, "FullTokenizer not initialized");
|
||||
return null;
|
||||
}
|
||||
|
||||
// 使用FullTokenizer进行tokenization
|
||||
java.util.List<String> tokens = mFullTokenizer.tokenize(text);
|
||||
java.util.List<Integer> ids = mFullTokenizer.convertTokensToIds(tokens);
|
||||
|
||||
// 准备BERT输入
|
||||
int[] inputIds = new int[maxSeqLength];
|
||||
int[] attentionMask = new int[maxSeqLength];
|
||||
int[] tokenTypeIds = new int[maxSeqLength];
|
||||
|
||||
// 填充[CLS] token
|
||||
inputIds[0] = mVocabMap.getOrDefault("[CLS]", 101);
|
||||
attentionMask[0] = 1;
|
||||
|
||||
// 填充文本token
|
||||
int textLength = ids.size();
|
||||
int maxTextLength = maxSeqLength - 2; // 减去[CLS]和[SEP]
|
||||
int actualLength = Math.min(textLength, maxTextLength);
|
||||
|
||||
for (int i = 0; i < actualLength; i++) {
|
||||
inputIds[i + 1] = ids.get(i);
|
||||
attentionMask[i + 1] = 1;
|
||||
}
|
||||
|
||||
// 填充[SEP] token
|
||||
inputIds[actualLength + 1] = mVocabMap.getOrDefault("[SEP]", 102);
|
||||
attentionMask[actualLength + 1] = 1;
|
||||
|
||||
return new int[][]{inputIds, attentionMask, tokenTypeIds};
|
||||
}
|
||||
|
||||
// Native methods
|
||||
private native long createBgeEngine();
|
||||
private native int loadModel(long nativePtr, String modelPath, String vocabPath);
|
||||
private native void freeModel(long ptr);
|
||||
private native float[] getEmbeddingNative(long ptr, String text);
|
||||
private native int getEmbeddingDim(long ptr);
|
||||
private native float calculateSimilarityNative(long ptr, String text1, String text2);
|
||||
}
|
||||
58
app/src/main/java/com/digitalperson/engine/CharChecker.java
Normal file
58
app/src/main/java/com/digitalperson/engine/CharChecker.java
Normal file
@@ -0,0 +1,58 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
package com.digitalperson.engine;
|
||||
|
||||
/** To check whether a char is whitespace/control/punctuation. */
|
||||
final class CharChecker {
|
||||
|
||||
/** To judge whether it's an empty or unknown character. */
|
||||
public static boolean isInvalid(char ch) {
|
||||
return (ch == 0 || ch == 0xfffd);
|
||||
}
|
||||
|
||||
/** To judge whether it's a control character(exclude whitespace). */
|
||||
public static boolean isControl(char ch) {
|
||||
if (Character.isWhitespace(ch)) {
|
||||
return false;
|
||||
}
|
||||
int type = Character.getType(ch);
|
||||
return (type == Character.CONTROL || type == Character.FORMAT);
|
||||
}
|
||||
|
||||
/** To judge whether it can be regarded as a whitespace. */
|
||||
public static boolean isWhitespace(char ch) {
|
||||
if (Character.isWhitespace(ch)) {
|
||||
return true;
|
||||
}
|
||||
int type = Character.getType(ch);
|
||||
return (type == Character.SPACE_SEPARATOR
|
||||
|| type == Character.LINE_SEPARATOR
|
||||
|| type == Character.PARAGRAPH_SEPARATOR);
|
||||
}
|
||||
|
||||
/** To judge whether it's a punctuation. */
|
||||
public static boolean isPunctuation(char ch) {
|
||||
int type = Character.getType(ch);
|
||||
return (type == Character.CONNECTOR_PUNCTUATION
|
||||
|| type == Character.DASH_PUNCTUATION
|
||||
|| type == Character.START_PUNCTUATION
|
||||
|| type == Character.END_PUNCTUATION
|
||||
|| type == Character.INITIAL_QUOTE_PUNCTUATION
|
||||
|| type == Character.FINAL_QUOTE_PUNCTUATION
|
||||
|| type == Character.OTHER_PUNCTUATION);
|
||||
}
|
||||
|
||||
private CharChecker() {}
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
package com.digitalperson.engine;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* A java realization of Bert tokenization. Original python code:
|
||||
* https://github.com/google-research/bert/blob/master/tokenization.py runs full tokenization to
|
||||
* tokenize a String into split subtokens or ids.
|
||||
*/
|
||||
public final class FullTokenizer {
|
||||
private final BasicTokenizer basicTokenizer;
|
||||
private final WordpieceTokenizer wordpieceTokenizer;
|
||||
private final Map<String, Integer> dic;
|
||||
|
||||
public FullTokenizer(Map<String, Integer> inputDic, boolean doLowerCase) {
|
||||
dic = inputDic;
|
||||
basicTokenizer = new BasicTokenizer(doLowerCase);
|
||||
wordpieceTokenizer = new WordpieceTokenizer(inputDic);
|
||||
}
|
||||
|
||||
public List<String> tokenize(String text) {
|
||||
List<String> splitTokens = new ArrayList<>();
|
||||
for (String token : basicTokenizer.tokenize(text)) {
|
||||
splitTokens.addAll(wordpieceTokenizer.tokenize(token));
|
||||
}
|
||||
return splitTokens;
|
||||
}
|
||||
|
||||
public List<Integer> convertTokensToIds(List<String> tokens) {
|
||||
List<Integer> outputIds = new ArrayList<>();
|
||||
for (String token : tokens) {
|
||||
outputIds.add(dic.get(token));
|
||||
}
|
||||
return outputIds;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
package com.digitalperson.engine;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/** Word piece tokenization to split a piece of text into its word pieces. */
|
||||
public final class WordpieceTokenizer {
|
||||
private final Map<String, Integer> dic;
|
||||
|
||||
private static final String UNKNOWN_TOKEN = "[UNK]"; // For unknown words.
|
||||
private static final int MAX_INPUTCHARS_PER_WORD = 200;
|
||||
|
||||
public WordpieceTokenizer(Map<String, Integer> vocab) {
|
||||
dic = vocab;
|
||||
}
|
||||
|
||||
/**
|
||||
* Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first
|
||||
* algorithm to perform tokenization using the given vocabulary. For example: input = "unaffable",
|
||||
* output = ["un", "##aff", "##able"].
|
||||
*
|
||||
* @param text: A single token or whitespace separated tokens. This should have already been
|
||||
* passed through `BasicTokenizer.
|
||||
* @return A list of wordpiece tokens.
|
||||
*/
|
||||
public List<String> tokenize(String text) {
|
||||
if (text == null) {
|
||||
throw new NullPointerException("The input String is null.");
|
||||
}
|
||||
|
||||
List<String> outputTokens = new ArrayList<>();
|
||||
for (String token : BasicTokenizer.whitespaceTokenize(text)) {
|
||||
|
||||
if (token.length() > MAX_INPUTCHARS_PER_WORD) {
|
||||
outputTokens.add(UNKNOWN_TOKEN);
|
||||
continue;
|
||||
}
|
||||
|
||||
boolean isBad = false; // Mark if a word cannot be tokenized into known subwords.
|
||||
int start = 0;
|
||||
List<String> subTokens = new ArrayList<>();
|
||||
|
||||
while (start < token.length()) {
|
||||
String curSubStr = "";
|
||||
|
||||
int end = token.length(); // Longer substring matches first.
|
||||
while (start < end) {
|
||||
String subStr =
|
||||
(start == 0) ? token.substring(start, end) : "##" + token.substring(start, end);
|
||||
if (dic.containsKey(subStr)) {
|
||||
curSubStr = subStr;
|
||||
break;
|
||||
}
|
||||
end--;
|
||||
}
|
||||
|
||||
// The word doesn't contain any known subwords.
|
||||
if ("".equals(curSubStr)) {
|
||||
isBad = true;
|
||||
break;
|
||||
}
|
||||
|
||||
// curSubStr is the longeset subword that can be found.
|
||||
subTokens.add(curSubStr);
|
||||
|
||||
// Proceed to tokenize the resident string.
|
||||
start = end;
|
||||
}
|
||||
|
||||
if (isBad) {
|
||||
outputTokens.add(UNKNOWN_TOKEN);
|
||||
} else {
|
||||
outputTokens.addAll(subTokens);
|
||||
}
|
||||
}
|
||||
|
||||
return outputTokens;
|
||||
}
|
||||
}
|
||||
@@ -185,7 +185,7 @@ class QuestionGenerationAgent(
|
||||
val generationPrompt = buildGenerationPrompt(prompt, userProfile)
|
||||
|
||||
// 3. 调用大模型生成题目
|
||||
generateQuestionFromLLM(generationPrompt) { generatedQuestion ->
|
||||
generateQuestionFromLLM(generationPrompt, prompt) { generatedQuestion ->
|
||||
if (generatedQuestion == null) {
|
||||
Log.w(TAG, "Failed to generate question")
|
||||
return@generateQuestionFromLLM
|
||||
@@ -301,18 +301,22 @@ class QuestionGenerationAgent(
|
||||
/**
|
||||
* 调用LLM生成题目
|
||||
*/
|
||||
private fun generateQuestionFromLLM(prompt: String, onResult: (GeneratedQuestion?) -> Unit) {
|
||||
private fun generateQuestionFromLLM(
|
||||
promptText: String,
|
||||
promptMeta: QuestionPrompt,
|
||||
onResult: (GeneratedQuestion?) -> Unit
|
||||
) {
|
||||
// 优先使用本地LLM,如果不可用则使用云端LLM
|
||||
if (llmManager != null) {
|
||||
// 使用本地LLM
|
||||
llmManager.generate(prompt) { response: String ->
|
||||
parseGeneratedQuestion(response, onResult)
|
||||
llmManager.generate(promptText) { response: String ->
|
||||
parseGeneratedQuestion(response, promptMeta, onResult)
|
||||
}
|
||||
} else if (cloudLLMGenerator != null) {
|
||||
// 使用云端LLM
|
||||
Log.d(TAG, "Using cloud LLM to generate question")
|
||||
cloudLLMGenerator.invoke(prompt) { response ->
|
||||
parseGeneratedQuestion(response, onResult)
|
||||
cloudLLMGenerator.invoke(promptText) { response ->
|
||||
parseGeneratedQuestion(response, promptMeta, onResult)
|
||||
}
|
||||
} else {
|
||||
Log.e(TAG, "No LLM available (neither local nor cloud)")
|
||||
@@ -323,16 +327,20 @@ class QuestionGenerationAgent(
|
||||
/**
|
||||
* 解析生成的题目
|
||||
*/
|
||||
private fun parseGeneratedQuestion(response: String, onResult: (GeneratedQuestion?) -> Unit) {
|
||||
private fun parseGeneratedQuestion(
|
||||
response: String,
|
||||
promptMeta: QuestionPrompt,
|
||||
onResult: (GeneratedQuestion?) -> Unit
|
||||
) {
|
||||
try {
|
||||
val json = extractJsonFromResponse(response)
|
||||
if (json != null) {
|
||||
val question = GeneratedQuestion(
|
||||
content = json.getString("content"),
|
||||
answer = json.getString("answer"),
|
||||
subject = "生活适应",
|
||||
grade = 1,
|
||||
difficulty = 1
|
||||
subject = promptMeta.subject,
|
||||
grade = promptMeta.grade,
|
||||
difficulty = promptMeta.difficulty
|
||||
)
|
||||
onResult(question)
|
||||
} else {
|
||||
|
||||
@@ -7,6 +7,7 @@ import android.util.Log
|
||||
import com.digitalperson.config.AppConfig
|
||||
import java.io.File
|
||||
import java.io.FileOutputStream
|
||||
import java.io.InputStream
|
||||
|
||||
object FileHelper {
|
||||
private const val TAG = AppConfig.TAG
|
||||
@@ -65,6 +66,54 @@ object FileHelper {
|
||||
return copyAssetsToInternal(context, AppConfig.FaceRecognition.MODEL_DIR, outDir, files)
|
||||
}
|
||||
|
||||
/**
|
||||
* 将 BGE 相关文件从 assets 复制到 [Context.getFilesDir]/[AppConfig.Bge.ASSET_DIR]。
|
||||
* 若已存在且长度与 asset 一致则跳过(与 [com.digitalperson.embedding.SimilarityManager] 行为一致)。
|
||||
*/
|
||||
@JvmStatic
|
||||
fun copyBgeModels(context: Context): File? {
|
||||
val assetDir = AppConfig.Bge.ASSET_DIR
|
||||
val modelDir = File(context.filesDir, assetDir).apply { mkdirs() }
|
||||
val files = arrayOf(
|
||||
AppConfig.Bge.MODEL_FILE,
|
||||
"vocab.txt",
|
||||
"tokenizer.json",
|
||||
"tokenizer_config.json"
|
||||
)
|
||||
for (name in files) {
|
||||
val outFile = File(modelDir, name)
|
||||
val assetPath = "$assetDir/$name"
|
||||
try {
|
||||
var skip = false
|
||||
if (outFile.exists()) {
|
||||
context.assets.open(assetPath).use { input ->
|
||||
val assetSize = assetSizeOrNegative(input)
|
||||
if (assetSize >= 0 && outFile.length() == assetSize) skip = true
|
||||
}
|
||||
}
|
||||
if (skip) continue
|
||||
context.assets.open(assetPath).use { input ->
|
||||
FileOutputStream(outFile).use { output -> input.copyTo(output) }
|
||||
}
|
||||
Log.i(TAG, "Copied BGE asset: $name")
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "copyBgeModels failed for $name: ${e.message}", e)
|
||||
return null
|
||||
}
|
||||
}
|
||||
return modelDir
|
||||
}
|
||||
|
||||
/** [InputStream.available] 在部分实现上不可靠;失败时返回 -1,强制重拷。 */
|
||||
private fun assetSizeOrNegative(input: InputStream): Long {
|
||||
return try {
|
||||
val n = input.available()
|
||||
if (n > 0) n.toLong() else -1L
|
||||
} catch (_: Exception) {
|
||||
-1L
|
||||
}
|
||||
}
|
||||
|
||||
fun ensureDir(dir: File): File {
|
||||
if (!dir.exists()) {
|
||||
val created = dir.mkdirs()
|
||||
|
||||
@@ -38,6 +38,18 @@
|
||||
android:layout_height="match_parent" />
|
||||
</FrameLayout>
|
||||
|
||||
<ImageView
|
||||
android:id="@+id/ref_match_image"
|
||||
android:layout_width="200dp"
|
||||
android:layout_height="200dp"
|
||||
android:layout_margin="12dp"
|
||||
android:background="#66000000"
|
||||
android:scaleType="fitCenter"
|
||||
android:visibility="gone"
|
||||
app:layout_constraintBottom_toTopOf="@+id/button_row"
|
||||
app:layout_constraintStart_toStartOf="parent"
|
||||
tools:visibility="visible" />
|
||||
|
||||
<ScrollView
|
||||
android:id="@+id/scroll_view"
|
||||
android:layout_width="0dp"
|
||||
|
||||
@@ -31,6 +31,18 @@
|
||||
android:textIsSelectable="true" />
|
||||
</ScrollView>
|
||||
|
||||
<ImageView
|
||||
android:id="@+id/ref_match_image"
|
||||
android:layout_width="200dp"
|
||||
android:layout_height="200dp"
|
||||
android:layout_margin="12dp"
|
||||
android:background="#66000000"
|
||||
android:scaleType="fitCenter"
|
||||
android:visibility="gone"
|
||||
app:layout_constraintBottom_toTopOf="@+id/button_row"
|
||||
app:layout_constraintStart_toStartOf="parent"
|
||||
tools:visibility="visible" />
|
||||
|
||||
<LinearLayout
|
||||
android:id="@+id/button_row"
|
||||
android:layout_width="0dp"
|
||||
|
||||
Reference in New Issue
Block a user