word similarity
This commit is contained in:
@@ -4,10 +4,24 @@ plugins {
|
|||||||
id 'kotlin-kapt'
|
id 'kotlin-kapt'
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kapt {
|
||||||
|
// Room uses javac stubs under kapt; keep parameter names for :bind variables.
|
||||||
|
javacOptions {
|
||||||
|
option("-parameters")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
android {
|
android {
|
||||||
namespace 'com.digitalperson'
|
namespace 'com.digitalperson'
|
||||||
compileSdk 34
|
compileSdk 34
|
||||||
|
|
||||||
|
sourceSets {
|
||||||
|
main {
|
||||||
|
// app/note/ref → assets 中为 ref/...(与 AppConfig.RefCorpus.ASSETS_ROOT 一致)
|
||||||
|
assets.srcDirs = ['src/main/assets', 'note']
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
buildFeatures {
|
buildFeatures {
|
||||||
buildConfig true
|
buildConfig true
|
||||||
}
|
}
|
||||||
@@ -100,4 +114,9 @@ dependencies {
|
|||||||
|
|
||||||
implementation project(':tuanjieLibrary')
|
implementation project(':tuanjieLibrary')
|
||||||
implementation files('../tuanjieLibrary/libs/unity-classes.jar')
|
implementation files('../tuanjieLibrary/libs/unity-classes.jar')
|
||||||
|
|
||||||
|
// BGE tokenizer (BasicTokenizer) + SimilarityManager 批量相似度测试
|
||||||
|
implementation 'com.google.guava:guava:31.1-android'
|
||||||
|
implementation 'org.ejml:ejml-core:0.43.1'
|
||||||
|
implementation 'org.ejml:ejml-simple:0.43.1'
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@
|
|||||||
<uses-feature android:name="android.hardware.camera.any" />
|
<uses-feature android:name="android.hardware.camera.any" />
|
||||||
|
|
||||||
<application
|
<application
|
||||||
|
android:name="com.digitalperson.DigitalPersonApp"
|
||||||
android:allowBackup="true"
|
android:allowBackup="true"
|
||||||
android:label="@string/app_name"
|
android:label="@string/app_name"
|
||||||
android:supportsRtl="true"
|
android:supportsRtl="true"
|
||||||
|
|||||||
5
app/src/main/assets/bge_models/PLACE_RKNN_HERE.txt
Normal file
5
app/src/main/assets/bge_models/PLACE_RKNN_HERE.txt
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
Required assets for BGE (see AppConfig.Bge):
|
||||||
|
bge-small-zh-v1.5.rknn, vocab.txt, tokenizer.json, tokenizer_config.json
|
||||||
|
|
||||||
|
First run: FileHelper.copyBgeModels / BgeEmbedding.initialize / SimilarityManager.initBgeModel
|
||||||
|
copies from assets/bge_models to internal storage.
|
||||||
BIN
app/src/main/assets/bge_models/bge-small-zh-v1.5.rknn
Normal file
BIN
app/src/main/assets/bge_models/bge-small-zh-v1.5.rknn
Normal file
Binary file not shown.
21278
app/src/main/assets/bge_models/tokenizer.json
Normal file
21278
app/src/main/assets/bge_models/tokenizer.json
Normal file
File diff suppressed because it is too large
Load Diff
15
app/src/main/assets/bge_models/tokenizer_config.json
Normal file
15
app/src/main/assets/bge_models/tokenizer_config.json
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
{
|
||||||
|
"clean_up_tokenization_spaces": true,
|
||||||
|
"cls_token": "[CLS]",
|
||||||
|
"do_basic_tokenize": true,
|
||||||
|
"do_lower_case": false,
|
||||||
|
"mask_token": "[MASK]",
|
||||||
|
"model_max_length": 512,
|
||||||
|
"never_split": null,
|
||||||
|
"pad_token": "[PAD]",
|
||||||
|
"sep_token": "[SEP]",
|
||||||
|
"strip_accents": null,
|
||||||
|
"tokenize_chinese_chars": true,
|
||||||
|
"tokenizer_class": "BertTokenizer",
|
||||||
|
"unk_token": "[UNK]"
|
||||||
|
}
|
||||||
21128
app/src/main/assets/bge_models/vocab.txt
Normal file
21128
app/src/main/assets/bge_models/vocab.txt
Normal file
File diff suppressed because it is too large
Load Diff
504
app/src/main/cpp/BgeEngineRKNN.cpp
Normal file
504
app/src/main/cpp/BgeEngineRKNN.cpp
Normal file
@@ -0,0 +1,504 @@
|
|||||||
|
#include "BgeEngineRKNN.h"
|
||||||
|
#include <android/log.h>
|
||||||
|
#include <cstring>
|
||||||
|
#include <iostream>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <fstream>
|
||||||
|
#include <sstream>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#define LOG_TAG "BgeEngineRKNN"
|
||||||
|
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__)
|
||||||
|
#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__)
|
||||||
|
#define LOGD(...) __android_log_print(ANDROID_LOG_DEBUG, LOG_TAG, __VA_ARGS__)
|
||||||
|
|
||||||
|
BgeEngineRKNN::BgeEngineRKNN() : m_ctx(0), m_initialized(false), m_embedding_dim(768) {
|
||||||
|
// 默认嵌入维度为768,可根据模型实际情况调整
|
||||||
|
}
|
||||||
|
|
||||||
|
BgeEngineRKNN::~BgeEngineRKNN() {
|
||||||
|
freeModel();
|
||||||
|
}
|
||||||
|
|
||||||
|
int BgeEngineRKNN::loadModel(const char* modelPath, const char* vocabPath) {
|
||||||
|
if (m_initialized) {
|
||||||
|
LOGI("Model already loaded");
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 加载词汇表(如果提供了路径)
|
||||||
|
if (vocabPath != nullptr) {
|
||||||
|
int vocab_ret = loadVocab(vocabPath);
|
||||||
|
if (vocab_ret != 0) {
|
||||||
|
LOGE("Failed to load vocab: %d", vocab_ret);
|
||||||
|
return vocab_ret;
|
||||||
|
}
|
||||||
|
LOGI("Vocab loaded successfully with %zu tokens", m_vocab.size());
|
||||||
|
} else {
|
||||||
|
LOGE("No vocab path provided, using fallback tokenization");
|
||||||
|
}
|
||||||
|
|
||||||
|
int ret = initRKNN(modelPath);
|
||||||
|
if (ret != 0) {
|
||||||
|
LOGE("Failed to initialize RKNN: %d", ret);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
m_initialized = true;
|
||||||
|
LOGI("BGE model loaded successfully");
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void BgeEngineRKNN::freeModel() {
|
||||||
|
if (m_initialized) {
|
||||||
|
deinitRKNN();
|
||||||
|
m_initialized = false;
|
||||||
|
LOGI("BGE model freed");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
float* BgeEngineRKNN::getEmbedding(const std::string& text) {
|
||||||
|
if (!m_initialized) {
|
||||||
|
LOGE("Model not initialized");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 简单的tokenize实现(需要替换为完整的tokenizer)
|
||||||
|
std::vector<int> tokens = tokenize(text);
|
||||||
|
if (tokens.empty()) {
|
||||||
|
LOGE("Tokenization failed");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 推理获取嵌入
|
||||||
|
float* embedding = inferEmbedding(tokens);
|
||||||
|
if (embedding == nullptr) {
|
||||||
|
LOGE("Failed to get embedding");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
return embedding;
|
||||||
|
}
|
||||||
|
|
||||||
|
float BgeEngineRKNN::calculateSimilarity(const std::string& text1, const std::string& text2) {
|
||||||
|
if (!m_initialized) {
|
||||||
|
LOGE("Model not initialized");
|
||||||
|
return 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取两个文本的嵌入
|
||||||
|
float* embedding1 = getEmbedding(text1);
|
||||||
|
if (embedding1 == nullptr) {
|
||||||
|
return 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
float* embedding2 = getEmbedding(text2);
|
||||||
|
if (embedding2 == nullptr) {
|
||||||
|
delete[] embedding1;
|
||||||
|
return 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 归一化嵌入
|
||||||
|
std::vector<float> normalized1 = normalizeEmbedding(embedding1, m_embedding_dim);
|
||||||
|
std::vector<float> normalized2 = normalizeEmbedding(embedding2, m_embedding_dim);
|
||||||
|
|
||||||
|
// 计算余弦相似度
|
||||||
|
float similarity = cosineSimilarity(normalized1, normalized2);
|
||||||
|
|
||||||
|
// 释放内存
|
||||||
|
delete[] embedding1;
|
||||||
|
delete[] embedding2;
|
||||||
|
|
||||||
|
return similarity;
|
||||||
|
}
|
||||||
|
|
||||||
|
int BgeEngineRKNN::getEmbeddingDim() const {
|
||||||
|
return m_embedding_dim;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> BgeEngineRKNN::tokenize(const std::string& text) {
|
||||||
|
std::vector<int> tokens;
|
||||||
|
tokens.push_back(101); // [CLS] token
|
||||||
|
|
||||||
|
// 如果词汇表已加载,使用完整的tokenization流程
|
||||||
|
if (!m_vocab.empty()) {
|
||||||
|
// 步骤1: 基本分词
|
||||||
|
std::vector<std::string> basic_tokens = basicTokenize(text);
|
||||||
|
|
||||||
|
// 步骤2: WordPiece分词
|
||||||
|
std::vector<std::string> word_pieces = wordPieceTokenize(basic_tokens);
|
||||||
|
|
||||||
|
// 步骤3: 转换为token id
|
||||||
|
for (const std::string& piece : word_pieces) {
|
||||||
|
if (tokens.size() > 510) break; // 限制最大长度
|
||||||
|
|
||||||
|
auto it = m_vocab.find(piece);
|
||||||
|
if (it != m_vocab.end()) {
|
||||||
|
tokens.push_back(it->second);
|
||||||
|
} else {
|
||||||
|
// 对于未在词表中的token,使用[UNK] token
|
||||||
|
auto unk_it = m_vocab.find("[UNK]");
|
||||||
|
if (unk_it != m_vocab.end()) {
|
||||||
|
tokens.push_back(unk_it->second);
|
||||||
|
} else {
|
||||||
|
// 如果[UNK]也不存在,使用0作为fallback
|
||||||
|
tokens.push_back(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// fallback: 使用字符级分词(当词汇表未加载时)
|
||||||
|
size_t i = 0;
|
||||||
|
while (i < text.size() && tokens.size() <= 510) {
|
||||||
|
unsigned char c = static_cast<unsigned char>(text[i]);
|
||||||
|
int token_id;
|
||||||
|
|
||||||
|
// 处理不同长度的UTF-8字符
|
||||||
|
if (c < 0x80) {
|
||||||
|
// ASCII字符 (1字节)
|
||||||
|
token_id = static_cast<int>(c) + 100;
|
||||||
|
i++;
|
||||||
|
} else if (c < 0xE0) {
|
||||||
|
// 2字节UTF-8字符
|
||||||
|
token_id = 3000 + ((c & 0x1F) << 6) + (static_cast<unsigned char>(text[i+1]) & 0x3F);
|
||||||
|
i += 2;
|
||||||
|
} else if (c < 0xF0) {
|
||||||
|
// 3字节UTF-8字符(常见中文)
|
||||||
|
token_id = 4000 + ((c & 0x0F) << 12) + ((static_cast<unsigned char>(text[i+1]) & 0x3F) << 6) + (static_cast<unsigned char>(text[i+2]) & 0x3F);
|
||||||
|
i += 3;
|
||||||
|
} else {
|
||||||
|
// 4字节UTF-8字符
|
||||||
|
token_id = 8000 + ((c & 0x07) << 18) + ((static_cast<unsigned char>(text[i+1]) & 0x3F) << 12) + ((static_cast<unsigned char>(text[i+2]) & 0x3F) << 6) + (static_cast<unsigned char>(text[i+3]) & 0x3F);
|
||||||
|
i += 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 确保token_id在合理范围内
|
||||||
|
token_id = token_id % 30000;
|
||||||
|
tokens.push_back(token_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens.push_back(102); // [SEP] token
|
||||||
|
return tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
int BgeEngineRKNN::loadVocab(const char* vocabPath) {
|
||||||
|
std::ifstream vocab_file(vocabPath);
|
||||||
|
if (!vocab_file.is_open()) {
|
||||||
|
LOGE("Failed to open vocab file: %s", vocabPath);
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string line;
|
||||||
|
int id = 0;
|
||||||
|
while (std::getline(vocab_file, line)) {
|
||||||
|
if (!line.empty()) {
|
||||||
|
m_vocab[line] = id;
|
||||||
|
id++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
vocab_file.close();
|
||||||
|
|
||||||
|
if (m_vocab.empty()) {
|
||||||
|
LOGE("Vocab file is empty");
|
||||||
|
return -2;
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> BgeEngineRKNN::basicTokenize(const std::string& text) {
|
||||||
|
std::vector<std::string> tokens;
|
||||||
|
|
||||||
|
size_t i = 0;
|
||||||
|
while (i < text.size()) {
|
||||||
|
unsigned char c = static_cast<unsigned char>(text[i]);
|
||||||
|
|
||||||
|
if (std::isspace(c)) {
|
||||||
|
// 跳过空格
|
||||||
|
i++;
|
||||||
|
} else if (c < 0x80) {
|
||||||
|
// ASCII字符,按空格切分
|
||||||
|
std::string token;
|
||||||
|
while (i < text.size() && !std::isspace(static_cast<unsigned char>(text[i])) && static_cast<unsigned char>(text[i]) < 0x80) {
|
||||||
|
token += text[i];
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
if (!token.empty()) {
|
||||||
|
tokens.push_back(token);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 非ASCII字符(如中文),按字符切分
|
||||||
|
// 处理UTF-8编码的字符
|
||||||
|
if (c < 0xE0) {
|
||||||
|
// 2字节UTF-8字符
|
||||||
|
if (i + 1 < text.size()) {
|
||||||
|
tokens.push_back(text.substr(i, 2));
|
||||||
|
i += 2;
|
||||||
|
} else {
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
} else if (c < 0xF0) {
|
||||||
|
// 3字节UTF-8字符(常见中文)
|
||||||
|
if (i + 2 < text.size()) {
|
||||||
|
tokens.push_back(text.substr(i, 3));
|
||||||
|
i += 3;
|
||||||
|
} else {
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 4字节UTF-8字符
|
||||||
|
if (i + 3 < text.size()) {
|
||||||
|
tokens.push_back(text.substr(i, 4));
|
||||||
|
i += 4;
|
||||||
|
} else {
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> BgeEngineRKNN::wordPieceTokenize(const std::vector<std::string>& basic_tokens) {
|
||||||
|
std::vector<std::string> word_pieces;
|
||||||
|
|
||||||
|
for (const std::string& token : basic_tokens) {
|
||||||
|
if (m_vocab.find(token) != m_vocab.end()) {
|
||||||
|
word_pieces.push_back(token);
|
||||||
|
} else {
|
||||||
|
// 简单的字符级分词作为 fallback
|
||||||
|
for (char c : token) {
|
||||||
|
word_pieces.push_back(std::string(1, c));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return word_pieces;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> BgeEngineRKNN::normalizeEmbedding(const float* embedding, int dim) {
|
||||||
|
std::vector<float> normalized(dim);
|
||||||
|
float norm = 0.0f;
|
||||||
|
|
||||||
|
// 计算L2范数
|
||||||
|
for (int i = 0; i < dim; i++) {
|
||||||
|
norm += embedding[i] * embedding[i];
|
||||||
|
}
|
||||||
|
norm = std::sqrt(norm);
|
||||||
|
|
||||||
|
// 归一化
|
||||||
|
if (norm > 0.0f) {
|
||||||
|
for (int i = 0; i < dim; i++) {
|
||||||
|
normalized[i] = embedding[i] / norm;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return normalized;
|
||||||
|
}
|
||||||
|
|
||||||
|
float BgeEngineRKNN::cosineSimilarity(const std::vector<float>& vec1, const std::vector<float>& vec2) {
|
||||||
|
if (vec1.size() != vec2.size()) {
|
||||||
|
LOGE("Vector dimension mismatch");
|
||||||
|
return 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
float dot_product = 0.0f;
|
||||||
|
for (size_t i = 0; i < vec1.size(); i++) {
|
||||||
|
dot_product += vec1[i] * vec2[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
return dot_product;
|
||||||
|
}
|
||||||
|
|
||||||
|
int BgeEngineRKNN::initRKNN(const char* modelPath) {
|
||||||
|
int ret = 0;
|
||||||
|
rknn_context ctx = 0;
|
||||||
|
|
||||||
|
// 初始化RKNN
|
||||||
|
ret = rknn_init(&ctx, (void*)modelPath, 0, 0, NULL);
|
||||||
|
if (ret != RKNN_SUCC) {
|
||||||
|
LOGE("rknn_init failed: %d", ret);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
m_ctx = ctx;
|
||||||
|
|
||||||
|
// 查询模型输入输出信息
|
||||||
|
rknn_input_output_num io_num;
|
||||||
|
ret = rknn_query(m_ctx, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num));
|
||||||
|
if (ret != RKNN_SUCC) {
|
||||||
|
LOGE("rknn_query input_output_num failed: %d", ret);
|
||||||
|
rknn_destroy(ctx);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
LOGI("Input number: %d, Output number: %d", io_num.n_input, io_num.n_output);
|
||||||
|
|
||||||
|
// 查询所有输入属性
|
||||||
|
for (int i = 0; i < io_num.n_input; i++) {
|
||||||
|
rknn_tensor_attr input_attr;
|
||||||
|
memset(&input_attr, 0, sizeof(input_attr));
|
||||||
|
input_attr.index = i;
|
||||||
|
ret = rknn_query(m_ctx, RKNN_QUERY_INPUT_ATTR, &input_attr, sizeof(input_attr));
|
||||||
|
if (ret != RKNN_SUCC) {
|
||||||
|
LOGE("rknn_query input_attr %d failed: %d", i, ret);
|
||||||
|
rknn_destroy(ctx);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
LOGI("Input %d info:", i);
|
||||||
|
LOGI(" type: %d", input_attr.type);
|
||||||
|
LOGI(" size: %d", input_attr.size);
|
||||||
|
LOGI(" n_dims: %d", input_attr.n_dims);
|
||||||
|
LOGI(" dims:");
|
||||||
|
for (int j = 0; j < input_attr.n_dims; j++) {
|
||||||
|
LOGI(" dim[%d] = %d", j, input_attr.dims[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 查询所有输出属性
|
||||||
|
for (int i = 0; i < io_num.n_output; i++) {
|
||||||
|
rknn_tensor_attr output_attr;
|
||||||
|
memset(&output_attr, 0, sizeof(output_attr));
|
||||||
|
output_attr.index = i;
|
||||||
|
ret = rknn_query(m_ctx, RKNN_QUERY_OUTPUT_ATTR, &output_attr, sizeof(output_attr));
|
||||||
|
if (ret != RKNN_SUCC) {
|
||||||
|
LOGE("rknn_query output_attr %d failed: %d", i, ret);
|
||||||
|
rknn_destroy(ctx);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 打印输出信息
|
||||||
|
LOGI("Output %d dims:", i);
|
||||||
|
for (int j = 0; j < output_attr.n_dims; j++) {
|
||||||
|
LOGI(" dim[%d] = %d", j, output_attr.dims[j]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否为 output1 且维度符合要求
|
||||||
|
if (i == 1) {
|
||||||
|
bool is_valid_output = false;
|
||||||
|
if (output_attr.n_dims == 2 && output_attr.dims[0] == 1 && output_attr.dims[1] == 512) {
|
||||||
|
is_valid_output = true;
|
||||||
|
} else if (output_attr.n_dims == 3 && output_attr.dims[0] == 1 && output_attr.dims[1] == 1 && output_attr.dims[2] == 512) {
|
||||||
|
is_valid_output = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_valid_output) {
|
||||||
|
LOGI("Output 1 has valid embedding dimensions: using Output 1");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 确定嵌入维度(暂时使用 output 0,后续可能需要修改为使用 output 1)
|
||||||
|
rknn_tensor_attr output0_attr;
|
||||||
|
memset(&output0_attr, 0, sizeof(output0_attr));
|
||||||
|
output0_attr.index = 0;
|
||||||
|
ret = rknn_query(m_ctx, RKNN_QUERY_OUTPUT_ATTR, &output0_attr, sizeof(output0_attr));
|
||||||
|
if (ret == RKNN_SUCC && output0_attr.n_dims >= 2) {
|
||||||
|
m_embedding_dim = output0_attr.dims[output0_attr.n_dims - 1];
|
||||||
|
LOGI("Embedding dimension (from output 0): %d", m_embedding_dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void BgeEngineRKNN::deinitRKNN() {
|
||||||
|
if (m_ctx != 0) {
|
||||||
|
rknn_destroy(m_ctx);
|
||||||
|
m_ctx = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
float* BgeEngineRKNN::inferEmbedding(const std::vector<int>& tokens) {
|
||||||
|
// 准备输入数据
|
||||||
|
const int max_seq_len = 512;
|
||||||
|
std::vector<int> padded_tokens = tokens;
|
||||||
|
|
||||||
|
// 填充到最大序列长度
|
||||||
|
if (padded_tokens.size() < max_seq_len) {
|
||||||
|
padded_tokens.resize(max_seq_len, 0); // 使用0作为padding
|
||||||
|
} else if (padded_tokens.size() > max_seq_len) {
|
||||||
|
padded_tokens.resize(max_seq_len); // 截断
|
||||||
|
}
|
||||||
|
|
||||||
|
// 准备attention_mask:1表示真实token,0表示padding
|
||||||
|
std::vector<int> attention_mask(max_seq_len, 0);
|
||||||
|
for (int i = 0; i < tokens.size() && i < max_seq_len; i++) {
|
||||||
|
attention_mask[i] = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 准备token_type_ids:对于单句子任务,全部为0
|
||||||
|
std::vector<int> token_type_ids(max_seq_len, 0);
|
||||||
|
|
||||||
|
// 准备输入张量
|
||||||
|
rknn_input inputs[3]; // 最多3个输入:input_ids, attention_mask, token_type_ids
|
||||||
|
memset(inputs, 0, sizeof(inputs));
|
||||||
|
|
||||||
|
// 输入0: input_ids
|
||||||
|
inputs[0].index = 0;
|
||||||
|
inputs[0].type = RKNN_TENSOR_INT32;
|
||||||
|
inputs[0].size = sizeof(int) * max_seq_len;
|
||||||
|
inputs[0].buf = padded_tokens.data();
|
||||||
|
inputs[0].pass_through = 0;
|
||||||
|
|
||||||
|
// 输入1: attention_mask
|
||||||
|
inputs[1].index = 1;
|
||||||
|
inputs[1].type = RKNN_TENSOR_INT32;
|
||||||
|
inputs[1].size = sizeof(int) * max_seq_len;
|
||||||
|
inputs[1].buf = attention_mask.data();
|
||||||
|
inputs[1].pass_through = 0;
|
||||||
|
|
||||||
|
// 输入2: token_type_ids
|
||||||
|
inputs[2].index = 2;
|
||||||
|
inputs[2].type = RKNN_TENSOR_INT32;
|
||||||
|
inputs[2].size = sizeof(int) * max_seq_len;
|
||||||
|
inputs[2].buf = token_type_ids.data();
|
||||||
|
inputs[2].pass_through = 0;
|
||||||
|
|
||||||
|
// 设置输入(根据实际的输入数量)
|
||||||
|
rknn_input_output_num io_num;
|
||||||
|
int ret = rknn_query(m_ctx, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num));
|
||||||
|
if (ret != RKNN_SUCC) {
|
||||||
|
LOGE("rknn_query input_output_num failed: %d", ret);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
ret = rknn_inputs_set(m_ctx, io_num.n_input, inputs);
|
||||||
|
if (ret != RKNN_SUCC) {
|
||||||
|
LOGE("rknn_inputs_set failed: %d", ret);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 运行推理
|
||||||
|
ret = rknn_run(m_ctx, NULL);
|
||||||
|
if (ret != RKNN_SUCC) {
|
||||||
|
LOGE("rknn_run failed: %d", ret);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取输出
|
||||||
|
rknn_output outputs[2]; // 增加到2个输出
|
||||||
|
memset(outputs, 0, sizeof(outputs));
|
||||||
|
outputs[0].want_float = 1;
|
||||||
|
outputs[1].want_float = 1;
|
||||||
|
|
||||||
|
ret = rknn_outputs_get(m_ctx, 2, outputs, NULL);
|
||||||
|
if (ret != RKNN_SUCC) {
|
||||||
|
LOGE("rknn_outputs_get failed: %d", ret);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 提取嵌入(使用output 1)
|
||||||
|
float* embedding = new float[m_embedding_dim];
|
||||||
|
memcpy(embedding, outputs[1].buf, sizeof(float) * m_embedding_dim);
|
||||||
|
LOGI("Using output 1 for embedding");
|
||||||
|
|
||||||
|
// 释放输出
|
||||||
|
rknn_outputs_release(m_ctx, 2, outputs);
|
||||||
|
|
||||||
|
return embedding;
|
||||||
|
}
|
||||||
42
app/src/main/cpp/BgeEngineRKNN.h
Normal file
42
app/src/main/cpp/BgeEngineRKNN.h
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
#ifndef BGE_ENGINE_RKNN_H
|
||||||
|
#define BGE_ENGINE_RKNN_H
|
||||||
|
|
||||||
|
#include <rknn_api.h>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
class BgeEngineRKNN {
|
||||||
|
public:
|
||||||
|
BgeEngineRKNN();
|
||||||
|
~BgeEngineRKNN();
|
||||||
|
|
||||||
|
int loadModel(const char* modelPath, const char* vocabPath = nullptr);
|
||||||
|
void freeModel();
|
||||||
|
float* getEmbedding(const std::string& text);
|
||||||
|
float calculateSimilarity(const std::string& text1, const std::string& text2);
|
||||||
|
int getEmbeddingDim() const; // 获取嵌入维度
|
||||||
|
|
||||||
|
private:
|
||||||
|
rknn_context m_ctx;
|
||||||
|
bool m_initialized;
|
||||||
|
int m_embedding_dim;
|
||||||
|
std::map<std::string, int> m_vocab; // 词汇表映射
|
||||||
|
|
||||||
|
// 辅助方法
|
||||||
|
std::vector<int> tokenize(const std::string& text);
|
||||||
|
std::vector<float> normalizeEmbedding(const float* embedding, int dim);
|
||||||
|
float cosineSimilarity(const std::vector<float>& vec1, const std::vector<float>& vec2);
|
||||||
|
|
||||||
|
// Tokenizer相关方法
|
||||||
|
int loadVocab(const char* vocabPath);
|
||||||
|
std::vector<std::string> basicTokenize(const std::string& text);
|
||||||
|
std::vector<std::string> wordPieceTokenize(const std::vector<std::string>& basic_tokens);
|
||||||
|
|
||||||
|
// RKNN相关方法
|
||||||
|
int initRKNN(const char* modelPath);
|
||||||
|
void deinitRKNN();
|
||||||
|
float* inferEmbedding(const std::vector<int>& tokens);
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // BGE_ENGINE_RKNN_H
|
||||||
229
app/src/main/cpp/BgeEngineRKNNJNI.cpp
Normal file
229
app/src/main/cpp/BgeEngineRKNNJNI.cpp
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
#include <jni.h>
|
||||||
|
#include <string>
|
||||||
|
#include <map>
|
||||||
|
#include <android/log.h>
|
||||||
|
#include "BgeEngineRKNN.h"
|
||||||
|
|
||||||
|
#define LOG_TAG "BgeEngineJNI"
|
||||||
|
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__)
|
||||||
|
#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__)
|
||||||
|
|
||||||
|
// 全局引擎实例映射
|
||||||
|
static std::map<jlong, BgeEngineRKNN*> g_engine_map;
|
||||||
|
static jclass g_float_array_class;
|
||||||
|
static jmethodID g_float_array_ctor;
|
||||||
|
|
||||||
|
// 获取BGE引擎实例
|
||||||
|
BgeEngineRKNN* getBgeEngineFromPtr(jlong ptr) {
|
||||||
|
auto it = g_engine_map.find(ptr);
|
||||||
|
if (it != g_engine_map.end()) {
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// JNI接口实现
|
||||||
|
|
||||||
|
// 创建BGE引擎
|
||||||
|
extern "C" JNIEXPORT jlong JNICALL Java_com_digitalperson_engine_BgeEngineRKNN_createBgeEngine(
|
||||||
|
JNIEnv* env,
|
||||||
|
jobject thiz) {
|
||||||
|
LOGI("Creating BGE engine");
|
||||||
|
BgeEngineRKNN* engine = new BgeEngineRKNN();
|
||||||
|
jlong ptr = reinterpret_cast<jlong>(engine);
|
||||||
|
g_engine_map[ptr] = engine;
|
||||||
|
return ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 加载模型
|
||||||
|
extern "C" JNIEXPORT jint JNICALL Java_com_digitalperson_engine_BgeEngineRKNN_loadModel(
|
||||||
|
JNIEnv* env,
|
||||||
|
jobject thiz,
|
||||||
|
jlong ptr,
|
||||||
|
jstring modelPath,
|
||||||
|
jstring vocabPath) {
|
||||||
|
LOGI("Loading BGE model with vocab");
|
||||||
|
BgeEngineRKNN* engine = getBgeEngineFromPtr(ptr);
|
||||||
|
if (engine == nullptr) {
|
||||||
|
LOGE("Invalid engine pointer");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char* c_modelPath = env->GetStringUTFChars(modelPath, NULL);
|
||||||
|
if (c_modelPath == nullptr) {
|
||||||
|
LOGE("Failed to get model path string");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char* c_vocabPath = nullptr;
|
||||||
|
if (vocabPath != nullptr) {
|
||||||
|
c_vocabPath = env->GetStringUTFChars(vocabPath, NULL);
|
||||||
|
if (c_vocabPath == nullptr) {
|
||||||
|
LOGE("Failed to get vocab path string");
|
||||||
|
env->ReleaseStringUTFChars(modelPath, c_modelPath);
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int ret = engine->loadModel(c_modelPath, c_vocabPath);
|
||||||
|
|
||||||
|
env->ReleaseStringUTFChars(modelPath, c_modelPath);
|
||||||
|
if (c_vocabPath != nullptr) {
|
||||||
|
env->ReleaseStringUTFChars(vocabPath, c_vocabPath);
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 释放模型
|
||||||
|
extern "C" JNIEXPORT void JNICALL Java_com_digitalperson_engine_BgeEngineRKNN_freeModel(
|
||||||
|
JNIEnv* env,
|
||||||
|
jobject thiz,
|
||||||
|
jlong ptr) {
|
||||||
|
LOGI("Freeing BGE model");
|
||||||
|
BgeEngineRKNN* engine = getBgeEngineFromPtr(ptr);
|
||||||
|
if (engine == nullptr) {
|
||||||
|
LOGE("Invalid engine pointer");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
engine->freeModel();
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取嵌入
|
||||||
|
extern "C" JNIEXPORT jfloatArray JNICALL Java_com_digitalperson_engine_BgeEngineRKNN_getEmbeddingNative(
|
||||||
|
JNIEnv* env,
|
||||||
|
jobject thiz,
|
||||||
|
jlong ptr,
|
||||||
|
jstring text) {
|
||||||
|
BgeEngineRKNN* engine = getBgeEngineFromPtr(ptr);
|
||||||
|
if (engine == nullptr) {
|
||||||
|
LOGE("Invalid engine pointer");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char* c_text = env->GetStringUTFChars(text, NULL);
|
||||||
|
if (c_text == nullptr) {
|
||||||
|
LOGE("Failed to get text string");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string cpp_text(c_text);
|
||||||
|
env->ReleaseStringUTFChars(text, c_text);
|
||||||
|
|
||||||
|
float* embedding = engine->getEmbedding(cpp_text);
|
||||||
|
if (embedding == nullptr) {
|
||||||
|
LOGE("Failed to get embedding");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取正确的嵌入维度
|
||||||
|
int embedding_dim = engine->getEmbeddingDim();
|
||||||
|
LOGI("Using embedding dimension: %d", embedding_dim);
|
||||||
|
|
||||||
|
// 创建Java float数组并返回
|
||||||
|
jfloatArray j_embedding = env->NewFloatArray(embedding_dim);
|
||||||
|
if (j_embedding == nullptr) {
|
||||||
|
LOGE("Failed to create float array");
|
||||||
|
delete[] embedding;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
env->SetFloatArrayRegion(j_embedding, 0, embedding_dim, embedding);
|
||||||
|
delete[] embedding;
|
||||||
|
return j_embedding;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取嵌入维度
|
||||||
|
extern "C" JNIEXPORT jint JNICALL Java_com_digitalperson_engine_BgeEngineRKNN_getEmbeddingDim(
|
||||||
|
JNIEnv* env,
|
||||||
|
jobject thiz,
|
||||||
|
jlong ptr) {
|
||||||
|
BgeEngineRKNN* engine = getBgeEngineFromPtr(ptr);
|
||||||
|
if (engine == nullptr) {
|
||||||
|
LOGE("Invalid engine pointer");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int embedding_dim = engine->getEmbeddingDim();
|
||||||
|
LOGI("Returning embedding dimension: %d", embedding_dim);
|
||||||
|
return embedding_dim;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算相似度
|
||||||
|
extern "C" JNIEXPORT jfloat JNICALL Java_com_digitalperson_engine_BgeEngineRKNN_calculateSimilarityNative(
|
||||||
|
JNIEnv* env,
|
||||||
|
jobject thiz,
|
||||||
|
jlong ptr,
|
||||||
|
jstring text1,
|
||||||
|
jstring text2) {
|
||||||
|
LOGI("Calculating similarity");
|
||||||
|
BgeEngineRKNN* engine = getBgeEngineFromPtr(ptr);
|
||||||
|
if (engine == nullptr) {
|
||||||
|
LOGE("Invalid engine pointer");
|
||||||
|
return 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char* c_text1 = env->GetStringUTFChars(text1, NULL);
|
||||||
|
const char* c_text2 = env->GetStringUTFChars(text2, NULL);
|
||||||
|
if (c_text1 == nullptr || c_text2 == nullptr) {
|
||||||
|
LOGE("Failed to get text strings");
|
||||||
|
if (c_text1 != nullptr) {
|
||||||
|
env->ReleaseStringUTFChars(text1, c_text1);
|
||||||
|
}
|
||||||
|
if (c_text2 != nullptr) {
|
||||||
|
env->ReleaseStringUTFChars(text2, c_text2);
|
||||||
|
}
|
||||||
|
return 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string cpp_text1(c_text1);
|
||||||
|
std::string cpp_text2(c_text2);
|
||||||
|
env->ReleaseStringUTFChars(text1, c_text1);
|
||||||
|
env->ReleaseStringUTFChars(text2, c_text2);
|
||||||
|
|
||||||
|
float similarity = engine->calculateSimilarity(cpp_text1, cpp_text2);
|
||||||
|
return similarity;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// 初始化JNI
|
||||||
|
jint JNI_OnLoad(JavaVM* vm, void* reserved) {
|
||||||
|
JNIEnv* env = nullptr;
|
||||||
|
if (vm->GetEnv((void**)&env, JNI_VERSION_1_6) != JNI_OK) {
|
||||||
|
LOGE("Failed to get JNI environment");
|
||||||
|
return JNI_ERR;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 缓存类和方法ID
|
||||||
|
jclass float_array_cls = env->FindClass("[F");
|
||||||
|
if (float_array_cls == nullptr) {
|
||||||
|
LOGE("Failed to find float array class");
|
||||||
|
return JNI_ERR;
|
||||||
|
}
|
||||||
|
g_float_array_class = (jclass)env->NewGlobalRef(float_array_cls);
|
||||||
|
|
||||||
|
return JNI_VERSION_1_6;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 清理JNI
|
||||||
|
void JNI_OnUnload(JavaVM* vm, void* reserved) {
|
||||||
|
JNIEnv* env = nullptr;
|
||||||
|
if (vm->GetEnv((void**)&env, JNI_VERSION_1_6) != JNI_OK) {
|
||||||
|
LOGE("Failed to get JNI environment");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 释放全局引用
|
||||||
|
if (g_float_array_class != nullptr) {
|
||||||
|
env->DeleteGlobalRef(g_float_array_class);
|
||||||
|
g_float_array_class = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 清理引擎实例映射
|
||||||
|
for (auto& pair : g_engine_map) {
|
||||||
|
delete pair.second;
|
||||||
|
}
|
||||||
|
g_engine_map.clear();
|
||||||
|
}
|
||||||
@@ -58,5 +58,21 @@ if (ANDROID)
|
|||||||
jnigraphics
|
jnigraphics
|
||||||
log
|
log
|
||||||
)
|
)
|
||||||
|
|
||||||
|
add_library(bgeEngine SHARED
|
||||||
|
BgeEngineRKNN.cpp
|
||||||
|
BgeEngineRKNNJNI.cpp
|
||||||
|
)
|
||||||
|
|
||||||
|
target_include_directories(bgeEngine PRIVATE
|
||||||
|
${CMAKE_CURRENT_LIST_DIR}
|
||||||
|
${CMAKE_CURRENT_LIST_DIR}/zipformer_headers
|
||||||
|
${CMAKE_CURRENT_LIST_DIR}/utils
|
||||||
|
)
|
||||||
|
|
||||||
|
target_link_libraries(bgeEngine
|
||||||
|
rknnrt
|
||||||
|
log
|
||||||
|
)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|||||||
26
app/src/main/java/com/digitalperson/DigitalPersonApp.kt
Normal file
26
app/src/main/java/com/digitalperson/DigitalPersonApp.kt
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
package com.digitalperson
|
||||||
|
|
||||||
|
import android.app.Application
|
||||||
|
import android.util.Log
|
||||||
|
import com.digitalperson.config.AppConfig
|
||||||
|
import com.digitalperson.embedding.RefEmbeddingIndexer
|
||||||
|
import kotlinx.coroutines.CoroutineScope
|
||||||
|
import kotlinx.coroutines.Dispatchers
|
||||||
|
import kotlinx.coroutines.SupervisorJob
|
||||||
|
import kotlinx.coroutines.launch
|
||||||
|
|
||||||
|
class DigitalPersonApp : Application() {
|
||||||
|
|
||||||
|
private val appScope = CoroutineScope(SupervisorJob() + Dispatchers.Default)
|
||||||
|
|
||||||
|
override fun onCreate() {
|
||||||
|
super.onCreate()
|
||||||
|
appScope.launch {
|
||||||
|
try {
|
||||||
|
RefEmbeddingIndexer.runOnce(this@DigitalPersonApp)
|
||||||
|
} catch (t: Throwable) {
|
||||||
|
Log.e(AppConfig.TAG, "[RefEmbed] 索引任务异常", t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -38,6 +38,7 @@ import com.digitalperson.interaction.ConversationSummaryMemory
|
|||||||
|
|
||||||
import java.io.File
|
import java.io.File
|
||||||
import android.graphics.BitmapFactory
|
import android.graphics.BitmapFactory
|
||||||
|
import android.widget.ImageView
|
||||||
import kotlinx.coroutines.CoroutineScope
|
import kotlinx.coroutines.CoroutineScope
|
||||||
import kotlinx.coroutines.Dispatchers
|
import kotlinx.coroutines.Dispatchers
|
||||||
import kotlinx.coroutines.Job
|
import kotlinx.coroutines.Job
|
||||||
@@ -49,6 +50,7 @@ import kotlinx.coroutines.withContext
|
|||||||
|
|
||||||
import com.digitalperson.onboard_testing.FaceRecognitionTest
|
import com.digitalperson.onboard_testing.FaceRecognitionTest
|
||||||
import com.digitalperson.onboard_testing.LLMSummaryTest
|
import com.digitalperson.onboard_testing.LLMSummaryTest
|
||||||
|
import com.digitalperson.embedding.RefImageMatcher
|
||||||
|
|
||||||
class Live2DChatActivity : AppCompatActivity() {
|
class Live2DChatActivity : AppCompatActivity() {
|
||||||
companion object {
|
companion object {
|
||||||
@@ -109,6 +111,7 @@ class Live2DChatActivity : AppCompatActivity() {
|
|||||||
|
|
||||||
private lateinit var faceRecognitionTest: FaceRecognitionTest
|
private lateinit var faceRecognitionTest: FaceRecognitionTest
|
||||||
private lateinit var llmSummaryTest: LLMSummaryTest
|
private lateinit var llmSummaryTest: LLMSummaryTest
|
||||||
|
private var refMatchImageView: ImageView? = null
|
||||||
|
|
||||||
override fun onRequestPermissionsResult(
|
override fun onRequestPermissionsResult(
|
||||||
requestCode: Int,
|
requestCode: Int,
|
||||||
@@ -160,6 +163,8 @@ class Live2DChatActivity : AppCompatActivity() {
|
|||||||
live2dViewId = R.id.live2d_view
|
live2dViewId = R.id.live2d_view
|
||||||
)
|
)
|
||||||
|
|
||||||
|
refMatchImageView = findViewById(R.id.ref_match_image)
|
||||||
|
|
||||||
cameraPreviewView = findViewById(R.id.camera_preview)
|
cameraPreviewView = findViewById(R.id.camera_preview)
|
||||||
cameraPreviewView.implementationMode = PreviewView.ImplementationMode.COMPATIBLE
|
cameraPreviewView.implementationMode = PreviewView.ImplementationMode.COMPATIBLE
|
||||||
faceOverlayView = findViewById(R.id.face_overlay)
|
faceOverlayView = findViewById(R.id.face_overlay)
|
||||||
@@ -611,6 +616,7 @@ class Live2DChatActivity : AppCompatActivity() {
|
|||||||
runOnUiThread {
|
runOnUiThread {
|
||||||
uiManager.appendToUi("${filteredText.orEmpty()}\n")
|
uiManager.appendToUi("${filteredText.orEmpty()}\n")
|
||||||
}
|
}
|
||||||
|
maybeShowMatchedRefImage(filteredText ?: response)
|
||||||
}
|
}
|
||||||
interactionCoordinator.onCloudFinalResponse(response)
|
interactionCoordinator.onCloudFinalResponse(response)
|
||||||
}
|
}
|
||||||
@@ -649,6 +655,24 @@ class Live2DChatActivity : AppCompatActivity() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private fun maybeShowMatchedRefImage(text: String) {
|
||||||
|
val imageView = refMatchImageView ?: return
|
||||||
|
ioScope.launch {
|
||||||
|
val match = RefImageMatcher.findBestMatch(applicationContext, text)
|
||||||
|
if (match == null) return@launch
|
||||||
|
val bitmap = try {
|
||||||
|
assets.open(match.pngAssetPath).use { BitmapFactory.decodeStream(it) }
|
||||||
|
} catch (_: Throwable) {
|
||||||
|
null
|
||||||
|
}
|
||||||
|
if (bitmap == null) return@launch
|
||||||
|
runOnUiThread {
|
||||||
|
imageView.setImageBitmap(bitmap)
|
||||||
|
imageView.visibility = android.view.View.VISIBLE
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private fun createTtsCallback() = object : TtsController.TtsCallback {
|
private fun createTtsCallback() = object : TtsController.TtsCallback {
|
||||||
override fun onTtsStarted(text: String) {
|
override fun onTtsStarted(text: String) {
|
||||||
runOnUiThread {
|
runOnUiThread {
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import android.view.View
|
|||||||
import androidx.lifecycle.Lifecycle
|
import androidx.lifecycle.Lifecycle
|
||||||
import androidx.lifecycle.LifecycleOwner
|
import androidx.lifecycle.LifecycleOwner
|
||||||
import androidx.lifecycle.LifecycleRegistry
|
import androidx.lifecycle.LifecycleRegistry
|
||||||
|
import android.widget.ImageView
|
||||||
import com.unity3d.player.UnityPlayer
|
import com.unity3d.player.UnityPlayer
|
||||||
import com.unity3d.player.UnityPlayerActivity
|
import com.unity3d.player.UnityPlayerActivity
|
||||||
import com.digitalperson.audio.AudioProcessor
|
import com.digitalperson.audio.AudioProcessor
|
||||||
@@ -47,6 +48,8 @@ import com.digitalperson.tts.TtsController
|
|||||||
import com.digitalperson.util.FileHelper
|
import com.digitalperson.util.FileHelper
|
||||||
import com.digitalperson.vad.VadManager
|
import com.digitalperson.vad.VadManager
|
||||||
import kotlinx.coroutines.*
|
import kotlinx.coroutines.*
|
||||||
|
import com.digitalperson.embedding.RefImageMatcher
|
||||||
|
import android.graphics.BitmapFactory
|
||||||
|
|
||||||
class UnityDigitalPersonActivity : UnityPlayerActivity(), LifecycleOwner {
|
class UnityDigitalPersonActivity : UnityPlayerActivity(), LifecycleOwner {
|
||||||
|
|
||||||
@@ -108,6 +111,7 @@ class UnityDigitalPersonActivity : UnityPlayerActivity(), LifecycleOwner {
|
|||||||
private lateinit var holdToSpeakButton: Button
|
private lateinit var holdToSpeakButton: Button
|
||||||
private var recordButtonGlow: View? = null
|
private var recordButtonGlow: View? = null
|
||||||
private var pulseAnimator: ObjectAnimator? = null
|
private var pulseAnimator: ObjectAnimator? = null
|
||||||
|
private var refMatchImageView: ImageView? = null
|
||||||
|
|
||||||
|
|
||||||
// 音频和AI模块
|
// 音频和AI模块
|
||||||
@@ -254,6 +258,7 @@ class UnityDigitalPersonActivity : UnityPlayerActivity(), LifecycleOwner {
|
|||||||
chatHistoryText = chatLayout.findViewById(R.id.my_text)
|
chatHistoryText = chatLayout.findViewById(R.id.my_text)
|
||||||
holdToSpeakButton = chatLayout.findViewById(R.id.record_button)
|
holdToSpeakButton = chatLayout.findViewById(R.id.record_button)
|
||||||
recordButtonGlow = chatLayout.findViewById(R.id.record_button_glow)
|
recordButtonGlow = chatLayout.findViewById(R.id.record_button_glow)
|
||||||
|
refMatchImageView = chatLayout.findViewById(R.id.ref_match_image)
|
||||||
|
|
||||||
// 根据配置设置按钮可见性
|
// 根据配置设置按钮可见性
|
||||||
if (AppConfig.USE_HOLD_TO_SPEAK) {
|
if (AppConfig.USE_HOLD_TO_SPEAK) {
|
||||||
@@ -735,6 +740,7 @@ class UnityDigitalPersonActivity : UnityPlayerActivity(), LifecycleOwner {
|
|||||||
val filteredText = ttsController.speakLlmResponse(response)
|
val filteredText = ttsController.speakLlmResponse(response)
|
||||||
android.util.Log.d("UnityDigitalPerson", "LLM response filtered: ${filteredText?.take(60)}")
|
android.util.Log.d("UnityDigitalPerson", "LLM response filtered: ${filteredText?.take(60)}")
|
||||||
if (filteredText != null) appendChat("助手: $filteredText")
|
if (filteredText != null) appendChat("助手: $filteredText")
|
||||||
|
maybeShowMatchedRefImage(filteredText ?: response)
|
||||||
interactionCoordinator.onCloudFinalResponse(filteredText ?: response.trim())
|
interactionCoordinator.onCloudFinalResponse(filteredText ?: response.trim())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -751,6 +757,25 @@ class UnityDigitalPersonActivity : UnityPlayerActivity(), LifecycleOwner {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private fun maybeShowMatchedRefImage(text: String) {
|
||||||
|
val imageView = refMatchImageView ?: return
|
||||||
|
// Unity Activity already has coroutines
|
||||||
|
CoroutineScope(SupervisorJob() + Dispatchers.IO).launch {
|
||||||
|
val match = RefImageMatcher.findBestMatch(applicationContext, text)
|
||||||
|
if (match == null) return@launch
|
||||||
|
val bitmap = try {
|
||||||
|
assets.open(match.pngAssetPath).use { BitmapFactory.decodeStream(it) }
|
||||||
|
} catch (_: Throwable) {
|
||||||
|
null
|
||||||
|
}
|
||||||
|
if (bitmap == null) return@launch
|
||||||
|
runOnUiThread {
|
||||||
|
imageView.setImageBitmap(bitmap)
|
||||||
|
imageView.visibility = View.VISIBLE
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private fun requestLocalThought(prompt: String, onResult: (String) -> Unit) {
|
private fun requestLocalThought(prompt: String, onResult: (String) -> Unit) {
|
||||||
val local = llmManager
|
val local = llmManager
|
||||||
if (local == null) {
|
if (local == null) {
|
||||||
|
|||||||
@@ -110,6 +110,19 @@ object AppConfig {
|
|||||||
const val MODEL_SIZE_ESTIMATE = 500L * 1024 * 1024 // 500MB
|
const val MODEL_SIZE_ESTIMATE = 500L * 1024 * 1024 // 500MB
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** BGE-small-zh-v1.5 文本嵌入(RKNN),用于语义相似度 / 检索。 */
|
||||||
|
object Bge {
|
||||||
|
const val ASSET_DIR = "bge_models"
|
||||||
|
const val MODEL_FILE = "bge-small-zh-v1.5.rknn"
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* app/note/ref 通过 Gradle 额外 assets 目录打入 apk 后,在 assets 中的根路径为 `ref/`。
|
||||||
|
*/
|
||||||
|
object RefCorpus {
|
||||||
|
const val ASSETS_ROOT = "ref"
|
||||||
|
}
|
||||||
|
|
||||||
object OnboardTesting {
|
object OnboardTesting {
|
||||||
// 测试人脸识别
|
// 测试人脸识别
|
||||||
const val FACE_REGONITION = false
|
const val FACE_REGONITION = false
|
||||||
|
|||||||
@@ -9,15 +9,24 @@ import com.digitalperson.data.dao.UserAnswerDao
|
|||||||
import com.digitalperson.data.dao.UserMemoryDao
|
import com.digitalperson.data.dao.UserMemoryDao
|
||||||
import com.digitalperson.data.dao.ChatMessageDao
|
import com.digitalperson.data.dao.ChatMessageDao
|
||||||
import com.digitalperson.data.dao.ConversationSummaryDao
|
import com.digitalperson.data.dao.ConversationSummaryDao
|
||||||
|
import com.digitalperson.data.dao.RefTextEmbeddingDao
|
||||||
import com.digitalperson.data.entity.Question
|
import com.digitalperson.data.entity.Question
|
||||||
import com.digitalperson.data.entity.UserAnswer
|
import com.digitalperson.data.entity.UserAnswer
|
||||||
import com.digitalperson.data.entity.UserMemory
|
import com.digitalperson.data.entity.UserMemory
|
||||||
import com.digitalperson.data.entity.ChatMessageEntity
|
import com.digitalperson.data.entity.ChatMessageEntity
|
||||||
import com.digitalperson.data.entity.ConversationSummaryEntity
|
import com.digitalperson.data.entity.ConversationSummaryEntity
|
||||||
|
import com.digitalperson.data.entity.RefTextEmbedding
|
||||||
|
|
||||||
@Database(
|
@Database(
|
||||||
entities = [UserMemory::class, Question::class, UserAnswer::class, ChatMessageEntity::class, ConversationSummaryEntity::class],
|
entities = [
|
||||||
version = 4,
|
UserMemory::class,
|
||||||
|
Question::class,
|
||||||
|
UserAnswer::class,
|
||||||
|
ChatMessageEntity::class,
|
||||||
|
ConversationSummaryEntity::class,
|
||||||
|
RefTextEmbedding::class
|
||||||
|
],
|
||||||
|
version = 5,
|
||||||
exportSchema = false
|
exportSchema = false
|
||||||
)
|
)
|
||||||
abstract class AppDatabase : RoomDatabase() {
|
abstract class AppDatabase : RoomDatabase() {
|
||||||
@@ -26,6 +35,7 @@ abstract class AppDatabase : RoomDatabase() {
|
|||||||
abstract fun userAnswerDao(): UserAnswerDao
|
abstract fun userAnswerDao(): UserAnswerDao
|
||||||
abstract fun chatMessageDao(): ChatMessageDao
|
abstract fun chatMessageDao(): ChatMessageDao
|
||||||
abstract fun conversationSummaryDao(): ConversationSummaryDao
|
abstract fun conversationSummaryDao(): ConversationSummaryDao
|
||||||
|
abstract fun refTextEmbeddingDao(): RefTextEmbeddingDao
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
private const val DATABASE_NAME = "digital_human.db"
|
private const val DATABASE_NAME = "digital_human.db"
|
||||||
|
|||||||
@@ -10,6 +10,17 @@ interface QuestionDao {
|
|||||||
@Insert
|
@Insert
|
||||||
fun insert(question: Question): Long
|
fun insert(question: Question): Long
|
||||||
|
|
||||||
|
@Query(
|
||||||
|
"""
|
||||||
|
SELECT * FROM questions
|
||||||
|
WHERE content = :content
|
||||||
|
AND ((:subject IS NULL AND subject IS NULL) OR subject = :subject)
|
||||||
|
AND ((:grade IS NULL AND grade IS NULL) OR grade = :grade)
|
||||||
|
LIMIT 1
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
fun findByContentSubjectGrade(content: String, subject: String?, grade: Int?): Question?
|
||||||
|
|
||||||
@Query("SELECT * FROM questions WHERE subject = :subject ORDER BY difficulty")
|
@Query("SELECT * FROM questions WHERE subject = :subject ORDER BY difficulty")
|
||||||
fun getQuestionsBySubject(subject: String): List<Question>
|
fun getQuestionsBySubject(subject: String): List<Question>
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,20 @@
|
|||||||
|
package com.digitalperson.data.dao
|
||||||
|
|
||||||
|
import androidx.room.Dao
|
||||||
|
import androidx.room.Insert
|
||||||
|
import androidx.room.OnConflictStrategy
|
||||||
|
import androidx.room.Query
|
||||||
|
import com.digitalperson.data.entity.RefTextEmbedding
|
||||||
|
|
||||||
|
@Dao
|
||||||
|
interface RefTextEmbeddingDao {
|
||||||
|
|
||||||
|
@Query("SELECT * FROM ref_text_embeddings WHERE assetPath = :path LIMIT 1")
|
||||||
|
fun getByPath(path: String): RefTextEmbedding?
|
||||||
|
|
||||||
|
@Insert(onConflict = OnConflictStrategy.REPLACE)
|
||||||
|
fun insert(row: RefTextEmbedding): Long
|
||||||
|
|
||||||
|
@Query("SELECT * FROM ref_text_embeddings")
|
||||||
|
fun getAll(): List<RefTextEmbedding>
|
||||||
|
}
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
package com.digitalperson.data.entity
|
||||||
|
|
||||||
|
import androidx.room.Entity
|
||||||
|
import androidx.room.Index
|
||||||
|
import androidx.room.PrimaryKey
|
||||||
|
import com.digitalperson.data.util.embeddingBytesToFloatArray
|
||||||
|
|
||||||
|
@Entity(
|
||||||
|
tableName = "ref_text_embeddings",
|
||||||
|
indices = [Index(value = ["assetPath"], unique = true)]
|
||||||
|
)
|
||||||
|
data class RefTextEmbedding(
|
||||||
|
@PrimaryKey(autoGenerate = true) val id: Long = 0,
|
||||||
|
/** assets 相对路径,如 ref/一年级.../xxx.txt */
|
||||||
|
val assetPath: String,
|
||||||
|
/** 参与嵌入的正文 UTF-8 的 SHA-256 十六进制,用于跳过未变更文件 */
|
||||||
|
val contentHash: String,
|
||||||
|
val dim: Int,
|
||||||
|
val embedding: ByteArray,
|
||||||
|
val updatedAt: Long = System.currentTimeMillis()
|
||||||
|
) {
|
||||||
|
fun toFloatArray(): FloatArray = embeddingBytesToFloatArray(embedding)
|
||||||
|
}
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
package com.digitalperson.data.util
|
||||||
|
|
||||||
|
import java.nio.ByteBuffer
|
||||||
|
import java.nio.ByteOrder
|
||||||
|
|
||||||
|
fun floatArrayToEmbeddingBytes(values: FloatArray): ByteArray {
|
||||||
|
val bb = ByteBuffer.allocate(values.size * 4).order(ByteOrder.LITTLE_ENDIAN)
|
||||||
|
for (v in values) {
|
||||||
|
bb.putFloat(v)
|
||||||
|
}
|
||||||
|
return bb.array()
|
||||||
|
}
|
||||||
|
|
||||||
|
fun embeddingBytesToFloatArray(blob: ByteArray): FloatArray {
|
||||||
|
val bb = ByteBuffer.wrap(blob).order(ByteOrder.LITTLE_ENDIAN)
|
||||||
|
val n = blob.size / 4
|
||||||
|
return FloatArray(n) { bb.getFloat() }
|
||||||
|
}
|
||||||
@@ -0,0 +1,48 @@
|
|||||||
|
package com.digitalperson.embedding
|
||||||
|
|
||||||
|
import android.content.Context
|
||||||
|
import com.digitalperson.config.AppConfig
|
||||||
|
import com.digitalperson.engine.BgeEngineRKNN
|
||||||
|
import com.digitalperson.util.FileHelper
|
||||||
|
import java.io.File
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 懒加载 BGE 文本嵌入(RKNN)。用于对话文本与预存标注的语义相似度检索。
|
||||||
|
*
|
||||||
|
* 初始化会复制 [AppConfig.Bge.ASSET_DIR] 下资源到内部存储并加载 [AppConfig.Bge.MODEL_FILE]。
|
||||||
|
*/
|
||||||
|
object BgeEmbedding {
|
||||||
|
|
||||||
|
@Volatile
|
||||||
|
private var engine: BgeEngineRKNN? = null
|
||||||
|
|
||||||
|
fun isReady(): Boolean = engine?.isInitialized == true
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 在主线程调用会阻塞;建议在后台线程或协程 [Dispatchers.IO] 中调用。
|
||||||
|
*/
|
||||||
|
@Synchronized
|
||||||
|
fun initialize(context: Context): Boolean {
|
||||||
|
if (engine?.isInitialized == true) return true
|
||||||
|
val dir = FileHelper.copyBgeModels(context.applicationContext) ?: return false
|
||||||
|
val path = File(dir, AppConfig.Bge.MODEL_FILE).absolutePath
|
||||||
|
if (!File(path).exists()) return false
|
||||||
|
val eng = BgeEngineRKNN(context.applicationContext)
|
||||||
|
if (!eng.initialize(path)) return false
|
||||||
|
engine = eng
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
@Synchronized
|
||||||
|
fun release() {
|
||||||
|
engine?.deinitialize()
|
||||||
|
engine = null
|
||||||
|
}
|
||||||
|
|
||||||
|
fun getEmbedding(text: String): FloatArray? = engine?.getEmbedding(text)
|
||||||
|
|
||||||
|
fun similarity(text1: String, text2: String): Float? =
|
||||||
|
engine?.calculateSimilarity(text1, text2)
|
||||||
|
|
||||||
|
fun embeddingDim(): Int = engine?.embeddingDim ?: -1
|
||||||
|
}
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
package com.digitalperson.embedding
|
||||||
|
|
||||||
|
import android.content.Context
|
||||||
|
|
||||||
|
internal object RefCorpusAssetScanner {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 递归列出 [root] 目录下(含子目录)所有 `.txt` 的 assets 路径,使用 `/` 分隔。
|
||||||
|
*/
|
||||||
|
fun listTxtFilesUnder(context: Context, root: String): List<String> {
|
||||||
|
val am = context.assets
|
||||||
|
val out = ArrayList<String>()
|
||||||
|
|
||||||
|
fun walk(prefix: String) {
|
||||||
|
val children = am.list(prefix) ?: return
|
||||||
|
if (children.isEmpty()) return
|
||||||
|
for (child in children) {
|
||||||
|
if (child.isEmpty()) continue
|
||||||
|
val full = "$prefix/$child"
|
||||||
|
val grand = am.list(full)
|
||||||
|
if (grand.isNullOrEmpty()) {
|
||||||
|
if (child.endsWith(".txt", ignoreCase = true)) {
|
||||||
|
out.add(full)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
walk(full)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
walk(root.removeSuffix("/"))
|
||||||
|
return out.sorted()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,173 @@
|
|||||||
|
package com.digitalperson.embedding
|
||||||
|
|
||||||
|
import android.content.Context
|
||||||
|
import android.util.Log
|
||||||
|
import com.digitalperson.config.AppConfig
|
||||||
|
import com.digitalperson.data.AppDatabase
|
||||||
|
import com.digitalperson.data.entity.Question
|
||||||
|
import com.digitalperson.data.entity.RefTextEmbedding
|
||||||
|
import com.digitalperson.data.util.floatArrayToEmbeddingBytes
|
||||||
|
import kotlinx.coroutines.Dispatchers
|
||||||
|
import kotlinx.coroutines.withContext
|
||||||
|
import java.security.MessageDigest
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 开机(进程启动)时在后台扫描 assets 下 [AppConfig.RefCorpus.ASSETS_ROOT] 中全部 `.txt`,
|
||||||
|
* 跳过 `#` 行后做 BGE 嵌入,写入 Room 并填充 [RefEmbeddingMemoryCache]。
|
||||||
|
*
|
||||||
|
* DAO 为同步接口,整个函数在 [Dispatchers.IO] 上运行,不阻塞主线程。
|
||||||
|
*/
|
||||||
|
object RefEmbeddingIndexer {
|
||||||
|
|
||||||
|
private const val TAG = AppConfig.TAG
|
||||||
|
|
||||||
|
suspend fun runOnce(context: Context) = withContext(Dispatchers.IO) {
|
||||||
|
val app = context.applicationContext
|
||||||
|
val db = AppDatabase.getInstance(app)
|
||||||
|
val dao = db.refTextEmbeddingDao()
|
||||||
|
val questionDao = db.questionDao()
|
||||||
|
|
||||||
|
if (!BgeEmbedding.initialize(app)) {
|
||||||
|
Log.e(TAG, "[RefEmbed] BGE 初始化失败,跳过 ref 语料索引")
|
||||||
|
return@withContext
|
||||||
|
}
|
||||||
|
|
||||||
|
val root = AppConfig.RefCorpus.ASSETS_ROOT
|
||||||
|
val paths = RefCorpusAssetScanner.listTxtFilesUnder(app, root)
|
||||||
|
Log.i(TAG, "[RefEmbed] 发现 ${paths.size} 个 txt(root=$root)")
|
||||||
|
|
||||||
|
var skipped = 0
|
||||||
|
var embedded = 0
|
||||||
|
var empty = 0
|
||||||
|
var failed = 0
|
||||||
|
|
||||||
|
for (path in paths) {
|
||||||
|
val raw = try {
|
||||||
|
app.assets.open(path).bufferedReader(Charsets.UTF_8).use { it.readText() }
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.w(TAG, "[RefEmbed] 读取失败 $path: ${e.message}")
|
||||||
|
failed++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 题库:遇到包含 ?/? 的行,写入 questions
|
||||||
|
val subject = extractSubjectFromRaw(raw)
|
||||||
|
val grade = extractGradeFromPath(path)
|
||||||
|
val questionLines = extractQuestionLines(raw)
|
||||||
|
for (line in questionLines) {
|
||||||
|
val content = line.trim()
|
||||||
|
if (content.isEmpty()) continue
|
||||||
|
val exists = questionDao.findByContentSubjectGrade(content, subject, grade)
|
||||||
|
if (exists == null) {
|
||||||
|
questionDao.insert(
|
||||||
|
Question(
|
||||||
|
id = 0,
|
||||||
|
content = content,
|
||||||
|
answer = null,
|
||||||
|
subject = subject,
|
||||||
|
grade = grade,
|
||||||
|
difficulty = 1,
|
||||||
|
createdAt = System.currentTimeMillis()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val embedText = RefTxtEmbedText.fromRawFileContent(raw)
|
||||||
|
if (embedText.isEmpty()) {
|
||||||
|
empty++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
val hash = sha256Hex(embedText.toByteArray(Charsets.UTF_8))
|
||||||
|
// 同步 DAO 调用(已在 IO 线程)
|
||||||
|
val existing = dao.getByPath(path)
|
||||||
|
if (existing != null && existing.contentHash == hash) {
|
||||||
|
RefEmbeddingMemoryCache.put(path, normalizeL2(existing.toFloatArray()))
|
||||||
|
skipped++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
val vec = BgeEmbedding.getEmbedding(embedText)
|
||||||
|
if (vec == null || vec.isEmpty()) {
|
||||||
|
Log.w(TAG, "[RefEmbed] 嵌入为空 $path")
|
||||||
|
failed++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
val normalized = normalizeL2(vec)
|
||||||
|
dao.insert(
|
||||||
|
RefTextEmbedding(
|
||||||
|
assetPath = path,
|
||||||
|
contentHash = hash,
|
||||||
|
dim = normalized.size,
|
||||||
|
embedding = floatArrayToEmbeddingBytes(normalized)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
RefEmbeddingMemoryCache.put(path, normalized)
|
||||||
|
embedded++
|
||||||
|
}
|
||||||
|
|
||||||
|
Log.i(
|
||||||
|
TAG,
|
||||||
|
"[RefEmbed] 完成 embedded=$embedded skipped=$skipped empty=$empty failed=$failed cacheSize=${RefEmbeddingMemoryCache.size()}"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun extractSubjectFromRaw(raw: String): String? {
|
||||||
|
val line = raw.lineSequence()
|
||||||
|
.map { it.trimEnd() }
|
||||||
|
.firstOrNull { it.trimStart().startsWith("#") }
|
||||||
|
?: return null
|
||||||
|
val s = line.trimStart().removePrefix("#").trim()
|
||||||
|
return s.ifEmpty { null }
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun extractQuestionLines(raw: String): List<String> {
|
||||||
|
return raw.lineSequence()
|
||||||
|
.map { it.trimEnd() }
|
||||||
|
.filter { it.isNotBlank() }
|
||||||
|
.filter { !it.trimStart().startsWith("#") }
|
||||||
|
.filter { it.contains('?') || it.contains('?') }
|
||||||
|
.toList()
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun extractGradeFromPath(assetPath: String): Int? {
|
||||||
|
// example: ref/一年级上-生活适应/... or ref/二年级下-...
|
||||||
|
val idx = assetPath.indexOf("年级")
|
||||||
|
if (idx <= 0) return null
|
||||||
|
val prefix = assetPath.substring(0, idx)
|
||||||
|
val cn = prefix.lastOrNull() ?: return null
|
||||||
|
return chineseGradeToInt(cn)
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun chineseGradeToInt(c: Char): Int? {
|
||||||
|
return when (c) {
|
||||||
|
'一' -> 1
|
||||||
|
'二' -> 2
|
||||||
|
'三' -> 3
|
||||||
|
'四' -> 4
|
||||||
|
'五' -> 5
|
||||||
|
'六' -> 6
|
||||||
|
'七' -> 7
|
||||||
|
'八' -> 8
|
||||||
|
'九' -> 9
|
||||||
|
'十' -> 10
|
||||||
|
else -> null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun normalizeL2(v: FloatArray): FloatArray {
|
||||||
|
var sum = 0.0
|
||||||
|
for (x in v) sum += (x * x).toDouble()
|
||||||
|
val norm = kotlin.math.sqrt(sum).toFloat()
|
||||||
|
if (norm <= 1e-12f) return v.copyOf()
|
||||||
|
return FloatArray(v.size) { i -> v[i] / norm }
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun sha256Hex(data: ByteArray): String {
|
||||||
|
val md = MessageDigest.getInstance("SHA-256")
|
||||||
|
val digest = md.digest(data)
|
||||||
|
return digest.joinToString("") { "%02x".format(it) }
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
package com.digitalperson.embedding
|
||||||
|
|
||||||
|
import java.util.concurrent.ConcurrentHashMap
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ref 语料 txt 嵌入的内存缓存,键为 assets 相对路径(与数据库 [com.digitalperson.data.entity.RefTextEmbedding.assetPath] 一致)。
|
||||||
|
*/
|
||||||
|
object RefEmbeddingMemoryCache {
|
||||||
|
|
||||||
|
private val vectors = ConcurrentHashMap<String, FloatArray>()
|
||||||
|
|
||||||
|
fun put(assetPath: String, embedding: FloatArray) {
|
||||||
|
vectors[assetPath] = embedding.copyOf()
|
||||||
|
}
|
||||||
|
|
||||||
|
fun get(assetPath: String): FloatArray? {
|
||||||
|
val v = vectors[assetPath] ?: return null
|
||||||
|
return v.copyOf()
|
||||||
|
}
|
||||||
|
|
||||||
|
fun clear() {
|
||||||
|
vectors.clear()
|
||||||
|
}
|
||||||
|
|
||||||
|
/** 只读快照(向量已为拷贝,可安全使用)。 */
|
||||||
|
fun snapshot(): Map<String, FloatArray> =
|
||||||
|
vectors.mapValues { it.value.copyOf() }
|
||||||
|
|
||||||
|
fun size(): Int = vectors.size
|
||||||
|
}
|
||||||
@@ -0,0 +1,95 @@
|
|||||||
|
package com.digitalperson.embedding
|
||||||
|
|
||||||
|
import android.content.Context
|
||||||
|
import android.util.Log
|
||||||
|
import com.digitalperson.config.AppConfig
|
||||||
|
import kotlin.math.sqrt
|
||||||
|
|
||||||
|
data class RefImageMatch(
|
||||||
|
val txtAssetPath: String,
|
||||||
|
val pngAssetPath: String,
|
||||||
|
val score: Float
|
||||||
|
)
|
||||||
|
|
||||||
|
object RefImageMatcher {
|
||||||
|
|
||||||
|
private const val TAG = AppConfig.TAG
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param threshold 余弦相似度阈值(向量已归一化时等价于 dot product)。
|
||||||
|
*/
|
||||||
|
fun findBestMatch(
|
||||||
|
context: Context,
|
||||||
|
text: String,
|
||||||
|
threshold: Float = 0.70f
|
||||||
|
): RefImageMatch? {
|
||||||
|
val query = text.trim()
|
||||||
|
if (query.isEmpty()) return null
|
||||||
|
|
||||||
|
if (!BgeEmbedding.isReady()) {
|
||||||
|
val ok = BgeEmbedding.initialize(context.applicationContext)
|
||||||
|
if (!ok) {
|
||||||
|
Log.w(TAG, "[RefMatch] BGE not ready, skip match")
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val q = BgeEmbedding.getEmbedding(query) ?: return null
|
||||||
|
if (q.isEmpty()) return null
|
||||||
|
val qn = normalizeL2(q)
|
||||||
|
|
||||||
|
val vectors = RefEmbeddingMemoryCache.snapshot()
|
||||||
|
if (vectors.isEmpty()) return null
|
||||||
|
|
||||||
|
var bestPath: String? = null
|
||||||
|
var bestScore = -1f
|
||||||
|
|
||||||
|
for ((path, v) in vectors) {
|
||||||
|
if (v.isEmpty() || v.size != qn.size) continue
|
||||||
|
val score = dot(qn, v)
|
||||||
|
if (score > bestScore) {
|
||||||
|
bestScore = score
|
||||||
|
bestPath = path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val txtPath = bestPath ?: return null
|
||||||
|
if (bestScore < threshold) return null
|
||||||
|
|
||||||
|
val pngPath = if (txtPath.endsWith(".txt", ignoreCase = true)) {
|
||||||
|
txtPath.dropLast(4) + ".png"
|
||||||
|
} else {
|
||||||
|
"$txtPath.png"
|
||||||
|
}
|
||||||
|
|
||||||
|
// 不在这里 decode,只检查是否存在,避免 UI 线程 IO。
|
||||||
|
val exists = try {
|
||||||
|
context.assets.open(pngPath).close()
|
||||||
|
true
|
||||||
|
} catch (_: Throwable) {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
if (!exists) return null
|
||||||
|
|
||||||
|
return RefImageMatch(
|
||||||
|
txtAssetPath = txtPath,
|
||||||
|
pngAssetPath = pngPath,
|
||||||
|
score = bestScore
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun dot(a: FloatArray, b: FloatArray): Float {
|
||||||
|
var s = 0f
|
||||||
|
for (i in a.indices) s += a[i] * b[i]
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun normalizeL2(v: FloatArray): FloatArray {
|
||||||
|
var sum = 0.0
|
||||||
|
for (x in v) sum += (x * x).toDouble()
|
||||||
|
val norm = sqrt(sum).toFloat()
|
||||||
|
if (norm <= 1e-12f) return v.copyOf()
|
||||||
|
return FloatArray(v.size) { i -> v[i] / norm }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
package com.digitalperson.embedding
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 从原始 txt 中构造待嵌入文本:去掉「以 # 开头」的行(行首可含空白),其余行以换行连接。
|
||||||
|
*/
|
||||||
|
object RefTxtEmbedText {
|
||||||
|
|
||||||
|
fun fromRawFileContent(raw: String): String {
|
||||||
|
return raw.lineSequence()
|
||||||
|
.map { it.trimEnd() }
|
||||||
|
.filter { line ->
|
||||||
|
if (line.isEmpty()) return@filter false
|
||||||
|
!line.trimStart().startsWith("#")
|
||||||
|
}
|
||||||
|
.joinToString("\n")
|
||||||
|
.trim()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,353 @@
|
|||||||
|
package com.digitalperson.embedding;
|
||||||
|
|
||||||
|
import android.content.Context;
|
||||||
|
import android.os.Handler;
|
||||||
|
import android.os.Looper;
|
||||||
|
import android.util.Log;
|
||||||
|
|
||||||
|
import com.digitalperson.config.AppConfig;
|
||||||
|
import com.digitalperson.engine.BgeEngineRKNN;
|
||||||
|
import com.digitalperson.util.FileHelper;
|
||||||
|
|
||||||
|
import org.ejml.simple.SimpleMatrix;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class SimilarityManager {
|
||||||
|
private static final String TAG = "SimilarityManager";
|
||||||
|
|
||||||
|
private Context mContext;
|
||||||
|
private Handler mHandler;
|
||||||
|
private BgeEngineRKNN mBgeEngine;
|
||||||
|
|
||||||
|
// 测试数据相关
|
||||||
|
private List<float[]> mTestEmbeddings; // 预计算的100个句子的嵌入(float数组格式)
|
||||||
|
private List<SimpleMatrix> mTestEmbeddingsEJML; // 预计算的100个句子的嵌入(EJML格式)
|
||||||
|
private SimpleMatrix mTestEmbeddingsMatrix; // 所有测试嵌入的矩阵(嵌入维度 × 句子数量)
|
||||||
|
private List<Double> mTestEmbeddingsNorms; // 预计算的100个句子的嵌入范数
|
||||||
|
private SimpleMatrix mTestEmbeddingsNormsMatrix; // 预计算的嵌入范数向量(1 × 句子数量)
|
||||||
|
private SimpleMatrix mTestEmbeddingsNormsReciprocalMatrix; // 预计算的嵌入范数倒数向量(1 × 句子数量)
|
||||||
|
private List<String> mTestSentences; // 预生成的100个测试句子
|
||||||
|
private boolean useEJMLForSimilarity = false; // 默认使用传统for循环计算相似度
|
||||||
|
|
||||||
|
public interface SimilarityListener {
|
||||||
|
void onSimilarityCalculated(float similarity, long timeTaken);
|
||||||
|
void onPerformanceTestComplete(long traditionalTime, long ejmlTime, double speedup);
|
||||||
|
void onError(String errorMessage);
|
||||||
|
}
|
||||||
|
|
||||||
|
private SimilarityListener mListener;
|
||||||
|
|
||||||
|
public SimilarityManager(Context context) {
|
||||||
|
this.mContext = context;
|
||||||
|
this.mHandler = new Handler(Looper.getMainLooper());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setListener(SimilarityListener listener) {
|
||||||
|
this.mListener = listener;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 初始化BGE模型
|
||||||
|
public void initBgeModel() {
|
||||||
|
try {
|
||||||
|
File bgeModelDir = FileHelper.copyBgeModels(mContext);
|
||||||
|
if (bgeModelDir == null) {
|
||||||
|
Log.e(TAG, "BGE model directory copy failed");
|
||||||
|
if (mListener != null) {
|
||||||
|
mListener.onError("BGE 模型文件复制失败");
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 初始化BGE模型
|
||||||
|
mBgeEngine = new BgeEngineRKNN(mContext);
|
||||||
|
String modelPath = new File(bgeModelDir, AppConfig.Bge.MODEL_FILE).getAbsolutePath();
|
||||||
|
|
||||||
|
// 检查模型文件是否存在
|
||||||
|
if (!new File(modelPath).exists()) {
|
||||||
|
Log.e(TAG, "BGE model file does not exist: " + modelPath);
|
||||||
|
if (mListener != null) {
|
||||||
|
mListener.onError("BGE 模型文件不存在");
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean success = mBgeEngine.initialize(modelPath);
|
||||||
|
if (!success) {
|
||||||
|
Log.e(TAG, "Failed to initialize BGE model");
|
||||||
|
if (mListener != null) {
|
||||||
|
mListener.onError("BGE 模型初始化失败");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
Log.e(TAG, "Exception in initBgeModel", e);
|
||||||
|
if (mListener != null) {
|
||||||
|
mListener.onError("初始化BGE模型异常: " + e.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算BGE相似度
|
||||||
|
public void calculateSimilarity(String text1, String text2) {
|
||||||
|
if (mBgeEngine == null || !mBgeEngine.isInitialized()) {
|
||||||
|
Log.w(TAG, "BGE engine not initialized, skipping similarity calculation");
|
||||||
|
if (mListener != null) {
|
||||||
|
mListener.onError("BGE 模型未初始化");
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 在后台线程中计算相似度
|
||||||
|
new Thread(() -> {
|
||||||
|
try {
|
||||||
|
long startTime = System.currentTimeMillis();
|
||||||
|
float similarity = mBgeEngine.calculateSimilarity(text1, text2);
|
||||||
|
long timeTaken = System.currentTimeMillis() - startTime;
|
||||||
|
|
||||||
|
if (mListener != null) {
|
||||||
|
mListener.onSimilarityCalculated(similarity, timeTaken);
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
Log.e(TAG, "Exception in calculateSimilarity", e);
|
||||||
|
if (mListener != null) {
|
||||||
|
mListener.onError("相似度计算失败: " + e.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}).start();
|
||||||
|
}
|
||||||
|
|
||||||
|
// 测试BGE性能
|
||||||
|
public void testPerformance(String userInput) {
|
||||||
|
if (mBgeEngine == null || !mBgeEngine.isInitialized()) {
|
||||||
|
Log.w(TAG, "BGE engine not initialized, skipping performance test");
|
||||||
|
if (mListener != null) {
|
||||||
|
mListener.onError("BGE 模型未初始化");
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 在后台线程中执行测试
|
||||||
|
new Thread(() -> {
|
||||||
|
try {
|
||||||
|
// 准备测试数据
|
||||||
|
prepareTestData();
|
||||||
|
|
||||||
|
// 1. 计算用户输入的嵌入
|
||||||
|
float[] userEmbedding = mBgeEngine.getEmbedding(userInput);
|
||||||
|
if (userEmbedding == null) {
|
||||||
|
if (mListener != null) {
|
||||||
|
mListener.onError("无法计算用户输入的嵌入");
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建EJML格式的用户嵌入
|
||||||
|
SimpleMatrix userVec = new SimpleMatrix(userEmbedding.length, 1);
|
||||||
|
for (int i = 0; i < userEmbedding.length; i++) {
|
||||||
|
userVec.set(i, 0, userEmbedding[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 方法1: 使用float数组和循环计算相似度(传统方法)
|
||||||
|
long startTime1 = System.currentTimeMillis();
|
||||||
|
List<Float> similarities1 = new ArrayList<>();
|
||||||
|
|
||||||
|
// 优化传统方法:使用范数倒数和乘法代替除法
|
||||||
|
double normUserTraditional = 0.0;
|
||||||
|
for (float value : userEmbedding) {
|
||||||
|
normUserTraditional += value * value;
|
||||||
|
}
|
||||||
|
normUserTraditional = Math.sqrt(normUserTraditional);
|
||||||
|
double normUserReciprocalTraditional = (normUserTraditional > 1e-10) ? 1.0 / normUserTraditional : 0.0;
|
||||||
|
|
||||||
|
for (int i = 0; i < mTestEmbeddings.size(); i++) {
|
||||||
|
float[] embedding = mTestEmbeddings.get(i);
|
||||||
|
// 计算点积
|
||||||
|
double dotProduct = 0.0;
|
||||||
|
for (int j = 0; j < userEmbedding.length; j++) {
|
||||||
|
dotProduct += userEmbedding[j] * embedding[j];
|
||||||
|
}
|
||||||
|
// 使用乘法代替除法:cos(u,v_i) = (u·v_i) × (1/||u||) × (1/||v_i||)
|
||||||
|
double normEmbeddingReciprocal = mTestEmbeddingsNorms.get(i) > 1e-10 ? 1.0 / mTestEmbeddingsNorms.get(i) : 0.0;
|
||||||
|
float sim = (float) (dotProduct * normUserReciprocalTraditional * normEmbeddingReciprocal);
|
||||||
|
similarities1.add(sim);
|
||||||
|
}
|
||||||
|
long timeTaken1 = System.currentTimeMillis() - startTime1;
|
||||||
|
|
||||||
|
// 方法2: 使用EJML计算相似度(优化方法)
|
||||||
|
long startTime2 = System.currentTimeMillis();
|
||||||
|
List<Float> similarities2 = new ArrayList<>();
|
||||||
|
|
||||||
|
// 优化1: 预计算用户向量的范数和范数倒数
|
||||||
|
double normUser = userVec.normF();
|
||||||
|
double normUserReciprocal = (normUser > 1e-10) ? 1.0 / normUser : 0.0;
|
||||||
|
|
||||||
|
// 优化2: 使用批量矩阵计算一次性计算所有点积和相似度
|
||||||
|
if (mTestEmbeddingsMatrix != null && mTestEmbeddingsNormsReciprocalMatrix != null) {
|
||||||
|
// 使用矩阵乘法一次性计算所有点积
|
||||||
|
SimpleMatrix dotProducts = userVec.transpose().mult(mTestEmbeddingsMatrix);
|
||||||
|
|
||||||
|
// 一次性计算所有相似度
|
||||||
|
// 使用乘法代替除法:cos(u,v_i) = (u·v_i) × (1/||u||) × (1/||v_i||)
|
||||||
|
// 步骤1: 计算 (u·v_i) × (1/||u||)
|
||||||
|
SimpleMatrix dotProductsScaled = dotProducts.scale(normUserReciprocal);
|
||||||
|
// 步骤2: 逐元素乘法计算最终相似度
|
||||||
|
SimpleMatrix similaritiesMatrix = dotProductsScaled.elementMult(mTestEmbeddingsNormsReciprocalMatrix);
|
||||||
|
|
||||||
|
// 将结果转换为列表
|
||||||
|
for (int i = 0; i < similaritiesMatrix.numCols(); i++) {
|
||||||
|
similarities2.add((float) similaritiesMatrix.get(0, i));
|
||||||
|
}
|
||||||
|
} else if (mTestEmbeddingsMatrix != null) {
|
||||||
|
// 降级方案1: 只有嵌入矩阵,使用批量点积 + 循环计算相似度
|
||||||
|
SimpleMatrix dotProducts = userVec.transpose().mult(mTestEmbeddingsMatrix);
|
||||||
|
for (int i = 0; i < dotProducts.numCols(); i++) {
|
||||||
|
double dotProduct = dotProducts.get(0, i);
|
||||||
|
double normEmbeddingReciprocal = mTestEmbeddingsNorms.get(i) > 1e-10 ? 1.0 / mTestEmbeddingsNorms.get(i) : 0.0;
|
||||||
|
float sim = (float) (dotProduct * normUserReciprocal * normEmbeddingReciprocal);
|
||||||
|
similarities2.add(sim);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 降级方案2: 使用循环计算(当矩阵未初始化时)
|
||||||
|
for (int i = 0; i < mTestEmbeddingsEJML.size(); i++) {
|
||||||
|
SimpleMatrix embedding = mTestEmbeddingsEJML.get(i);
|
||||||
|
double dotProduct = userVec.dot(embedding);
|
||||||
|
double normEmbeddingReciprocal = mTestEmbeddingsNorms.get(i) > 1e-10 ? 1.0 / mTestEmbeddingsNorms.get(i) : 0.0;
|
||||||
|
float sim = (float) (dotProduct * normUserReciprocal * normEmbeddingReciprocal);
|
||||||
|
similarities2.add(sim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
long timeTaken2 = System.currentTimeMillis() - startTime2;
|
||||||
|
|
||||||
|
// 计算速度提升
|
||||||
|
double speedup = (double) timeTaken1 / timeTaken2;
|
||||||
|
|
||||||
|
if (mListener != null) {
|
||||||
|
mListener.onPerformanceTestComplete(timeTaken1, timeTaken2, speedup);
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
Log.e(TAG, "Exception in testPerformance", e);
|
||||||
|
if (mListener != null) {
|
||||||
|
mListener.onError("性能测试失败: " + e.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}).start();
|
||||||
|
}
|
||||||
|
|
||||||
|
// 准备测试数据
|
||||||
|
private void prepareTestData() {
|
||||||
|
if (mTestEmbeddings == null || mTestEmbeddings.isEmpty() ||
|
||||||
|
mTestEmbeddingsEJML == null || mTestEmbeddingsEJML.isEmpty() ||
|
||||||
|
mTestEmbeddingsNorms == null || mTestEmbeddingsNorms.isEmpty() ||
|
||||||
|
mTestEmbeddingsMatrix == null ||
|
||||||
|
mTestEmbeddingsNormsMatrix == null ||
|
||||||
|
mTestEmbeddingsNormsReciprocalMatrix == null) {
|
||||||
|
// 生成100个测试句子
|
||||||
|
generateTestSentences();
|
||||||
|
|
||||||
|
// 预计算嵌入(float数组格式)
|
||||||
|
mTestEmbeddings = new ArrayList<>();
|
||||||
|
mTestEmbeddingsEJML = new ArrayList<>();
|
||||||
|
mTestEmbeddingsNorms = new ArrayList<>();
|
||||||
|
List<String> validTestSentences = new ArrayList<>(); // 只包含成功获取嵌入的句子
|
||||||
|
|
||||||
|
for (String sentence : mTestSentences) {
|
||||||
|
float[] embedding = mBgeEngine.getEmbedding(sentence);
|
||||||
|
if (embedding != null) {
|
||||||
|
// 添加float数组格式的嵌入
|
||||||
|
mTestEmbeddings.add(embedding);
|
||||||
|
|
||||||
|
// 添加EJML格式的嵌入
|
||||||
|
SimpleMatrix vec = new SimpleMatrix(embedding.length, 1);
|
||||||
|
for (int i = 0; i < embedding.length; i++) {
|
||||||
|
vec.set(i, 0, embedding[i]);
|
||||||
|
}
|
||||||
|
mTestEmbeddingsEJML.add(vec);
|
||||||
|
|
||||||
|
// 预计算嵌入的范数
|
||||||
|
double norm = vec.normF();
|
||||||
|
mTestEmbeddingsNorms.add(norm);
|
||||||
|
|
||||||
|
// 添加到有效句子列表
|
||||||
|
validTestSentences.add(sentence);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新mTestSentences为只包含有效句子的列表,确保索引匹配
|
||||||
|
mTestSentences = validTestSentences;
|
||||||
|
|
||||||
|
// 创建测试嵌入矩阵(嵌入维度 × 句子数量)
|
||||||
|
if (!mTestEmbeddingsEJML.isEmpty()) {
|
||||||
|
int embeddingDim = mTestEmbeddingsEJML.get(0).numRows();
|
||||||
|
int numSentences = mTestEmbeddingsEJML.size();
|
||||||
|
mTestEmbeddingsMatrix = new SimpleMatrix(embeddingDim, numSentences);
|
||||||
|
|
||||||
|
for (int i = 0; i < numSentences; i++) {
|
||||||
|
SimpleMatrix embedding = mTestEmbeddingsEJML.get(i);
|
||||||
|
for (int j = 0; j < embeddingDim; j++) {
|
||||||
|
mTestEmbeddingsMatrix.set(j, i, embedding.get(j, 0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Log.d(TAG, "Created test embeddings matrix, size: " + embeddingDim + " × " + numSentences);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建测试嵌入范数矩阵(1 × 句子数量)
|
||||||
|
if (!mTestEmbeddingsNorms.isEmpty()) {
|
||||||
|
int numSentences = mTestEmbeddingsNorms.size();
|
||||||
|
mTestEmbeddingsNormsMatrix = new SimpleMatrix(1, numSentences);
|
||||||
|
mTestEmbeddingsNormsReciprocalMatrix = new SimpleMatrix(1, numSentences);
|
||||||
|
for (int i = 0; i < numSentences; i++) {
|
||||||
|
double norm = mTestEmbeddingsNorms.get(i);
|
||||||
|
mTestEmbeddingsNormsMatrix.set(0, i, norm);
|
||||||
|
// 预计算范数倒数,避免在相似度计算时重复计算
|
||||||
|
double reciprocal = (norm > 1e-10) ? 1.0 / norm : 0.0;
|
||||||
|
mTestEmbeddingsNormsReciprocalMatrix.set(0, i, reciprocal);
|
||||||
|
}
|
||||||
|
Log.d(TAG, "Created test embeddings norms matrix, size: 1 × " + numSentences);
|
||||||
|
Log.d(TAG, "Created test embeddings norms reciprocal matrix, size: 1 × " + numSentences);
|
||||||
|
}
|
||||||
|
|
||||||
|
Log.d(TAG, "Generated embeddings for " + mTestEmbeddings.size() + " test sentences");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成测试句子
|
||||||
|
private void generateTestSentences() {
|
||||||
|
mTestSentences = new ArrayList<>();
|
||||||
|
|
||||||
|
// 基础词汇
|
||||||
|
String[] words = {"手", "头", "眼睛", "鼻子", "嘴巴", "耳朵", "头发", "手指", "脚", "腿"};
|
||||||
|
String[] adjectives = {"大", "小", "长", "短", "多", "少", "漂亮", "丑陋", "干净", "脏"};
|
||||||
|
String[] verbs = {"有", "是", "看", "听", "说", "走", "跑", "跳", "吃", "喝"};
|
||||||
|
String[] subjects = {"我", "你", "他", "她", "它", "我们", "你们", "他们", "这个", "那个"};
|
||||||
|
|
||||||
|
// 添加一些固定的测试句子
|
||||||
|
mTestSentences.add("我的手有很多手指头");
|
||||||
|
mTestSentences.add("他的头发很长");
|
||||||
|
mTestSentences.add("她的眼睛很大");
|
||||||
|
mTestSentences.add("这个鼻子很高");
|
||||||
|
mTestSentences.add("那个嘴巴很小");
|
||||||
|
|
||||||
|
Log.d(TAG, "Generated " + mTestSentences.size() + " test sentences");
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置是否使用EJML进行相似度计算
|
||||||
|
public void setUseEJMLForSimilarity(boolean useEJML) {
|
||||||
|
this.useEJMLForSimilarity = useEJML;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查BGE模型是否初始化
|
||||||
|
public boolean isInitialized() {
|
||||||
|
return mBgeEngine != null && mBgeEngine.isInitialized();
|
||||||
|
}
|
||||||
|
|
||||||
|
// 释放BGE模型资源
|
||||||
|
public void deinitialize() {
|
||||||
|
if (mBgeEngine != null) {
|
||||||
|
mBgeEngine.deinitialize();
|
||||||
|
mBgeEngine = null;
|
||||||
|
Log.d(TAG, "BGE engine deinitialized");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
104
app/src/main/java/com/digitalperson/engine/BasicTokenizer.java
Normal file
104
app/src/main/java/com/digitalperson/engine/BasicTokenizer.java
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
package com.digitalperson.engine;
|
||||||
|
|
||||||
|
import com.google.common.base.Ascii;
|
||||||
|
import com.google.common.collect.Iterables;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/** Basic tokenization (punctuation splitting, lower casing, etc.) */
|
||||||
|
public final class BasicTokenizer {
|
||||||
|
private final boolean doLowerCase;
|
||||||
|
|
||||||
|
public BasicTokenizer(boolean doLowerCase) {
|
||||||
|
this.doLowerCase = doLowerCase;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<String> tokenize(String text) {
|
||||||
|
String cleanedText = cleanText(text);
|
||||||
|
|
||||||
|
List<String> origTokens = whitespaceTokenize(cleanedText);
|
||||||
|
|
||||||
|
StringBuilder stringBuilder = new StringBuilder();
|
||||||
|
for (String token : origTokens) {
|
||||||
|
if (doLowerCase) {
|
||||||
|
token = Ascii.toLowerCase(token);
|
||||||
|
}
|
||||||
|
List<String> list = runSplitOnPunc(token);
|
||||||
|
for (String subToken : list) {
|
||||||
|
stringBuilder.append(subToken).append(" ");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return whitespaceTokenize(stringBuilder.toString());
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Performs invalid character removal and whitespace cleanup on text. */
|
||||||
|
static String cleanText(String text) {
|
||||||
|
if (text == null) {
|
||||||
|
throw new NullPointerException("The input String is null.");
|
||||||
|
}
|
||||||
|
|
||||||
|
StringBuilder stringBuilder = new StringBuilder("");
|
||||||
|
for (int index = 0; index < text.length(); index++) {
|
||||||
|
char ch = text.charAt(index);
|
||||||
|
|
||||||
|
// Skip the characters that cannot be used.
|
||||||
|
if (CharChecker.isInvalid(ch) || CharChecker.isControl(ch)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (CharChecker.isWhitespace(ch)) {
|
||||||
|
stringBuilder.append(" ");
|
||||||
|
} else {
|
||||||
|
stringBuilder.append(ch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return stringBuilder.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Runs basic whitespace cleaning and splitting on a piece of text. */
|
||||||
|
static List<String> whitespaceTokenize(String text) {
|
||||||
|
if (text == null) {
|
||||||
|
throw new NullPointerException("The input String is null.");
|
||||||
|
}
|
||||||
|
return Arrays.asList(text.split(" "));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Splits punctuation on a piece of text. */
|
||||||
|
static List<String> runSplitOnPunc(String text) {
|
||||||
|
if (text == null) {
|
||||||
|
throw new NullPointerException("The input String is null.");
|
||||||
|
}
|
||||||
|
|
||||||
|
List<String> tokens = new ArrayList<>();
|
||||||
|
boolean startNewWord = true;
|
||||||
|
for (int i = 0; i < text.length(); i++) {
|
||||||
|
char ch = text.charAt(i);
|
||||||
|
if (CharChecker.isPunctuation(ch)) {
|
||||||
|
tokens.add(String.valueOf(ch));
|
||||||
|
startNewWord = true;
|
||||||
|
} else {
|
||||||
|
if (startNewWord) {
|
||||||
|
tokens.add("");
|
||||||
|
startNewWord = false;
|
||||||
|
}
|
||||||
|
tokens.set(tokens.size() - 1, Iterables.getLast(tokens) + ch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens;
|
||||||
|
}
|
||||||
|
}
|
||||||
263
app/src/main/java/com/digitalperson/engine/BgeEngineRKNN.java
Normal file
263
app/src/main/java/com/digitalperson/engine/BgeEngineRKNN.java
Normal file
@@ -0,0 +1,263 @@
|
|||||||
|
package com.digitalperson.engine;
|
||||||
|
|
||||||
|
import android.content.Context;
|
||||||
|
import android.util.Log;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class BgeEngineRKNN {
|
||||||
|
private static final String TAG = "BgeEngineRKNN";
|
||||||
|
private final long nativePtr;
|
||||||
|
private final Context mContext;
|
||||||
|
private boolean mIsInitialized = false;
|
||||||
|
private FullTokenizer mFullTokenizer = null;
|
||||||
|
private Map<String, Integer> mVocabMap = new HashMap<>();
|
||||||
|
|
||||||
|
static {
|
||||||
|
try {
|
||||||
|
// Load dependent libraries
|
||||||
|
System.loadLibrary("rknnrt");
|
||||||
|
System.loadLibrary("bgeEngine");
|
||||||
|
Log.d(TAG, "Successfully loaded librknnrt.so and libbgeEngine.so");
|
||||||
|
} catch (UnsatisfiedLinkError e) {
|
||||||
|
Log.e(TAG, "Failed to load native library", e);
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public BgeEngineRKNN(Context context) {
|
||||||
|
mContext = context;
|
||||||
|
try {
|
||||||
|
nativePtr = createBgeEngine();
|
||||||
|
if (nativePtr == 0) {
|
||||||
|
throw new RuntimeException("Failed to create native BGE engine");
|
||||||
|
}
|
||||||
|
} catch (UnsatisfiedLinkError e) {
|
||||||
|
Log.e(TAG, "Failed to load native library", e);
|
||||||
|
throw new RuntimeException("Failed to load native library: " + e.getMessage(), e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean initialize(String modelPath) {
|
||||||
|
// 自动查找词汇表路径
|
||||||
|
String vocabPath = modelPath.replace("bge-small-zh-v1.5.rknn", "vocab.txt");
|
||||||
|
return initialize(modelPath, vocabPath);
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean initialize(String modelPath, String vocabPath) {
|
||||||
|
if (mIsInitialized) {
|
||||||
|
Log.i(TAG, "Model already initialized");
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
Log.d(TAG, "Loading BGE model: " + modelPath);
|
||||||
|
Log.d(TAG, "Loading vocab: " + vocabPath);
|
||||||
|
|
||||||
|
// 加载词汇表
|
||||||
|
if (!loadVocab(vocabPath)) {
|
||||||
|
Log.e(TAG, "Failed to load vocab");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建FullTokenizer实例
|
||||||
|
mFullTokenizer = new FullTokenizer(mVocabMap, true); // true表示使用小写
|
||||||
|
|
||||||
|
int ret = loadModel(nativePtr, modelPath, vocabPath);
|
||||||
|
if (ret == 0) {
|
||||||
|
mIsInitialized = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load vocabulary from file
|
||||||
|
* @param vocabPath Path to vocabulary file
|
||||||
|
* @return True if successful, false otherwise
|
||||||
|
*/
|
||||||
|
private boolean loadVocab(String vocabPath) {
|
||||||
|
try {
|
||||||
|
java.io.BufferedReader reader = new java.io.BufferedReader(
|
||||||
|
new java.io.FileReader(vocabPath));
|
||||||
|
String line;
|
||||||
|
int index = 0;
|
||||||
|
while ((line = reader.readLine()) != null) {
|
||||||
|
line = line.trim();
|
||||||
|
if (!line.isEmpty()) {
|
||||||
|
mVocabMap.put(line, index);
|
||||||
|
index++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
reader.close();
|
||||||
|
Log.d(TAG, "Vocab loaded successfully with " + mVocabMap.size() + " tokens");
|
||||||
|
return true;
|
||||||
|
} catch (java.io.IOException e) {
|
||||||
|
Log.e(TAG, "Failed to load vocab: " + e.getMessage());
|
||||||
|
e.printStackTrace();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void deinitialize() {
|
||||||
|
if (nativePtr != 0) {
|
||||||
|
freeModel(nativePtr);
|
||||||
|
}
|
||||||
|
mIsInitialized = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isInitialized() {
|
||||||
|
return mIsInitialized;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the embedding dimension of the model
|
||||||
|
* @return Embedding dimension
|
||||||
|
*/
|
||||||
|
public int getEmbeddingDim() {
|
||||||
|
if (!mIsInitialized) {
|
||||||
|
Log.e(TAG, "Engine not initialized");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
return getEmbeddingDim(nativePtr);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate embedding for a single text
|
||||||
|
* @param text Input text
|
||||||
|
* @return Embedding vector as float array
|
||||||
|
*/
|
||||||
|
public float[] getEmbedding(String text) {
|
||||||
|
if (!mIsInitialized) {
|
||||||
|
Log.e(TAG, "Engine not initialized");
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return getEmbeddingNative(nativePtr, text);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate cosine similarity between two texts
|
||||||
|
* @param text1 First text
|
||||||
|
* @param text2 Second text
|
||||||
|
* @return Cosine similarity score between -1.0 and 1.0
|
||||||
|
*/
|
||||||
|
public float calculateSimilarity(String text1, String text2) {
|
||||||
|
if (!mIsInitialized) {
|
||||||
|
Log.e(TAG, "Engine not initialized");
|
||||||
|
return 0.0f;
|
||||||
|
}
|
||||||
|
return calculateSimilarityNative(nativePtr, text1, text2);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get token ids for a text using FullTokenizer
|
||||||
|
* @param text Input text
|
||||||
|
* @return Token ids as int array
|
||||||
|
*/
|
||||||
|
public int[] getTokens(String text) {
|
||||||
|
if (!mIsInitialized) {
|
||||||
|
Log.e(TAG, "Engine not initialized");
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
if (mFullTokenizer == null) {
|
||||||
|
Log.e(TAG, "FullTokenizer not initialized");
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用FullTokenizer进行tokenization
|
||||||
|
java.util.List<String> tokens = mFullTokenizer.tokenize(text);
|
||||||
|
java.util.List<Integer> ids = mFullTokenizer.convertTokensToIds(tokens);
|
||||||
|
|
||||||
|
// 转换为int数组
|
||||||
|
int[] result = new int[ids.size()];
|
||||||
|
for (int i = 0; i < ids.size(); i++) {
|
||||||
|
result[i] = ids.get(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate [UNK] ratio for token ids
|
||||||
|
* @param tokens Token ids array
|
||||||
|
* @return [UNK] ratio between 0.0 and 1.0
|
||||||
|
*/
|
||||||
|
public float calculateUnkRatio(int[] tokens) {
|
||||||
|
if (!mIsInitialized) {
|
||||||
|
Log.e(TAG, "Engine not initialized");
|
||||||
|
return 0.0f;
|
||||||
|
}
|
||||||
|
if (mFullTokenizer == null || tokens == null || tokens.length == 0) {
|
||||||
|
return 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取[UNK]的id
|
||||||
|
int unkId = mVocabMap.getOrDefault("[UNK]", -1);
|
||||||
|
if (unkId == -1) {
|
||||||
|
return 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 统计[UNK]的数量
|
||||||
|
int unkCount = 0;
|
||||||
|
for (int token : tokens) {
|
||||||
|
if (token == unkId) {
|
||||||
|
unkCount++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return (float) unkCount / tokens.length;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get BERT model inputs (input_ids, attention_mask, token_type_ids)
|
||||||
|
* @param text Input text
|
||||||
|
* @param maxSeqLength Maximum sequence length
|
||||||
|
* @return Array of int arrays: [input_ids, attention_mask, token_type_ids]
|
||||||
|
*/
|
||||||
|
public int[][] getBertInputs(String text, int maxSeqLength) {
|
||||||
|
if (!mIsInitialized) {
|
||||||
|
Log.e(TAG, "Engine not initialized");
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
if (mFullTokenizer == null) {
|
||||||
|
Log.e(TAG, "FullTokenizer not initialized");
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用FullTokenizer进行tokenization
|
||||||
|
java.util.List<String> tokens = mFullTokenizer.tokenize(text);
|
||||||
|
java.util.List<Integer> ids = mFullTokenizer.convertTokensToIds(tokens);
|
||||||
|
|
||||||
|
// 准备BERT输入
|
||||||
|
int[] inputIds = new int[maxSeqLength];
|
||||||
|
int[] attentionMask = new int[maxSeqLength];
|
||||||
|
int[] tokenTypeIds = new int[maxSeqLength];
|
||||||
|
|
||||||
|
// 填充[CLS] token
|
||||||
|
inputIds[0] = mVocabMap.getOrDefault("[CLS]", 101);
|
||||||
|
attentionMask[0] = 1;
|
||||||
|
|
||||||
|
// 填充文本token
|
||||||
|
int textLength = ids.size();
|
||||||
|
int maxTextLength = maxSeqLength - 2; // 减去[CLS]和[SEP]
|
||||||
|
int actualLength = Math.min(textLength, maxTextLength);
|
||||||
|
|
||||||
|
for (int i = 0; i < actualLength; i++) {
|
||||||
|
inputIds[i + 1] = ids.get(i);
|
||||||
|
attentionMask[i + 1] = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 填充[SEP] token
|
||||||
|
inputIds[actualLength + 1] = mVocabMap.getOrDefault("[SEP]", 102);
|
||||||
|
attentionMask[actualLength + 1] = 1;
|
||||||
|
|
||||||
|
return new int[][]{inputIds, attentionMask, tokenTypeIds};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Native methods
|
||||||
|
private native long createBgeEngine();
|
||||||
|
private native int loadModel(long nativePtr, String modelPath, String vocabPath);
|
||||||
|
private native void freeModel(long ptr);
|
||||||
|
private native float[] getEmbeddingNative(long ptr, String text);
|
||||||
|
private native int getEmbeddingDim(long ptr);
|
||||||
|
private native float calculateSimilarityNative(long ptr, String text1, String text2);
|
||||||
|
}
|
||||||
58
app/src/main/java/com/digitalperson/engine/CharChecker.java
Normal file
58
app/src/main/java/com/digitalperson/engine/CharChecker.java
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
package com.digitalperson.engine;
|
||||||
|
|
||||||
|
/** To check whether a char is whitespace/control/punctuation. */
|
||||||
|
final class CharChecker {
|
||||||
|
|
||||||
|
/** To judge whether it's an empty or unknown character. */
|
||||||
|
public static boolean isInvalid(char ch) {
|
||||||
|
return (ch == 0 || ch == 0xfffd);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** To judge whether it's a control character(exclude whitespace). */
|
||||||
|
public static boolean isControl(char ch) {
|
||||||
|
if (Character.isWhitespace(ch)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
int type = Character.getType(ch);
|
||||||
|
return (type == Character.CONTROL || type == Character.FORMAT);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** To judge whether it can be regarded as a whitespace. */
|
||||||
|
public static boolean isWhitespace(char ch) {
|
||||||
|
if (Character.isWhitespace(ch)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
int type = Character.getType(ch);
|
||||||
|
return (type == Character.SPACE_SEPARATOR
|
||||||
|
|| type == Character.LINE_SEPARATOR
|
||||||
|
|| type == Character.PARAGRAPH_SEPARATOR);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** To judge whether it's a punctuation. */
|
||||||
|
public static boolean isPunctuation(char ch) {
|
||||||
|
int type = Character.getType(ch);
|
||||||
|
return (type == Character.CONNECTOR_PUNCTUATION
|
||||||
|
|| type == Character.DASH_PUNCTUATION
|
||||||
|
|| type == Character.START_PUNCTUATION
|
||||||
|
|| type == Character.END_PUNCTUATION
|
||||||
|
|| type == Character.INITIAL_QUOTE_PUNCTUATION
|
||||||
|
|| type == Character.FINAL_QUOTE_PUNCTUATION
|
||||||
|
|| type == Character.OTHER_PUNCTUATION);
|
||||||
|
}
|
||||||
|
|
||||||
|
private CharChecker() {}
|
||||||
|
}
|
||||||
@@ -0,0 +1,52 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
package com.digitalperson.engine;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A java realization of Bert tokenization. Original python code:
|
||||||
|
* https://github.com/google-research/bert/blob/master/tokenization.py runs full tokenization to
|
||||||
|
* tokenize a String into split subtokens or ids.
|
||||||
|
*/
|
||||||
|
public final class FullTokenizer {
|
||||||
|
private final BasicTokenizer basicTokenizer;
|
||||||
|
private final WordpieceTokenizer wordpieceTokenizer;
|
||||||
|
private final Map<String, Integer> dic;
|
||||||
|
|
||||||
|
public FullTokenizer(Map<String, Integer> inputDic, boolean doLowerCase) {
|
||||||
|
dic = inputDic;
|
||||||
|
basicTokenizer = new BasicTokenizer(doLowerCase);
|
||||||
|
wordpieceTokenizer = new WordpieceTokenizer(inputDic);
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<String> tokenize(String text) {
|
||||||
|
List<String> splitTokens = new ArrayList<>();
|
||||||
|
for (String token : basicTokenizer.tokenize(text)) {
|
||||||
|
splitTokens.addAll(wordpieceTokenizer.tokenize(token));
|
||||||
|
}
|
||||||
|
return splitTokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<Integer> convertTokensToIds(List<String> tokens) {
|
||||||
|
List<Integer> outputIds = new ArrayList<>();
|
||||||
|
for (String token : tokens) {
|
||||||
|
outputIds.add(dic.get(token));
|
||||||
|
}
|
||||||
|
return outputIds;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,94 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
package com.digitalperson.engine;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/** Word piece tokenization to split a piece of text into its word pieces. */
|
||||||
|
public final class WordpieceTokenizer {
|
||||||
|
private final Map<String, Integer> dic;
|
||||||
|
|
||||||
|
private static final String UNKNOWN_TOKEN = "[UNK]"; // For unknown words.
|
||||||
|
private static final int MAX_INPUTCHARS_PER_WORD = 200;
|
||||||
|
|
||||||
|
public WordpieceTokenizer(Map<String, Integer> vocab) {
|
||||||
|
dic = vocab;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first
|
||||||
|
* algorithm to perform tokenization using the given vocabulary. For example: input = "unaffable",
|
||||||
|
* output = ["un", "##aff", "##able"].
|
||||||
|
*
|
||||||
|
* @param text: A single token or whitespace separated tokens. This should have already been
|
||||||
|
* passed through `BasicTokenizer.
|
||||||
|
* @return A list of wordpiece tokens.
|
||||||
|
*/
|
||||||
|
public List<String> tokenize(String text) {
|
||||||
|
if (text == null) {
|
||||||
|
throw new NullPointerException("The input String is null.");
|
||||||
|
}
|
||||||
|
|
||||||
|
List<String> outputTokens = new ArrayList<>();
|
||||||
|
for (String token : BasicTokenizer.whitespaceTokenize(text)) {
|
||||||
|
|
||||||
|
if (token.length() > MAX_INPUTCHARS_PER_WORD) {
|
||||||
|
outputTokens.add(UNKNOWN_TOKEN);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean isBad = false; // Mark if a word cannot be tokenized into known subwords.
|
||||||
|
int start = 0;
|
||||||
|
List<String> subTokens = new ArrayList<>();
|
||||||
|
|
||||||
|
while (start < token.length()) {
|
||||||
|
String curSubStr = "";
|
||||||
|
|
||||||
|
int end = token.length(); // Longer substring matches first.
|
||||||
|
while (start < end) {
|
||||||
|
String subStr =
|
||||||
|
(start == 0) ? token.substring(start, end) : "##" + token.substring(start, end);
|
||||||
|
if (dic.containsKey(subStr)) {
|
||||||
|
curSubStr = subStr;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
end--;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The word doesn't contain any known subwords.
|
||||||
|
if ("".equals(curSubStr)) {
|
||||||
|
isBad = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// curSubStr is the longeset subword that can be found.
|
||||||
|
subTokens.add(curSubStr);
|
||||||
|
|
||||||
|
// Proceed to tokenize the resident string.
|
||||||
|
start = end;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isBad) {
|
||||||
|
outputTokens.add(UNKNOWN_TOKEN);
|
||||||
|
} else {
|
||||||
|
outputTokens.addAll(subTokens);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return outputTokens;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -185,7 +185,7 @@ class QuestionGenerationAgent(
|
|||||||
val generationPrompt = buildGenerationPrompt(prompt, userProfile)
|
val generationPrompt = buildGenerationPrompt(prompt, userProfile)
|
||||||
|
|
||||||
// 3. 调用大模型生成题目
|
// 3. 调用大模型生成题目
|
||||||
generateQuestionFromLLM(generationPrompt) { generatedQuestion ->
|
generateQuestionFromLLM(generationPrompt, prompt) { generatedQuestion ->
|
||||||
if (generatedQuestion == null) {
|
if (generatedQuestion == null) {
|
||||||
Log.w(TAG, "Failed to generate question")
|
Log.w(TAG, "Failed to generate question")
|
||||||
return@generateQuestionFromLLM
|
return@generateQuestionFromLLM
|
||||||
@@ -301,18 +301,22 @@ class QuestionGenerationAgent(
|
|||||||
/**
|
/**
|
||||||
* 调用LLM生成题目
|
* 调用LLM生成题目
|
||||||
*/
|
*/
|
||||||
private fun generateQuestionFromLLM(prompt: String, onResult: (GeneratedQuestion?) -> Unit) {
|
private fun generateQuestionFromLLM(
|
||||||
|
promptText: String,
|
||||||
|
promptMeta: QuestionPrompt,
|
||||||
|
onResult: (GeneratedQuestion?) -> Unit
|
||||||
|
) {
|
||||||
// 优先使用本地LLM,如果不可用则使用云端LLM
|
// 优先使用本地LLM,如果不可用则使用云端LLM
|
||||||
if (llmManager != null) {
|
if (llmManager != null) {
|
||||||
// 使用本地LLM
|
// 使用本地LLM
|
||||||
llmManager.generate(prompt) { response: String ->
|
llmManager.generate(promptText) { response: String ->
|
||||||
parseGeneratedQuestion(response, onResult)
|
parseGeneratedQuestion(response, promptMeta, onResult)
|
||||||
}
|
}
|
||||||
} else if (cloudLLMGenerator != null) {
|
} else if (cloudLLMGenerator != null) {
|
||||||
// 使用云端LLM
|
// 使用云端LLM
|
||||||
Log.d(TAG, "Using cloud LLM to generate question")
|
Log.d(TAG, "Using cloud LLM to generate question")
|
||||||
cloudLLMGenerator.invoke(prompt) { response ->
|
cloudLLMGenerator.invoke(promptText) { response ->
|
||||||
parseGeneratedQuestion(response, onResult)
|
parseGeneratedQuestion(response, promptMeta, onResult)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
Log.e(TAG, "No LLM available (neither local nor cloud)")
|
Log.e(TAG, "No LLM available (neither local nor cloud)")
|
||||||
@@ -323,16 +327,20 @@ class QuestionGenerationAgent(
|
|||||||
/**
|
/**
|
||||||
* 解析生成的题目
|
* 解析生成的题目
|
||||||
*/
|
*/
|
||||||
private fun parseGeneratedQuestion(response: String, onResult: (GeneratedQuestion?) -> Unit) {
|
private fun parseGeneratedQuestion(
|
||||||
|
response: String,
|
||||||
|
promptMeta: QuestionPrompt,
|
||||||
|
onResult: (GeneratedQuestion?) -> Unit
|
||||||
|
) {
|
||||||
try {
|
try {
|
||||||
val json = extractJsonFromResponse(response)
|
val json = extractJsonFromResponse(response)
|
||||||
if (json != null) {
|
if (json != null) {
|
||||||
val question = GeneratedQuestion(
|
val question = GeneratedQuestion(
|
||||||
content = json.getString("content"),
|
content = json.getString("content"),
|
||||||
answer = json.getString("answer"),
|
answer = json.getString("answer"),
|
||||||
subject = "生活适应",
|
subject = promptMeta.subject,
|
||||||
grade = 1,
|
grade = promptMeta.grade,
|
||||||
difficulty = 1
|
difficulty = promptMeta.difficulty
|
||||||
)
|
)
|
||||||
onResult(question)
|
onResult(question)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import android.util.Log
|
|||||||
import com.digitalperson.config.AppConfig
|
import com.digitalperson.config.AppConfig
|
||||||
import java.io.File
|
import java.io.File
|
||||||
import java.io.FileOutputStream
|
import java.io.FileOutputStream
|
||||||
|
import java.io.InputStream
|
||||||
|
|
||||||
object FileHelper {
|
object FileHelper {
|
||||||
private const val TAG = AppConfig.TAG
|
private const val TAG = AppConfig.TAG
|
||||||
@@ -65,6 +66,54 @@ object FileHelper {
|
|||||||
return copyAssetsToInternal(context, AppConfig.FaceRecognition.MODEL_DIR, outDir, files)
|
return copyAssetsToInternal(context, AppConfig.FaceRecognition.MODEL_DIR, outDir, files)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 将 BGE 相关文件从 assets 复制到 [Context.getFilesDir]/[AppConfig.Bge.ASSET_DIR]。
|
||||||
|
* 若已存在且长度与 asset 一致则跳过(与 [com.digitalperson.embedding.SimilarityManager] 行为一致)。
|
||||||
|
*/
|
||||||
|
@JvmStatic
|
||||||
|
fun copyBgeModels(context: Context): File? {
|
||||||
|
val assetDir = AppConfig.Bge.ASSET_DIR
|
||||||
|
val modelDir = File(context.filesDir, assetDir).apply { mkdirs() }
|
||||||
|
val files = arrayOf(
|
||||||
|
AppConfig.Bge.MODEL_FILE,
|
||||||
|
"vocab.txt",
|
||||||
|
"tokenizer.json",
|
||||||
|
"tokenizer_config.json"
|
||||||
|
)
|
||||||
|
for (name in files) {
|
||||||
|
val outFile = File(modelDir, name)
|
||||||
|
val assetPath = "$assetDir/$name"
|
||||||
|
try {
|
||||||
|
var skip = false
|
||||||
|
if (outFile.exists()) {
|
||||||
|
context.assets.open(assetPath).use { input ->
|
||||||
|
val assetSize = assetSizeOrNegative(input)
|
||||||
|
if (assetSize >= 0 && outFile.length() == assetSize) skip = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (skip) continue
|
||||||
|
context.assets.open(assetPath).use { input ->
|
||||||
|
FileOutputStream(outFile).use { output -> input.copyTo(output) }
|
||||||
|
}
|
||||||
|
Log.i(TAG, "Copied BGE asset: $name")
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.e(TAG, "copyBgeModels failed for $name: ${e.message}", e)
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return modelDir
|
||||||
|
}
|
||||||
|
|
||||||
|
/** [InputStream.available] 在部分实现上不可靠;失败时返回 -1,强制重拷。 */
|
||||||
|
private fun assetSizeOrNegative(input: InputStream): Long {
|
||||||
|
return try {
|
||||||
|
val n = input.available()
|
||||||
|
if (n > 0) n.toLong() else -1L
|
||||||
|
} catch (_: Exception) {
|
||||||
|
-1L
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fun ensureDir(dir: File): File {
|
fun ensureDir(dir: File): File {
|
||||||
if (!dir.exists()) {
|
if (!dir.exists()) {
|
||||||
val created = dir.mkdirs()
|
val created = dir.mkdirs()
|
||||||
|
|||||||
@@ -38,6 +38,18 @@
|
|||||||
android:layout_height="match_parent" />
|
android:layout_height="match_parent" />
|
||||||
</FrameLayout>
|
</FrameLayout>
|
||||||
|
|
||||||
|
<ImageView
|
||||||
|
android:id="@+id/ref_match_image"
|
||||||
|
android:layout_width="200dp"
|
||||||
|
android:layout_height="200dp"
|
||||||
|
android:layout_margin="12dp"
|
||||||
|
android:background="#66000000"
|
||||||
|
android:scaleType="fitCenter"
|
||||||
|
android:visibility="gone"
|
||||||
|
app:layout_constraintBottom_toTopOf="@+id/button_row"
|
||||||
|
app:layout_constraintStart_toStartOf="parent"
|
||||||
|
tools:visibility="visible" />
|
||||||
|
|
||||||
<ScrollView
|
<ScrollView
|
||||||
android:id="@+id/scroll_view"
|
android:id="@+id/scroll_view"
|
||||||
android:layout_width="0dp"
|
android:layout_width="0dp"
|
||||||
|
|||||||
@@ -31,6 +31,18 @@
|
|||||||
android:textIsSelectable="true" />
|
android:textIsSelectable="true" />
|
||||||
</ScrollView>
|
</ScrollView>
|
||||||
|
|
||||||
|
<ImageView
|
||||||
|
android:id="@+id/ref_match_image"
|
||||||
|
android:layout_width="200dp"
|
||||||
|
android:layout_height="200dp"
|
||||||
|
android:layout_margin="12dp"
|
||||||
|
android:background="#66000000"
|
||||||
|
android:scaleType="fitCenter"
|
||||||
|
android:visibility="gone"
|
||||||
|
app:layout_constraintBottom_toTopOf="@+id/button_row"
|
||||||
|
app:layout_constraintStart_toStartOf="parent"
|
||||||
|
tools:visibility="visible" />
|
||||||
|
|
||||||
<LinearLayout
|
<LinearLayout
|
||||||
android:id="@+id/button_row"
|
android:id="@+id/button_row"
|
||||||
android:layout_width="0dp"
|
android:layout_width="0dp"
|
||||||
|
|||||||
Reference in New Issue
Block a user