word similarity

This commit is contained in:
gcw_4spBpAfv
2026-04-21 23:05:59 +08:00
parent e23aaaa4ba
commit 1550783eef
36 changed files with 44822 additions and 12 deletions

View File

@@ -4,10 +4,24 @@ plugins {
id 'kotlin-kapt' id 'kotlin-kapt'
} }
kapt {
// Room uses javac stubs under kapt; keep parameter names for :bind variables.
javacOptions {
option("-parameters")
}
}
android { android {
namespace 'com.digitalperson' namespace 'com.digitalperson'
compileSdk 34 compileSdk 34
sourceSets {
main {
// app/note/ref → assets 中为 ref/...(与 AppConfig.RefCorpus.ASSETS_ROOT 一致)
assets.srcDirs = ['src/main/assets', 'note']
}
}
buildFeatures { buildFeatures {
buildConfig true buildConfig true
} }
@@ -100,4 +114,9 @@ dependencies {
implementation project(':tuanjieLibrary') implementation project(':tuanjieLibrary')
implementation files('../tuanjieLibrary/libs/unity-classes.jar') implementation files('../tuanjieLibrary/libs/unity-classes.jar')
// BGE tokenizer (BasicTokenizer) + SimilarityManager 批量相似度测试
implementation 'com.google.guava:guava:31.1-android'
implementation 'org.ejml:ejml-core:0.43.1'
implementation 'org.ejml:ejml-simple:0.43.1'
} }

View File

@@ -9,6 +9,7 @@
<uses-feature android:name="android.hardware.camera.any" /> <uses-feature android:name="android.hardware.camera.any" />
<application <application
android:name="com.digitalperson.DigitalPersonApp"
android:allowBackup="true" android:allowBackup="true"
android:label="@string/app_name" android:label="@string/app_name"
android:supportsRtl="true" android:supportsRtl="true"

View File

@@ -0,0 +1,5 @@
Required assets for BGE (see AppConfig.Bge):
bge-small-zh-v1.5.rknn, vocab.txt, tokenizer.json, tokenizer_config.json
First run: FileHelper.copyBgeModels / BgeEmbedding.initialize / SimilarityManager.initBgeModel
copies from assets/bge_models to internal storage.

Binary file not shown.

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,15 @@
{
"clean_up_tokenization_spaces": true,
"cls_token": "[CLS]",
"do_basic_tokenize": true,
"do_lower_case": false,
"mask_token": "[MASK]",
"model_max_length": 512,
"never_split": null,
"pad_token": "[PAD]",
"sep_token": "[SEP]",
"strip_accents": null,
"tokenize_chinese_chars": true,
"tokenizer_class": "BertTokenizer",
"unk_token": "[UNK]"
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,504 @@
#include "BgeEngineRKNN.h"
#include <android/log.h>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <cmath>
#include <fstream>
#include <sstream>
#include <unordered_map>
#define LOG_TAG "BgeEngineRKNN"
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__)
#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__)
#define LOGD(...) __android_log_print(ANDROID_LOG_DEBUG, LOG_TAG, __VA_ARGS__)
BgeEngineRKNN::BgeEngineRKNN() : m_ctx(0), m_initialized(false), m_embedding_dim(768) {
// 默认嵌入维度为768可根据模型实际情况调整
}
BgeEngineRKNN::~BgeEngineRKNN() {
freeModel();
}
int BgeEngineRKNN::loadModel(const char* modelPath, const char* vocabPath) {
if (m_initialized) {
LOGI("Model already loaded");
return 0;
}
// 加载词汇表(如果提供了路径)
if (vocabPath != nullptr) {
int vocab_ret = loadVocab(vocabPath);
if (vocab_ret != 0) {
LOGE("Failed to load vocab: %d", vocab_ret);
return vocab_ret;
}
LOGI("Vocab loaded successfully with %zu tokens", m_vocab.size());
} else {
LOGE("No vocab path provided, using fallback tokenization");
}
int ret = initRKNN(modelPath);
if (ret != 0) {
LOGE("Failed to initialize RKNN: %d", ret);
return ret;
}
m_initialized = true;
LOGI("BGE model loaded successfully");
return 0;
}
void BgeEngineRKNN::freeModel() {
if (m_initialized) {
deinitRKNN();
m_initialized = false;
LOGI("BGE model freed");
}
}
float* BgeEngineRKNN::getEmbedding(const std::string& text) {
if (!m_initialized) {
LOGE("Model not initialized");
return nullptr;
}
// 简单的tokenize实现需要替换为完整的tokenizer
std::vector<int> tokens = tokenize(text);
if (tokens.empty()) {
LOGE("Tokenization failed");
return nullptr;
}
// 推理获取嵌入
float* embedding = inferEmbedding(tokens);
if (embedding == nullptr) {
LOGE("Failed to get embedding");
return nullptr;
}
return embedding;
}
float BgeEngineRKNN::calculateSimilarity(const std::string& text1, const std::string& text2) {
if (!m_initialized) {
LOGE("Model not initialized");
return 0.0f;
}
// 获取两个文本的嵌入
float* embedding1 = getEmbedding(text1);
if (embedding1 == nullptr) {
return 0.0f;
}
float* embedding2 = getEmbedding(text2);
if (embedding2 == nullptr) {
delete[] embedding1;
return 0.0f;
}
// 归一化嵌入
std::vector<float> normalized1 = normalizeEmbedding(embedding1, m_embedding_dim);
std::vector<float> normalized2 = normalizeEmbedding(embedding2, m_embedding_dim);
// 计算余弦相似度
float similarity = cosineSimilarity(normalized1, normalized2);
// 释放内存
delete[] embedding1;
delete[] embedding2;
return similarity;
}
int BgeEngineRKNN::getEmbeddingDim() const {
return m_embedding_dim;
}
std::vector<int> BgeEngineRKNN::tokenize(const std::string& text) {
std::vector<int> tokens;
tokens.push_back(101); // [CLS] token
// 如果词汇表已加载使用完整的tokenization流程
if (!m_vocab.empty()) {
// 步骤1: 基本分词
std::vector<std::string> basic_tokens = basicTokenize(text);
// 步骤2: WordPiece分词
std::vector<std::string> word_pieces = wordPieceTokenize(basic_tokens);
// 步骤3: 转换为token id
for (const std::string& piece : word_pieces) {
if (tokens.size() > 510) break; // 限制最大长度
auto it = m_vocab.find(piece);
if (it != m_vocab.end()) {
tokens.push_back(it->second);
} else {
// 对于未在词表中的token使用[UNK] token
auto unk_it = m_vocab.find("[UNK]");
if (unk_it != m_vocab.end()) {
tokens.push_back(unk_it->second);
} else {
// 如果[UNK]也不存在使用0作为fallback
tokens.push_back(0);
}
}
}
} else {
// fallback: 使用字符级分词(当词汇表未加载时)
size_t i = 0;
while (i < text.size() && tokens.size() <= 510) {
unsigned char c = static_cast<unsigned char>(text[i]);
int token_id;
// 处理不同长度的UTF-8字符
if (c < 0x80) {
// ASCII字符 (1字节)
token_id = static_cast<int>(c) + 100;
i++;
} else if (c < 0xE0) {
// 2字节UTF-8字符
token_id = 3000 + ((c & 0x1F) << 6) + (static_cast<unsigned char>(text[i+1]) & 0x3F);
i += 2;
} else if (c < 0xF0) {
// 3字节UTF-8字符常见中文
token_id = 4000 + ((c & 0x0F) << 12) + ((static_cast<unsigned char>(text[i+1]) & 0x3F) << 6) + (static_cast<unsigned char>(text[i+2]) & 0x3F);
i += 3;
} else {
// 4字节UTF-8字符
token_id = 8000 + ((c & 0x07) << 18) + ((static_cast<unsigned char>(text[i+1]) & 0x3F) << 12) + ((static_cast<unsigned char>(text[i+2]) & 0x3F) << 6) + (static_cast<unsigned char>(text[i+3]) & 0x3F);
i += 4;
}
// 确保token_id在合理范围内
token_id = token_id % 30000;
tokens.push_back(token_id);
}
}
tokens.push_back(102); // [SEP] token
return tokens;
}
int BgeEngineRKNN::loadVocab(const char* vocabPath) {
std::ifstream vocab_file(vocabPath);
if (!vocab_file.is_open()) {
LOGE("Failed to open vocab file: %s", vocabPath);
return -1;
}
std::string line;
int id = 0;
while (std::getline(vocab_file, line)) {
if (!line.empty()) {
m_vocab[line] = id;
id++;
}
}
vocab_file.close();
if (m_vocab.empty()) {
LOGE("Vocab file is empty");
return -2;
}
return 0;
}
std::vector<std::string> BgeEngineRKNN::basicTokenize(const std::string& text) {
std::vector<std::string> tokens;
size_t i = 0;
while (i < text.size()) {
unsigned char c = static_cast<unsigned char>(text[i]);
if (std::isspace(c)) {
// 跳过空格
i++;
} else if (c < 0x80) {
// ASCII字符按空格切分
std::string token;
while (i < text.size() && !std::isspace(static_cast<unsigned char>(text[i])) && static_cast<unsigned char>(text[i]) < 0x80) {
token += text[i];
i++;
}
if (!token.empty()) {
tokens.push_back(token);
}
} else {
// 非ASCII字符如中文按字符切分
// 处理UTF-8编码的字符
if (c < 0xE0) {
// 2字节UTF-8字符
if (i + 1 < text.size()) {
tokens.push_back(text.substr(i, 2));
i += 2;
} else {
i++;
}
} else if (c < 0xF0) {
// 3字节UTF-8字符常见中文
if (i + 2 < text.size()) {
tokens.push_back(text.substr(i, 3));
i += 3;
} else {
i++;
}
} else {
// 4字节UTF-8字符
if (i + 3 < text.size()) {
tokens.push_back(text.substr(i, 4));
i += 4;
} else {
i++;
}
}
}
}
return tokens;
}
std::vector<std::string> BgeEngineRKNN::wordPieceTokenize(const std::vector<std::string>& basic_tokens) {
std::vector<std::string> word_pieces;
for (const std::string& token : basic_tokens) {
if (m_vocab.find(token) != m_vocab.end()) {
word_pieces.push_back(token);
} else {
// 简单的字符级分词作为 fallback
for (char c : token) {
word_pieces.push_back(std::string(1, c));
}
}
}
return word_pieces;
}
std::vector<float> BgeEngineRKNN::normalizeEmbedding(const float* embedding, int dim) {
std::vector<float> normalized(dim);
float norm = 0.0f;
// 计算L2范数
for (int i = 0; i < dim; i++) {
norm += embedding[i] * embedding[i];
}
norm = std::sqrt(norm);
// 归一化
if (norm > 0.0f) {
for (int i = 0; i < dim; i++) {
normalized[i] = embedding[i] / norm;
}
}
return normalized;
}
float BgeEngineRKNN::cosineSimilarity(const std::vector<float>& vec1, const std::vector<float>& vec2) {
if (vec1.size() != vec2.size()) {
LOGE("Vector dimension mismatch");
return 0.0f;
}
float dot_product = 0.0f;
for (size_t i = 0; i < vec1.size(); i++) {
dot_product += vec1[i] * vec2[i];
}
return dot_product;
}
int BgeEngineRKNN::initRKNN(const char* modelPath) {
int ret = 0;
rknn_context ctx = 0;
// 初始化RKNN
ret = rknn_init(&ctx, (void*)modelPath, 0, 0, NULL);
if (ret != RKNN_SUCC) {
LOGE("rknn_init failed: %d", ret);
return ret;
}
m_ctx = ctx;
// 查询模型输入输出信息
rknn_input_output_num io_num;
ret = rknn_query(m_ctx, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num));
if (ret != RKNN_SUCC) {
LOGE("rknn_query input_output_num failed: %d", ret);
rknn_destroy(ctx);
return ret;
}
LOGI("Input number: %d, Output number: %d", io_num.n_input, io_num.n_output);
// 查询所有输入属性
for (int i = 0; i < io_num.n_input; i++) {
rknn_tensor_attr input_attr;
memset(&input_attr, 0, sizeof(input_attr));
input_attr.index = i;
ret = rknn_query(m_ctx, RKNN_QUERY_INPUT_ATTR, &input_attr, sizeof(input_attr));
if (ret != RKNN_SUCC) {
LOGE("rknn_query input_attr %d failed: %d", i, ret);
rknn_destroy(ctx);
return ret;
}
LOGI("Input %d info:", i);
LOGI(" type: %d", input_attr.type);
LOGI(" size: %d", input_attr.size);
LOGI(" n_dims: %d", input_attr.n_dims);
LOGI(" dims:");
for (int j = 0; j < input_attr.n_dims; j++) {
LOGI(" dim[%d] = %d", j, input_attr.dims[j]);
}
}
// 查询所有输出属性
for (int i = 0; i < io_num.n_output; i++) {
rknn_tensor_attr output_attr;
memset(&output_attr, 0, sizeof(output_attr));
output_attr.index = i;
ret = rknn_query(m_ctx, RKNN_QUERY_OUTPUT_ATTR, &output_attr, sizeof(output_attr));
if (ret != RKNN_SUCC) {
LOGE("rknn_query output_attr %d failed: %d", i, ret);
rknn_destroy(ctx);
return ret;
}
// 打印输出信息
LOGI("Output %d dims:", i);
for (int j = 0; j < output_attr.n_dims; j++) {
LOGI(" dim[%d] = %d", j, output_attr.dims[j]);
}
// 检查是否为 output1 且维度符合要求
if (i == 1) {
bool is_valid_output = false;
if (output_attr.n_dims == 2 && output_attr.dims[0] == 1 && output_attr.dims[1] == 512) {
is_valid_output = true;
} else if (output_attr.n_dims == 3 && output_attr.dims[0] == 1 && output_attr.dims[1] == 1 && output_attr.dims[2] == 512) {
is_valid_output = true;
}
if (is_valid_output) {
LOGI("Output 1 has valid embedding dimensions: using Output 1");
}
}
}
// 确定嵌入维度(暂时使用 output 0后续可能需要修改为使用 output 1
rknn_tensor_attr output0_attr;
memset(&output0_attr, 0, sizeof(output0_attr));
output0_attr.index = 0;
ret = rknn_query(m_ctx, RKNN_QUERY_OUTPUT_ATTR, &output0_attr, sizeof(output0_attr));
if (ret == RKNN_SUCC && output0_attr.n_dims >= 2) {
m_embedding_dim = output0_attr.dims[output0_attr.n_dims - 1];
LOGI("Embedding dimension (from output 0): %d", m_embedding_dim);
}
return 0;
}
void BgeEngineRKNN::deinitRKNN() {
if (m_ctx != 0) {
rknn_destroy(m_ctx);
m_ctx = 0;
}
}
float* BgeEngineRKNN::inferEmbedding(const std::vector<int>& tokens) {
// 准备输入数据
const int max_seq_len = 512;
std::vector<int> padded_tokens = tokens;
// 填充到最大序列长度
if (padded_tokens.size() < max_seq_len) {
padded_tokens.resize(max_seq_len, 0); // 使用0作为padding
} else if (padded_tokens.size() > max_seq_len) {
padded_tokens.resize(max_seq_len); // 截断
}
// 准备attention_mask1表示真实token0表示padding
std::vector<int> attention_mask(max_seq_len, 0);
for (int i = 0; i < tokens.size() && i < max_seq_len; i++) {
attention_mask[i] = 1;
}
// 准备token_type_ids对于单句子任务全部为0
std::vector<int> token_type_ids(max_seq_len, 0);
// 准备输入张量
rknn_input inputs[3]; // 最多3个输入input_ids, attention_mask, token_type_ids
memset(inputs, 0, sizeof(inputs));
// 输入0: input_ids
inputs[0].index = 0;
inputs[0].type = RKNN_TENSOR_INT32;
inputs[0].size = sizeof(int) * max_seq_len;
inputs[0].buf = padded_tokens.data();
inputs[0].pass_through = 0;
// 输入1: attention_mask
inputs[1].index = 1;
inputs[1].type = RKNN_TENSOR_INT32;
inputs[1].size = sizeof(int) * max_seq_len;
inputs[1].buf = attention_mask.data();
inputs[1].pass_through = 0;
// 输入2: token_type_ids
inputs[2].index = 2;
inputs[2].type = RKNN_TENSOR_INT32;
inputs[2].size = sizeof(int) * max_seq_len;
inputs[2].buf = token_type_ids.data();
inputs[2].pass_through = 0;
// 设置输入(根据实际的输入数量)
rknn_input_output_num io_num;
int ret = rknn_query(m_ctx, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num));
if (ret != RKNN_SUCC) {
LOGE("rknn_query input_output_num failed: %d", ret);
return nullptr;
}
ret = rknn_inputs_set(m_ctx, io_num.n_input, inputs);
if (ret != RKNN_SUCC) {
LOGE("rknn_inputs_set failed: %d", ret);
return nullptr;
}
// 运行推理
ret = rknn_run(m_ctx, NULL);
if (ret != RKNN_SUCC) {
LOGE("rknn_run failed: %d", ret);
return nullptr;
}
// 获取输出
rknn_output outputs[2]; // 增加到2个输出
memset(outputs, 0, sizeof(outputs));
outputs[0].want_float = 1;
outputs[1].want_float = 1;
ret = rknn_outputs_get(m_ctx, 2, outputs, NULL);
if (ret != RKNN_SUCC) {
LOGE("rknn_outputs_get failed: %d", ret);
return nullptr;
}
// 提取嵌入使用output 1
float* embedding = new float[m_embedding_dim];
memcpy(embedding, outputs[1].buf, sizeof(float) * m_embedding_dim);
LOGI("Using output 1 for embedding");
// 释放输出
rknn_outputs_release(m_ctx, 2, outputs);
return embedding;
}

View File

@@ -0,0 +1,42 @@
#ifndef BGE_ENGINE_RKNN_H
#define BGE_ENGINE_RKNN_H
#include <rknn_api.h>
#include <vector>
#include <string>
#include <map>
class BgeEngineRKNN {
public:
BgeEngineRKNN();
~BgeEngineRKNN();
int loadModel(const char* modelPath, const char* vocabPath = nullptr);
void freeModel();
float* getEmbedding(const std::string& text);
float calculateSimilarity(const std::string& text1, const std::string& text2);
int getEmbeddingDim() const; // 获取嵌入维度
private:
rknn_context m_ctx;
bool m_initialized;
int m_embedding_dim;
std::map<std::string, int> m_vocab; // 词汇表映射
// 辅助方法
std::vector<int> tokenize(const std::string& text);
std::vector<float> normalizeEmbedding(const float* embedding, int dim);
float cosineSimilarity(const std::vector<float>& vec1, const std::vector<float>& vec2);
// Tokenizer相关方法
int loadVocab(const char* vocabPath);
std::vector<std::string> basicTokenize(const std::string& text);
std::vector<std::string> wordPieceTokenize(const std::vector<std::string>& basic_tokens);
// RKNN相关方法
int initRKNN(const char* modelPath);
void deinitRKNN();
float* inferEmbedding(const std::vector<int>& tokens);
};
#endif // BGE_ENGINE_RKNN_H

View File

@@ -0,0 +1,229 @@
#include <jni.h>
#include <string>
#include <map>
#include <android/log.h>
#include "BgeEngineRKNN.h"
#define LOG_TAG "BgeEngineJNI"
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__)
#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__)
// 全局引擎实例映射
static std::map<jlong, BgeEngineRKNN*> g_engine_map;
static jclass g_float_array_class;
static jmethodID g_float_array_ctor;
// 获取BGE引擎实例
BgeEngineRKNN* getBgeEngineFromPtr(jlong ptr) {
auto it = g_engine_map.find(ptr);
if (it != g_engine_map.end()) {
return it->second;
}
return nullptr;
}
// JNI接口实现
// 创建BGE引擎
extern "C" JNIEXPORT jlong JNICALL Java_com_digitalperson_engine_BgeEngineRKNN_createBgeEngine(
JNIEnv* env,
jobject thiz) {
LOGI("Creating BGE engine");
BgeEngineRKNN* engine = new BgeEngineRKNN();
jlong ptr = reinterpret_cast<jlong>(engine);
g_engine_map[ptr] = engine;
return ptr;
}
// 加载模型
extern "C" JNIEXPORT jint JNICALL Java_com_digitalperson_engine_BgeEngineRKNN_loadModel(
JNIEnv* env,
jobject thiz,
jlong ptr,
jstring modelPath,
jstring vocabPath) {
LOGI("Loading BGE model with vocab");
BgeEngineRKNN* engine = getBgeEngineFromPtr(ptr);
if (engine == nullptr) {
LOGE("Invalid engine pointer");
return -1;
}
const char* c_modelPath = env->GetStringUTFChars(modelPath, NULL);
if (c_modelPath == nullptr) {
LOGE("Failed to get model path string");
return -1;
}
const char* c_vocabPath = nullptr;
if (vocabPath != nullptr) {
c_vocabPath = env->GetStringUTFChars(vocabPath, NULL);
if (c_vocabPath == nullptr) {
LOGE("Failed to get vocab path string");
env->ReleaseStringUTFChars(modelPath, c_modelPath);
return -1;
}
}
int ret = engine->loadModel(c_modelPath, c_vocabPath);
env->ReleaseStringUTFChars(modelPath, c_modelPath);
if (c_vocabPath != nullptr) {
env->ReleaseStringUTFChars(vocabPath, c_vocabPath);
}
return ret;
}
// 释放模型
extern "C" JNIEXPORT void JNICALL Java_com_digitalperson_engine_BgeEngineRKNN_freeModel(
JNIEnv* env,
jobject thiz,
jlong ptr) {
LOGI("Freeing BGE model");
BgeEngineRKNN* engine = getBgeEngineFromPtr(ptr);
if (engine == nullptr) {
LOGE("Invalid engine pointer");
return;
}
engine->freeModel();
}
// 获取嵌入
extern "C" JNIEXPORT jfloatArray JNICALL Java_com_digitalperson_engine_BgeEngineRKNN_getEmbeddingNative(
JNIEnv* env,
jobject thiz,
jlong ptr,
jstring text) {
BgeEngineRKNN* engine = getBgeEngineFromPtr(ptr);
if (engine == nullptr) {
LOGE("Invalid engine pointer");
return nullptr;
}
const char* c_text = env->GetStringUTFChars(text, NULL);
if (c_text == nullptr) {
LOGE("Failed to get text string");
return nullptr;
}
std::string cpp_text(c_text);
env->ReleaseStringUTFChars(text, c_text);
float* embedding = engine->getEmbedding(cpp_text);
if (embedding == nullptr) {
LOGE("Failed to get embedding");
return nullptr;
}
// 获取正确的嵌入维度
int embedding_dim = engine->getEmbeddingDim();
LOGI("Using embedding dimension: %d", embedding_dim);
// 创建Java float数组并返回
jfloatArray j_embedding = env->NewFloatArray(embedding_dim);
if (j_embedding == nullptr) {
LOGE("Failed to create float array");
delete[] embedding;
return nullptr;
}
env->SetFloatArrayRegion(j_embedding, 0, embedding_dim, embedding);
delete[] embedding;
return j_embedding;
}
// 获取嵌入维度
extern "C" JNIEXPORT jint JNICALL Java_com_digitalperson_engine_BgeEngineRKNN_getEmbeddingDim(
JNIEnv* env,
jobject thiz,
jlong ptr) {
BgeEngineRKNN* engine = getBgeEngineFromPtr(ptr);
if (engine == nullptr) {
LOGE("Invalid engine pointer");
return -1;
}
int embedding_dim = engine->getEmbeddingDim();
LOGI("Returning embedding dimension: %d", embedding_dim);
return embedding_dim;
}
// 计算相似度
extern "C" JNIEXPORT jfloat JNICALL Java_com_digitalperson_engine_BgeEngineRKNN_calculateSimilarityNative(
JNIEnv* env,
jobject thiz,
jlong ptr,
jstring text1,
jstring text2) {
LOGI("Calculating similarity");
BgeEngineRKNN* engine = getBgeEngineFromPtr(ptr);
if (engine == nullptr) {
LOGE("Invalid engine pointer");
return 0.0f;
}
const char* c_text1 = env->GetStringUTFChars(text1, NULL);
const char* c_text2 = env->GetStringUTFChars(text2, NULL);
if (c_text1 == nullptr || c_text2 == nullptr) {
LOGE("Failed to get text strings");
if (c_text1 != nullptr) {
env->ReleaseStringUTFChars(text1, c_text1);
}
if (c_text2 != nullptr) {
env->ReleaseStringUTFChars(text2, c_text2);
}
return 0.0f;
}
std::string cpp_text1(c_text1);
std::string cpp_text2(c_text2);
env->ReleaseStringUTFChars(text1, c_text1);
env->ReleaseStringUTFChars(text2, c_text2);
float similarity = engine->calculateSimilarity(cpp_text1, cpp_text2);
return similarity;
}
// 初始化JNI
jint JNI_OnLoad(JavaVM* vm, void* reserved) {
JNIEnv* env = nullptr;
if (vm->GetEnv((void**)&env, JNI_VERSION_1_6) != JNI_OK) {
LOGE("Failed to get JNI environment");
return JNI_ERR;
}
// 缓存类和方法ID
jclass float_array_cls = env->FindClass("[F");
if (float_array_cls == nullptr) {
LOGE("Failed to find float array class");
return JNI_ERR;
}
g_float_array_class = (jclass)env->NewGlobalRef(float_array_cls);
return JNI_VERSION_1_6;
}
// 清理JNI
void JNI_OnUnload(JavaVM* vm, void* reserved) {
JNIEnv* env = nullptr;
if (vm->GetEnv((void**)&env, JNI_VERSION_1_6) != JNI_OK) {
LOGE("Failed to get JNI environment");
return;
}
// 释放全局引用
if (g_float_array_class != nullptr) {
env->DeleteGlobalRef(g_float_array_class);
g_float_array_class = nullptr;
}
// 清理引擎实例映射
for (auto& pair : g_engine_map) {
delete pair.second;
}
g_engine_map.clear();
}

View File

@@ -58,5 +58,21 @@ if (ANDROID)
jnigraphics jnigraphics
log log
) )
add_library(bgeEngine SHARED
BgeEngineRKNN.cpp
BgeEngineRKNNJNI.cpp
)
target_include_directories(bgeEngine PRIVATE
${CMAKE_CURRENT_LIST_DIR}
${CMAKE_CURRENT_LIST_DIR}/zipformer_headers
${CMAKE_CURRENT_LIST_DIR}/utils
)
target_link_libraries(bgeEngine
rknnrt
log
)
endif() endif()

View File

@@ -0,0 +1,26 @@
package com.digitalperson
import android.app.Application
import android.util.Log
import com.digitalperson.config.AppConfig
import com.digitalperson.embedding.RefEmbeddingIndexer
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.launch
class DigitalPersonApp : Application() {
private val appScope = CoroutineScope(SupervisorJob() + Dispatchers.Default)
override fun onCreate() {
super.onCreate()
appScope.launch {
try {
RefEmbeddingIndexer.runOnce(this@DigitalPersonApp)
} catch (t: Throwable) {
Log.e(AppConfig.TAG, "[RefEmbed] 索引任务异常", t)
}
}
}
}

View File

@@ -38,6 +38,7 @@ import com.digitalperson.interaction.ConversationSummaryMemory
import java.io.File import java.io.File
import android.graphics.BitmapFactory import android.graphics.BitmapFactory
import android.widget.ImageView
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job import kotlinx.coroutines.Job
@@ -49,6 +50,7 @@ import kotlinx.coroutines.withContext
import com.digitalperson.onboard_testing.FaceRecognitionTest import com.digitalperson.onboard_testing.FaceRecognitionTest
import com.digitalperson.onboard_testing.LLMSummaryTest import com.digitalperson.onboard_testing.LLMSummaryTest
import com.digitalperson.embedding.RefImageMatcher
class Live2DChatActivity : AppCompatActivity() { class Live2DChatActivity : AppCompatActivity() {
companion object { companion object {
@@ -109,6 +111,7 @@ class Live2DChatActivity : AppCompatActivity() {
private lateinit var faceRecognitionTest: FaceRecognitionTest private lateinit var faceRecognitionTest: FaceRecognitionTest
private lateinit var llmSummaryTest: LLMSummaryTest private lateinit var llmSummaryTest: LLMSummaryTest
private var refMatchImageView: ImageView? = null
override fun onRequestPermissionsResult( override fun onRequestPermissionsResult(
requestCode: Int, requestCode: Int,
@@ -159,6 +162,8 @@ class Live2DChatActivity : AppCompatActivity() {
speakingPlayerViewId = 0, speakingPlayerViewId = 0,
live2dViewId = R.id.live2d_view live2dViewId = R.id.live2d_view
) )
refMatchImageView = findViewById(R.id.ref_match_image)
cameraPreviewView = findViewById(R.id.camera_preview) cameraPreviewView = findViewById(R.id.camera_preview)
cameraPreviewView.implementationMode = PreviewView.ImplementationMode.COMPATIBLE cameraPreviewView.implementationMode = PreviewView.ImplementationMode.COMPATIBLE
@@ -611,6 +616,7 @@ class Live2DChatActivity : AppCompatActivity() {
runOnUiThread { runOnUiThread {
uiManager.appendToUi("${filteredText.orEmpty()}\n") uiManager.appendToUi("${filteredText.orEmpty()}\n")
} }
maybeShowMatchedRefImage(filteredText ?: response)
} }
interactionCoordinator.onCloudFinalResponse(response) interactionCoordinator.onCloudFinalResponse(response)
} }
@@ -648,6 +654,24 @@ class Live2DChatActivity : AppCompatActivity() {
onStopClicked(userInitiated = false) onStopClicked(userInitiated = false)
} }
} }
private fun maybeShowMatchedRefImage(text: String) {
val imageView = refMatchImageView ?: return
ioScope.launch {
val match = RefImageMatcher.findBestMatch(applicationContext, text)
if (match == null) return@launch
val bitmap = try {
assets.open(match.pngAssetPath).use { BitmapFactory.decodeStream(it) }
} catch (_: Throwable) {
null
}
if (bitmap == null) return@launch
runOnUiThread {
imageView.setImageBitmap(bitmap)
imageView.visibility = android.view.View.VISIBLE
}
}
}
private fun createTtsCallback() = object : TtsController.TtsCallback { private fun createTtsCallback() = object : TtsController.TtsCallback {
override fun onTtsStarted(text: String) { override fun onTtsStarted(text: String) {

View File

@@ -25,6 +25,7 @@ import android.view.View
import androidx.lifecycle.Lifecycle import androidx.lifecycle.Lifecycle
import androidx.lifecycle.LifecycleOwner import androidx.lifecycle.LifecycleOwner
import androidx.lifecycle.LifecycleRegistry import androidx.lifecycle.LifecycleRegistry
import android.widget.ImageView
import com.unity3d.player.UnityPlayer import com.unity3d.player.UnityPlayer
import com.unity3d.player.UnityPlayerActivity import com.unity3d.player.UnityPlayerActivity
import com.digitalperson.audio.AudioProcessor import com.digitalperson.audio.AudioProcessor
@@ -47,6 +48,8 @@ import com.digitalperson.tts.TtsController
import com.digitalperson.util.FileHelper import com.digitalperson.util.FileHelper
import com.digitalperson.vad.VadManager import com.digitalperson.vad.VadManager
import kotlinx.coroutines.* import kotlinx.coroutines.*
import com.digitalperson.embedding.RefImageMatcher
import android.graphics.BitmapFactory
class UnityDigitalPersonActivity : UnityPlayerActivity(), LifecycleOwner { class UnityDigitalPersonActivity : UnityPlayerActivity(), LifecycleOwner {
@@ -108,6 +111,7 @@ class UnityDigitalPersonActivity : UnityPlayerActivity(), LifecycleOwner {
private lateinit var holdToSpeakButton: Button private lateinit var holdToSpeakButton: Button
private var recordButtonGlow: View? = null private var recordButtonGlow: View? = null
private var pulseAnimator: ObjectAnimator? = null private var pulseAnimator: ObjectAnimator? = null
private var refMatchImageView: ImageView? = null
// 音频和AI模块 // 音频和AI模块
@@ -254,6 +258,7 @@ class UnityDigitalPersonActivity : UnityPlayerActivity(), LifecycleOwner {
chatHistoryText = chatLayout.findViewById(R.id.my_text) chatHistoryText = chatLayout.findViewById(R.id.my_text)
holdToSpeakButton = chatLayout.findViewById(R.id.record_button) holdToSpeakButton = chatLayout.findViewById(R.id.record_button)
recordButtonGlow = chatLayout.findViewById(R.id.record_button_glow) recordButtonGlow = chatLayout.findViewById(R.id.record_button_glow)
refMatchImageView = chatLayout.findViewById(R.id.ref_match_image)
// 根据配置设置按钮可见性 // 根据配置设置按钮可见性
if (AppConfig.USE_HOLD_TO_SPEAK) { if (AppConfig.USE_HOLD_TO_SPEAK) {
@@ -735,6 +740,7 @@ class UnityDigitalPersonActivity : UnityPlayerActivity(), LifecycleOwner {
val filteredText = ttsController.speakLlmResponse(response) val filteredText = ttsController.speakLlmResponse(response)
android.util.Log.d("UnityDigitalPerson", "LLM response filtered: ${filteredText?.take(60)}") android.util.Log.d("UnityDigitalPerson", "LLM response filtered: ${filteredText?.take(60)}")
if (filteredText != null) appendChat("助手: $filteredText") if (filteredText != null) appendChat("助手: $filteredText")
maybeShowMatchedRefImage(filteredText ?: response)
interactionCoordinator.onCloudFinalResponse(filteredText ?: response.trim()) interactionCoordinator.onCloudFinalResponse(filteredText ?: response.trim())
} }
@@ -751,6 +757,25 @@ class UnityDigitalPersonActivity : UnityPlayerActivity(), LifecycleOwner {
} }
} }
private fun maybeShowMatchedRefImage(text: String) {
val imageView = refMatchImageView ?: return
// Unity Activity already has coroutines
CoroutineScope(SupervisorJob() + Dispatchers.IO).launch {
val match = RefImageMatcher.findBestMatch(applicationContext, text)
if (match == null) return@launch
val bitmap = try {
assets.open(match.pngAssetPath).use { BitmapFactory.decodeStream(it) }
} catch (_: Throwable) {
null
}
if (bitmap == null) return@launch
runOnUiThread {
imageView.setImageBitmap(bitmap)
imageView.visibility = View.VISIBLE
}
}
}
private fun requestLocalThought(prompt: String, onResult: (String) -> Unit) { private fun requestLocalThought(prompt: String, onResult: (String) -> Unit) {
val local = llmManager val local = llmManager
if (local == null) { if (local == null) {

View File

@@ -110,6 +110,19 @@ object AppConfig {
const val MODEL_SIZE_ESTIMATE = 500L * 1024 * 1024 // 500MB const val MODEL_SIZE_ESTIMATE = 500L * 1024 * 1024 // 500MB
} }
/** BGE-small-zh-v1.5 文本嵌入RKNN用于语义相似度 / 检索。 */
object Bge {
const val ASSET_DIR = "bge_models"
const val MODEL_FILE = "bge-small-zh-v1.5.rknn"
}
/**
* app/note/ref 通过 Gradle 额外 assets 目录打入 apk 后,在 assets 中的根路径为 `ref/`。
*/
object RefCorpus {
const val ASSETS_ROOT = "ref"
}
object OnboardTesting { object OnboardTesting {
// 测试人脸识别 // 测试人脸识别
const val FACE_REGONITION = false const val FACE_REGONITION = false

View File

@@ -9,15 +9,24 @@ import com.digitalperson.data.dao.UserAnswerDao
import com.digitalperson.data.dao.UserMemoryDao import com.digitalperson.data.dao.UserMemoryDao
import com.digitalperson.data.dao.ChatMessageDao import com.digitalperson.data.dao.ChatMessageDao
import com.digitalperson.data.dao.ConversationSummaryDao import com.digitalperson.data.dao.ConversationSummaryDao
import com.digitalperson.data.dao.RefTextEmbeddingDao
import com.digitalperson.data.entity.Question import com.digitalperson.data.entity.Question
import com.digitalperson.data.entity.UserAnswer import com.digitalperson.data.entity.UserAnswer
import com.digitalperson.data.entity.UserMemory import com.digitalperson.data.entity.UserMemory
import com.digitalperson.data.entity.ChatMessageEntity import com.digitalperson.data.entity.ChatMessageEntity
import com.digitalperson.data.entity.ConversationSummaryEntity import com.digitalperson.data.entity.ConversationSummaryEntity
import com.digitalperson.data.entity.RefTextEmbedding
@Database( @Database(
entities = [UserMemory::class, Question::class, UserAnswer::class, ChatMessageEntity::class, ConversationSummaryEntity::class], entities = [
version = 4, UserMemory::class,
Question::class,
UserAnswer::class,
ChatMessageEntity::class,
ConversationSummaryEntity::class,
RefTextEmbedding::class
],
version = 5,
exportSchema = false exportSchema = false
) )
abstract class AppDatabase : RoomDatabase() { abstract class AppDatabase : RoomDatabase() {
@@ -26,6 +35,7 @@ abstract class AppDatabase : RoomDatabase() {
abstract fun userAnswerDao(): UserAnswerDao abstract fun userAnswerDao(): UserAnswerDao
abstract fun chatMessageDao(): ChatMessageDao abstract fun chatMessageDao(): ChatMessageDao
abstract fun conversationSummaryDao(): ConversationSummaryDao abstract fun conversationSummaryDao(): ConversationSummaryDao
abstract fun refTextEmbeddingDao(): RefTextEmbeddingDao
companion object { companion object {
private const val DATABASE_NAME = "digital_human.db" private const val DATABASE_NAME = "digital_human.db"

View File

@@ -9,6 +9,17 @@ import com.digitalperson.data.entity.Question
interface QuestionDao { interface QuestionDao {
@Insert @Insert
fun insert(question: Question): Long fun insert(question: Question): Long
@Query(
"""
SELECT * FROM questions
WHERE content = :content
AND ((:subject IS NULL AND subject IS NULL) OR subject = :subject)
AND ((:grade IS NULL AND grade IS NULL) OR grade = :grade)
LIMIT 1
"""
)
fun findByContentSubjectGrade(content: String, subject: String?, grade: Int?): Question?
@Query("SELECT * FROM questions WHERE subject = :subject ORDER BY difficulty") @Query("SELECT * FROM questions WHERE subject = :subject ORDER BY difficulty")
fun getQuestionsBySubject(subject: String): List<Question> fun getQuestionsBySubject(subject: String): List<Question>

View File

@@ -0,0 +1,20 @@
package com.digitalperson.data.dao
import androidx.room.Dao
import androidx.room.Insert
import androidx.room.OnConflictStrategy
import androidx.room.Query
import com.digitalperson.data.entity.RefTextEmbedding
@Dao
interface RefTextEmbeddingDao {
@Query("SELECT * FROM ref_text_embeddings WHERE assetPath = :path LIMIT 1")
fun getByPath(path: String): RefTextEmbedding?
@Insert(onConflict = OnConflictStrategy.REPLACE)
fun insert(row: RefTextEmbedding): Long
@Query("SELECT * FROM ref_text_embeddings")
fun getAll(): List<RefTextEmbedding>
}

View File

@@ -0,0 +1,23 @@
package com.digitalperson.data.entity
import androidx.room.Entity
import androidx.room.Index
import androidx.room.PrimaryKey
import com.digitalperson.data.util.embeddingBytesToFloatArray
@Entity(
tableName = "ref_text_embeddings",
indices = [Index(value = ["assetPath"], unique = true)]
)
data class RefTextEmbedding(
@PrimaryKey(autoGenerate = true) val id: Long = 0,
/** assets 相对路径,如 ref/一年级.../xxx.txt */
val assetPath: String,
/** 参与嵌入的正文 UTF-8 的 SHA-256 十六进制,用于跳过未变更文件 */
val contentHash: String,
val dim: Int,
val embedding: ByteArray,
val updatedAt: Long = System.currentTimeMillis()
) {
fun toFloatArray(): FloatArray = embeddingBytesToFloatArray(embedding)
}

View File

@@ -0,0 +1,18 @@
package com.digitalperson.data.util
import java.nio.ByteBuffer
import java.nio.ByteOrder
fun floatArrayToEmbeddingBytes(values: FloatArray): ByteArray {
val bb = ByteBuffer.allocate(values.size * 4).order(ByteOrder.LITTLE_ENDIAN)
for (v in values) {
bb.putFloat(v)
}
return bb.array()
}
fun embeddingBytesToFloatArray(blob: ByteArray): FloatArray {
val bb = ByteBuffer.wrap(blob).order(ByteOrder.LITTLE_ENDIAN)
val n = blob.size / 4
return FloatArray(n) { bb.getFloat() }
}

View File

@@ -0,0 +1,48 @@
package com.digitalperson.embedding
import android.content.Context
import com.digitalperson.config.AppConfig
import com.digitalperson.engine.BgeEngineRKNN
import com.digitalperson.util.FileHelper
import java.io.File
/**
* 懒加载 BGE 文本嵌入RKNN。用于对话文本与预存标注的语义相似度检索。
*
* 初始化会复制 [AppConfig.Bge.ASSET_DIR] 下资源到内部存储并加载 [AppConfig.Bge.MODEL_FILE]。
*/
object BgeEmbedding {
@Volatile
private var engine: BgeEngineRKNN? = null
fun isReady(): Boolean = engine?.isInitialized == true
/**
* 在主线程调用会阻塞;建议在后台线程或协程 [Dispatchers.IO] 中调用。
*/
@Synchronized
fun initialize(context: Context): Boolean {
if (engine?.isInitialized == true) return true
val dir = FileHelper.copyBgeModels(context.applicationContext) ?: return false
val path = File(dir, AppConfig.Bge.MODEL_FILE).absolutePath
if (!File(path).exists()) return false
val eng = BgeEngineRKNN(context.applicationContext)
if (!eng.initialize(path)) return false
engine = eng
return true
}
@Synchronized
fun release() {
engine?.deinitialize()
engine = null
}
fun getEmbedding(text: String): FloatArray? = engine?.getEmbedding(text)
fun similarity(text1: String, text2: String): Float? =
engine?.calculateSimilarity(text1, text2)
fun embeddingDim(): Int = engine?.embeddingDim ?: -1
}

View File

@@ -0,0 +1,34 @@
package com.digitalperson.embedding
import android.content.Context
internal object RefCorpusAssetScanner {
/**
* 递归列出 [root] 目录下(含子目录)所有 `.txt` 的 assets 路径,使用 `/` 分隔。
*/
fun listTxtFilesUnder(context: Context, root: String): List<String> {
val am = context.assets
val out = ArrayList<String>()
fun walk(prefix: String) {
val children = am.list(prefix) ?: return
if (children.isEmpty()) return
for (child in children) {
if (child.isEmpty()) continue
val full = "$prefix/$child"
val grand = am.list(full)
if (grand.isNullOrEmpty()) {
if (child.endsWith(".txt", ignoreCase = true)) {
out.add(full)
}
} else {
walk(full)
}
}
}
walk(root.removeSuffix("/"))
return out.sorted()
}
}

View File

@@ -0,0 +1,173 @@
package com.digitalperson.embedding
import android.content.Context
import android.util.Log
import com.digitalperson.config.AppConfig
import com.digitalperson.data.AppDatabase
import com.digitalperson.data.entity.Question
import com.digitalperson.data.entity.RefTextEmbedding
import com.digitalperson.data.util.floatArrayToEmbeddingBytes
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import java.security.MessageDigest
/**
* 开机(进程启动)时在后台扫描 assets 下 [AppConfig.RefCorpus.ASSETS_ROOT] 中全部 `.txt`
* 跳过 `#` 行后做 BGE 嵌入,写入 Room 并填充 [RefEmbeddingMemoryCache]。
*
* DAO 为同步接口,整个函数在 [Dispatchers.IO] 上运行,不阻塞主线程。
*/
object RefEmbeddingIndexer {
private const val TAG = AppConfig.TAG
suspend fun runOnce(context: Context) = withContext(Dispatchers.IO) {
val app = context.applicationContext
val db = AppDatabase.getInstance(app)
val dao = db.refTextEmbeddingDao()
val questionDao = db.questionDao()
if (!BgeEmbedding.initialize(app)) {
Log.e(TAG, "[RefEmbed] BGE 初始化失败,跳过 ref 语料索引")
return@withContext
}
val root = AppConfig.RefCorpus.ASSETS_ROOT
val paths = RefCorpusAssetScanner.listTxtFilesUnder(app, root)
Log.i(TAG, "[RefEmbed] 发现 ${paths.size} 个 txtroot=$root")
var skipped = 0
var embedded = 0
var empty = 0
var failed = 0
for (path in paths) {
val raw = try {
app.assets.open(path).bufferedReader(Charsets.UTF_8).use { it.readText() }
} catch (e: Exception) {
Log.w(TAG, "[RefEmbed] 读取失败 $path: ${e.message}")
failed++
continue
}
// 题库:遇到包含 ?/ 的行,写入 questions
val subject = extractSubjectFromRaw(raw)
val grade = extractGradeFromPath(path)
val questionLines = extractQuestionLines(raw)
for (line in questionLines) {
val content = line.trim()
if (content.isEmpty()) continue
val exists = questionDao.findByContentSubjectGrade(content, subject, grade)
if (exists == null) {
questionDao.insert(
Question(
id = 0,
content = content,
answer = null,
subject = subject,
grade = grade,
difficulty = 1,
createdAt = System.currentTimeMillis()
)
)
}
}
val embedText = RefTxtEmbedText.fromRawFileContent(raw)
if (embedText.isEmpty()) {
empty++
continue
}
val hash = sha256Hex(embedText.toByteArray(Charsets.UTF_8))
// 同步 DAO 调用(已在 IO 线程)
val existing = dao.getByPath(path)
if (existing != null && existing.contentHash == hash) {
RefEmbeddingMemoryCache.put(path, normalizeL2(existing.toFloatArray()))
skipped++
continue
}
val vec = BgeEmbedding.getEmbedding(embedText)
if (vec == null || vec.isEmpty()) {
Log.w(TAG, "[RefEmbed] 嵌入为空 $path")
failed++
continue
}
val normalized = normalizeL2(vec)
dao.insert(
RefTextEmbedding(
assetPath = path,
contentHash = hash,
dim = normalized.size,
embedding = floatArrayToEmbeddingBytes(normalized)
)
)
RefEmbeddingMemoryCache.put(path, normalized)
embedded++
}
Log.i(
TAG,
"[RefEmbed] 完成 embedded=$embedded skipped=$skipped empty=$empty failed=$failed cacheSize=${RefEmbeddingMemoryCache.size()}"
)
}
private fun extractSubjectFromRaw(raw: String): String? {
val line = raw.lineSequence()
.map { it.trimEnd() }
.firstOrNull { it.trimStart().startsWith("#") }
?: return null
val s = line.trimStart().removePrefix("#").trim()
return s.ifEmpty { null }
}
private fun extractQuestionLines(raw: String): List<String> {
return raw.lineSequence()
.map { it.trimEnd() }
.filter { it.isNotBlank() }
.filter { !it.trimStart().startsWith("#") }
.filter { it.contains('?') || it.contains('') }
.toList()
}
private fun extractGradeFromPath(assetPath: String): Int? {
// example: ref/一年级上-生活适应/... or ref/二年级下-...
val idx = assetPath.indexOf("年级")
if (idx <= 0) return null
val prefix = assetPath.substring(0, idx)
val cn = prefix.lastOrNull() ?: return null
return chineseGradeToInt(cn)
}
private fun chineseGradeToInt(c: Char): Int? {
return when (c) {
'一' -> 1
'二' -> 2
'三' -> 3
'四' -> 4
'五' -> 5
'六' -> 6
'七' -> 7
'八' -> 8
'九' -> 9
'十' -> 10
else -> null
}
}
private fun normalizeL2(v: FloatArray): FloatArray {
var sum = 0.0
for (x in v) sum += (x * x).toDouble()
val norm = kotlin.math.sqrt(sum).toFloat()
if (norm <= 1e-12f) return v.copyOf()
return FloatArray(v.size) { i -> v[i] / norm }
}
private fun sha256Hex(data: ByteArray): String {
val md = MessageDigest.getInstance("SHA-256")
val digest = md.digest(data)
return digest.joinToString("") { "%02x".format(it) }
}
}

View File

@@ -0,0 +1,30 @@
package com.digitalperson.embedding
import java.util.concurrent.ConcurrentHashMap
/**
* ref 语料 txt 嵌入的内存缓存,键为 assets 相对路径(与数据库 [com.digitalperson.data.entity.RefTextEmbedding.assetPath] 一致)。
*/
object RefEmbeddingMemoryCache {
private val vectors = ConcurrentHashMap<String, FloatArray>()
fun put(assetPath: String, embedding: FloatArray) {
vectors[assetPath] = embedding.copyOf()
}
fun get(assetPath: String): FloatArray? {
val v = vectors[assetPath] ?: return null
return v.copyOf()
}
fun clear() {
vectors.clear()
}
/** 只读快照(向量已为拷贝,可安全使用)。 */
fun snapshot(): Map<String, FloatArray> =
vectors.mapValues { it.value.copyOf() }
fun size(): Int = vectors.size
}

View File

@@ -0,0 +1,95 @@
package com.digitalperson.embedding
import android.content.Context
import android.util.Log
import com.digitalperson.config.AppConfig
import kotlin.math.sqrt
data class RefImageMatch(
val txtAssetPath: String,
val pngAssetPath: String,
val score: Float
)
object RefImageMatcher {
private const val TAG = AppConfig.TAG
/**
* @param threshold 余弦相似度阈值(向量已归一化时等价于 dot product
*/
fun findBestMatch(
context: Context,
text: String,
threshold: Float = 0.70f
): RefImageMatch? {
val query = text.trim()
if (query.isEmpty()) return null
if (!BgeEmbedding.isReady()) {
val ok = BgeEmbedding.initialize(context.applicationContext)
if (!ok) {
Log.w(TAG, "[RefMatch] BGE not ready, skip match")
return null
}
}
val q = BgeEmbedding.getEmbedding(query) ?: return null
if (q.isEmpty()) return null
val qn = normalizeL2(q)
val vectors = RefEmbeddingMemoryCache.snapshot()
if (vectors.isEmpty()) return null
var bestPath: String? = null
var bestScore = -1f
for ((path, v) in vectors) {
if (v.isEmpty() || v.size != qn.size) continue
val score = dot(qn, v)
if (score > bestScore) {
bestScore = score
bestPath = path
}
}
val txtPath = bestPath ?: return null
if (bestScore < threshold) return null
val pngPath = if (txtPath.endsWith(".txt", ignoreCase = true)) {
txtPath.dropLast(4) + ".png"
} else {
"$txtPath.png"
}
// 不在这里 decode只检查是否存在避免 UI 线程 IO。
val exists = try {
context.assets.open(pngPath).close()
true
} catch (_: Throwable) {
false
}
if (!exists) return null
return RefImageMatch(
txtAssetPath = txtPath,
pngAssetPath = pngPath,
score = bestScore
)
}
private fun dot(a: FloatArray, b: FloatArray): Float {
var s = 0f
for (i in a.indices) s += a[i] * b[i]
return s
}
private fun normalizeL2(v: FloatArray): FloatArray {
var sum = 0.0
for (x in v) sum += (x * x).toDouble()
val norm = sqrt(sum).toFloat()
if (norm <= 1e-12f) return v.copyOf()
return FloatArray(v.size) { i -> v[i] / norm }
}
}

View File

@@ -0,0 +1,18 @@
package com.digitalperson.embedding
/**
* 从原始 txt 中构造待嵌入文本:去掉「以 # 开头」的行(行首可含空白),其余行以换行连接。
*/
object RefTxtEmbedText {
fun fromRawFileContent(raw: String): String {
return raw.lineSequence()
.map { it.trimEnd() }
.filter { line ->
if (line.isEmpty()) return@filter false
!line.trimStart().startsWith("#")
}
.joinToString("\n")
.trim()
}
}

View File

@@ -0,0 +1,353 @@
package com.digitalperson.embedding;
import android.content.Context;
import android.os.Handler;
import android.os.Looper;
import android.util.Log;
import com.digitalperson.config.AppConfig;
import com.digitalperson.engine.BgeEngineRKNN;
import com.digitalperson.util.FileHelper;
import org.ejml.simple.SimpleMatrix;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
public class SimilarityManager {
private static final String TAG = "SimilarityManager";
private Context mContext;
private Handler mHandler;
private BgeEngineRKNN mBgeEngine;
// 测试数据相关
private List<float[]> mTestEmbeddings; // 预计算的100个句子的嵌入float数组格式
private List<SimpleMatrix> mTestEmbeddingsEJML; // 预计算的100个句子的嵌入EJML格式
private SimpleMatrix mTestEmbeddingsMatrix; // 所有测试嵌入的矩阵(嵌入维度 × 句子数量)
private List<Double> mTestEmbeddingsNorms; // 预计算的100个句子的嵌入范数
private SimpleMatrix mTestEmbeddingsNormsMatrix; // 预计算的嵌入范数向量1 × 句子数量)
private SimpleMatrix mTestEmbeddingsNormsReciprocalMatrix; // 预计算的嵌入范数倒数向量1 × 句子数量)
private List<String> mTestSentences; // 预生成的100个测试句子
private boolean useEJMLForSimilarity = false; // 默认使用传统for循环计算相似度
public interface SimilarityListener {
void onSimilarityCalculated(float similarity, long timeTaken);
void onPerformanceTestComplete(long traditionalTime, long ejmlTime, double speedup);
void onError(String errorMessage);
}
private SimilarityListener mListener;
public SimilarityManager(Context context) {
this.mContext = context;
this.mHandler = new Handler(Looper.getMainLooper());
}
public void setListener(SimilarityListener listener) {
this.mListener = listener;
}
// 初始化BGE模型
public void initBgeModel() {
try {
File bgeModelDir = FileHelper.copyBgeModels(mContext);
if (bgeModelDir == null) {
Log.e(TAG, "BGE model directory copy failed");
if (mListener != null) {
mListener.onError("BGE 模型文件复制失败");
}
return;
}
// 初始化BGE模型
mBgeEngine = new BgeEngineRKNN(mContext);
String modelPath = new File(bgeModelDir, AppConfig.Bge.MODEL_FILE).getAbsolutePath();
// 检查模型文件是否存在
if (!new File(modelPath).exists()) {
Log.e(TAG, "BGE model file does not exist: " + modelPath);
if (mListener != null) {
mListener.onError("BGE 模型文件不存在");
}
return;
}
boolean success = mBgeEngine.initialize(modelPath);
if (!success) {
Log.e(TAG, "Failed to initialize BGE model");
if (mListener != null) {
mListener.onError("BGE 模型初始化失败");
}
}
} catch (Exception e) {
Log.e(TAG, "Exception in initBgeModel", e);
if (mListener != null) {
mListener.onError("初始化BGE模型异常: " + e.getMessage());
}
}
}
// 计算BGE相似度
public void calculateSimilarity(String text1, String text2) {
if (mBgeEngine == null || !mBgeEngine.isInitialized()) {
Log.w(TAG, "BGE engine not initialized, skipping similarity calculation");
if (mListener != null) {
mListener.onError("BGE 模型未初始化");
}
return;
}
// 在后台线程中计算相似度
new Thread(() -> {
try {
long startTime = System.currentTimeMillis();
float similarity = mBgeEngine.calculateSimilarity(text1, text2);
long timeTaken = System.currentTimeMillis() - startTime;
if (mListener != null) {
mListener.onSimilarityCalculated(similarity, timeTaken);
}
} catch (Exception e) {
Log.e(TAG, "Exception in calculateSimilarity", e);
if (mListener != null) {
mListener.onError("相似度计算失败: " + e.getMessage());
}
}
}).start();
}
// 测试BGE性能
public void testPerformance(String userInput) {
if (mBgeEngine == null || !mBgeEngine.isInitialized()) {
Log.w(TAG, "BGE engine not initialized, skipping performance test");
if (mListener != null) {
mListener.onError("BGE 模型未初始化");
}
return;
}
// 在后台线程中执行测试
new Thread(() -> {
try {
// 准备测试数据
prepareTestData();
// 1. 计算用户输入的嵌入
float[] userEmbedding = mBgeEngine.getEmbedding(userInput);
if (userEmbedding == null) {
if (mListener != null) {
mListener.onError("无法计算用户输入的嵌入");
}
return;
}
// 创建EJML格式的用户嵌入
SimpleMatrix userVec = new SimpleMatrix(userEmbedding.length, 1);
for (int i = 0; i < userEmbedding.length; i++) {
userVec.set(i, 0, userEmbedding[i]);
}
// 方法1: 使用float数组和循环计算相似度传统方法
long startTime1 = System.currentTimeMillis();
List<Float> similarities1 = new ArrayList<>();
// 优化传统方法:使用范数倒数和乘法代替除法
double normUserTraditional = 0.0;
for (float value : userEmbedding) {
normUserTraditional += value * value;
}
normUserTraditional = Math.sqrt(normUserTraditional);
double normUserReciprocalTraditional = (normUserTraditional > 1e-10) ? 1.0 / normUserTraditional : 0.0;
for (int i = 0; i < mTestEmbeddings.size(); i++) {
float[] embedding = mTestEmbeddings.get(i);
// 计算点积
double dotProduct = 0.0;
for (int j = 0; j < userEmbedding.length; j++) {
dotProduct += userEmbedding[j] * embedding[j];
}
// 使用乘法代替除法cos(u,v_i) = (u·v_i) × (1/||u||) × (1/||v_i||)
double normEmbeddingReciprocal = mTestEmbeddingsNorms.get(i) > 1e-10 ? 1.0 / mTestEmbeddingsNorms.get(i) : 0.0;
float sim = (float) (dotProduct * normUserReciprocalTraditional * normEmbeddingReciprocal);
similarities1.add(sim);
}
long timeTaken1 = System.currentTimeMillis() - startTime1;
// 方法2: 使用EJML计算相似度优化方法
long startTime2 = System.currentTimeMillis();
List<Float> similarities2 = new ArrayList<>();
// 优化1: 预计算用户向量的范数和范数倒数
double normUser = userVec.normF();
double normUserReciprocal = (normUser > 1e-10) ? 1.0 / normUser : 0.0;
// 优化2: 使用批量矩阵计算一次性计算所有点积和相似度
if (mTestEmbeddingsMatrix != null && mTestEmbeddingsNormsReciprocalMatrix != null) {
// 使用矩阵乘法一次性计算所有点积
SimpleMatrix dotProducts = userVec.transpose().mult(mTestEmbeddingsMatrix);
// 一次性计算所有相似度
// 使用乘法代替除法cos(u,v_i) = (u·v_i) × (1/||u||) × (1/||v_i||)
// 步骤1: 计算 (u·v_i) × (1/||u||)
SimpleMatrix dotProductsScaled = dotProducts.scale(normUserReciprocal);
// 步骤2: 逐元素乘法计算最终相似度
SimpleMatrix similaritiesMatrix = dotProductsScaled.elementMult(mTestEmbeddingsNormsReciprocalMatrix);
// 将结果转换为列表
for (int i = 0; i < similaritiesMatrix.numCols(); i++) {
similarities2.add((float) similaritiesMatrix.get(0, i));
}
} else if (mTestEmbeddingsMatrix != null) {
// 降级方案1: 只有嵌入矩阵,使用批量点积 + 循环计算相似度
SimpleMatrix dotProducts = userVec.transpose().mult(mTestEmbeddingsMatrix);
for (int i = 0; i < dotProducts.numCols(); i++) {
double dotProduct = dotProducts.get(0, i);
double normEmbeddingReciprocal = mTestEmbeddingsNorms.get(i) > 1e-10 ? 1.0 / mTestEmbeddingsNorms.get(i) : 0.0;
float sim = (float) (dotProduct * normUserReciprocal * normEmbeddingReciprocal);
similarities2.add(sim);
}
} else {
// 降级方案2: 使用循环计算(当矩阵未初始化时)
for (int i = 0; i < mTestEmbeddingsEJML.size(); i++) {
SimpleMatrix embedding = mTestEmbeddingsEJML.get(i);
double dotProduct = userVec.dot(embedding);
double normEmbeddingReciprocal = mTestEmbeddingsNorms.get(i) > 1e-10 ? 1.0 / mTestEmbeddingsNorms.get(i) : 0.0;
float sim = (float) (dotProduct * normUserReciprocal * normEmbeddingReciprocal);
similarities2.add(sim);
}
}
long timeTaken2 = System.currentTimeMillis() - startTime2;
// 计算速度提升
double speedup = (double) timeTaken1 / timeTaken2;
if (mListener != null) {
mListener.onPerformanceTestComplete(timeTaken1, timeTaken2, speedup);
}
} catch (Exception e) {
Log.e(TAG, "Exception in testPerformance", e);
if (mListener != null) {
mListener.onError("性能测试失败: " + e.getMessage());
}
}
}).start();
}
// 准备测试数据
private void prepareTestData() {
if (mTestEmbeddings == null || mTestEmbeddings.isEmpty() ||
mTestEmbeddingsEJML == null || mTestEmbeddingsEJML.isEmpty() ||
mTestEmbeddingsNorms == null || mTestEmbeddingsNorms.isEmpty() ||
mTestEmbeddingsMatrix == null ||
mTestEmbeddingsNormsMatrix == null ||
mTestEmbeddingsNormsReciprocalMatrix == null) {
// 生成100个测试句子
generateTestSentences();
// 预计算嵌入float数组格式
mTestEmbeddings = new ArrayList<>();
mTestEmbeddingsEJML = new ArrayList<>();
mTestEmbeddingsNorms = new ArrayList<>();
List<String> validTestSentences = new ArrayList<>(); // 只包含成功获取嵌入的句子
for (String sentence : mTestSentences) {
float[] embedding = mBgeEngine.getEmbedding(sentence);
if (embedding != null) {
// 添加float数组格式的嵌入
mTestEmbeddings.add(embedding);
// 添加EJML格式的嵌入
SimpleMatrix vec = new SimpleMatrix(embedding.length, 1);
for (int i = 0; i < embedding.length; i++) {
vec.set(i, 0, embedding[i]);
}
mTestEmbeddingsEJML.add(vec);
// 预计算嵌入的范数
double norm = vec.normF();
mTestEmbeddingsNorms.add(norm);
// 添加到有效句子列表
validTestSentences.add(sentence);
}
}
// 更新mTestSentences为只包含有效句子的列表确保索引匹配
mTestSentences = validTestSentences;
// 创建测试嵌入矩阵(嵌入维度 × 句子数量)
if (!mTestEmbeddingsEJML.isEmpty()) {
int embeddingDim = mTestEmbeddingsEJML.get(0).numRows();
int numSentences = mTestEmbeddingsEJML.size();
mTestEmbeddingsMatrix = new SimpleMatrix(embeddingDim, numSentences);
for (int i = 0; i < numSentences; i++) {
SimpleMatrix embedding = mTestEmbeddingsEJML.get(i);
for (int j = 0; j < embeddingDim; j++) {
mTestEmbeddingsMatrix.set(j, i, embedding.get(j, 0));
}
}
Log.d(TAG, "Created test embeddings matrix, size: " + embeddingDim + " × " + numSentences);
}
// 创建测试嵌入范数矩阵1 × 句子数量)
if (!mTestEmbeddingsNorms.isEmpty()) {
int numSentences = mTestEmbeddingsNorms.size();
mTestEmbeddingsNormsMatrix = new SimpleMatrix(1, numSentences);
mTestEmbeddingsNormsReciprocalMatrix = new SimpleMatrix(1, numSentences);
for (int i = 0; i < numSentences; i++) {
double norm = mTestEmbeddingsNorms.get(i);
mTestEmbeddingsNormsMatrix.set(0, i, norm);
// 预计算范数倒数,避免在相似度计算时重复计算
double reciprocal = (norm > 1e-10) ? 1.0 / norm : 0.0;
mTestEmbeddingsNormsReciprocalMatrix.set(0, i, reciprocal);
}
Log.d(TAG, "Created test embeddings norms matrix, size: 1 × " + numSentences);
Log.d(TAG, "Created test embeddings norms reciprocal matrix, size: 1 × " + numSentences);
}
Log.d(TAG, "Generated embeddings for " + mTestEmbeddings.size() + " test sentences");
}
}
// 生成测试句子
private void generateTestSentences() {
mTestSentences = new ArrayList<>();
// 基础词汇
String[] words = {"", "", "眼睛", "鼻子", "嘴巴", "耳朵", "头发", "手指", "", ""};
String[] adjectives = {"", "", "", "", "", "", "漂亮", "丑陋", "干净", ""};
String[] verbs = {"", "", "", "", "", "", "", "", "", ""};
String[] subjects = {"", "", "", "", "", "我们", "你们", "他们", "这个", "那个"};
// 添加一些固定的测试句子
mTestSentences.add("我的手有很多手指头");
mTestSentences.add("他的头发很长");
mTestSentences.add("她的眼睛很大");
mTestSentences.add("这个鼻子很高");
mTestSentences.add("那个嘴巴很小");
Log.d(TAG, "Generated " + mTestSentences.size() + " test sentences");
}
// 设置是否使用EJML进行相似度计算
public void setUseEJMLForSimilarity(boolean useEJML) {
this.useEJMLForSimilarity = useEJML;
}
// 检查BGE模型是否初始化
public boolean isInitialized() {
return mBgeEngine != null && mBgeEngine.isInitialized();
}
// 释放BGE模型资源
public void deinitialize() {
if (mBgeEngine != null) {
mBgeEngine.deinitialize();
mBgeEngine = null;
Log.d(TAG, "BGE engine deinitialized");
}
}
}

View File

@@ -0,0 +1,104 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package com.digitalperson.engine;
import com.google.common.base.Ascii;
import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/** Basic tokenization (punctuation splitting, lower casing, etc.) */
public final class BasicTokenizer {
private final boolean doLowerCase;
public BasicTokenizer(boolean doLowerCase) {
this.doLowerCase = doLowerCase;
}
public List<String> tokenize(String text) {
String cleanedText = cleanText(text);
List<String> origTokens = whitespaceTokenize(cleanedText);
StringBuilder stringBuilder = new StringBuilder();
for (String token : origTokens) {
if (doLowerCase) {
token = Ascii.toLowerCase(token);
}
List<String> list = runSplitOnPunc(token);
for (String subToken : list) {
stringBuilder.append(subToken).append(" ");
}
}
return whitespaceTokenize(stringBuilder.toString());
}
/* Performs invalid character removal and whitespace cleanup on text. */
static String cleanText(String text) {
if (text == null) {
throw new NullPointerException("The input String is null.");
}
StringBuilder stringBuilder = new StringBuilder("");
for (int index = 0; index < text.length(); index++) {
char ch = text.charAt(index);
// Skip the characters that cannot be used.
if (CharChecker.isInvalid(ch) || CharChecker.isControl(ch)) {
continue;
}
if (CharChecker.isWhitespace(ch)) {
stringBuilder.append(" ");
} else {
stringBuilder.append(ch);
}
}
return stringBuilder.toString();
}
/* Runs basic whitespace cleaning and splitting on a piece of text. */
static List<String> whitespaceTokenize(String text) {
if (text == null) {
throw new NullPointerException("The input String is null.");
}
return Arrays.asList(text.split(" "));
}
/* Splits punctuation on a piece of text. */
static List<String> runSplitOnPunc(String text) {
if (text == null) {
throw new NullPointerException("The input String is null.");
}
List<String> tokens = new ArrayList<>();
boolean startNewWord = true;
for (int i = 0; i < text.length(); i++) {
char ch = text.charAt(i);
if (CharChecker.isPunctuation(ch)) {
tokens.add(String.valueOf(ch));
startNewWord = true;
} else {
if (startNewWord) {
tokens.add("");
startNewWord = false;
}
tokens.set(tokens.size() - 1, Iterables.getLast(tokens) + ch);
}
}
return tokens;
}
}

View File

@@ -0,0 +1,263 @@
package com.digitalperson.engine;
import android.content.Context;
import android.util.Log;
import java.util.HashMap;
import java.util.Map;
public class BgeEngineRKNN {
private static final String TAG = "BgeEngineRKNN";
private final long nativePtr;
private final Context mContext;
private boolean mIsInitialized = false;
private FullTokenizer mFullTokenizer = null;
private Map<String, Integer> mVocabMap = new HashMap<>();
static {
try {
// Load dependent libraries
System.loadLibrary("rknnrt");
System.loadLibrary("bgeEngine");
Log.d(TAG, "Successfully loaded librknnrt.so and libbgeEngine.so");
} catch (UnsatisfiedLinkError e) {
Log.e(TAG, "Failed to load native library", e);
throw e;
}
}
public BgeEngineRKNN(Context context) {
mContext = context;
try {
nativePtr = createBgeEngine();
if (nativePtr == 0) {
throw new RuntimeException("Failed to create native BGE engine");
}
} catch (UnsatisfiedLinkError e) {
Log.e(TAG, "Failed to load native library", e);
throw new RuntimeException("Failed to load native library: " + e.getMessage(), e);
}
}
public boolean initialize(String modelPath) {
// 自动查找词汇表路径
String vocabPath = modelPath.replace("bge-small-zh-v1.5.rknn", "vocab.txt");
return initialize(modelPath, vocabPath);
}
public boolean initialize(String modelPath, String vocabPath) {
if (mIsInitialized) {
Log.i(TAG, "Model already initialized");
return true;
}
Log.d(TAG, "Loading BGE model: " + modelPath);
Log.d(TAG, "Loading vocab: " + vocabPath);
// 加载词汇表
if (!loadVocab(vocabPath)) {
Log.e(TAG, "Failed to load vocab");
return false;
}
// 创建FullTokenizer实例
mFullTokenizer = new FullTokenizer(mVocabMap, true); // true表示使用小写
int ret = loadModel(nativePtr, modelPath, vocabPath);
if (ret == 0) {
mIsInitialized = true;
return true;
}
return false;
}
/**
* Load vocabulary from file
* @param vocabPath Path to vocabulary file
* @return True if successful, false otherwise
*/
private boolean loadVocab(String vocabPath) {
try {
java.io.BufferedReader reader = new java.io.BufferedReader(
new java.io.FileReader(vocabPath));
String line;
int index = 0;
while ((line = reader.readLine()) != null) {
line = line.trim();
if (!line.isEmpty()) {
mVocabMap.put(line, index);
index++;
}
}
reader.close();
Log.d(TAG, "Vocab loaded successfully with " + mVocabMap.size() + " tokens");
return true;
} catch (java.io.IOException e) {
Log.e(TAG, "Failed to load vocab: " + e.getMessage());
e.printStackTrace();
return false;
}
}
public void deinitialize() {
if (nativePtr != 0) {
freeModel(nativePtr);
}
mIsInitialized = false;
}
public boolean isInitialized() {
return mIsInitialized;
}
/**
* Get the embedding dimension of the model
* @return Embedding dimension
*/
public int getEmbeddingDim() {
if (!mIsInitialized) {
Log.e(TAG, "Engine not initialized");
return -1;
}
return getEmbeddingDim(nativePtr);
}
/**
* Calculate embedding for a single text
* @param text Input text
* @return Embedding vector as float array
*/
public float[] getEmbedding(String text) {
if (!mIsInitialized) {
Log.e(TAG, "Engine not initialized");
return null;
}
return getEmbeddingNative(nativePtr, text);
}
/**
* Calculate cosine similarity between two texts
* @param text1 First text
* @param text2 Second text
* @return Cosine similarity score between -1.0 and 1.0
*/
public float calculateSimilarity(String text1, String text2) {
if (!mIsInitialized) {
Log.e(TAG, "Engine not initialized");
return 0.0f;
}
return calculateSimilarityNative(nativePtr, text1, text2);
}
/**
* Get token ids for a text using FullTokenizer
* @param text Input text
* @return Token ids as int array
*/
public int[] getTokens(String text) {
if (!mIsInitialized) {
Log.e(TAG, "Engine not initialized");
return null;
}
if (mFullTokenizer == null) {
Log.e(TAG, "FullTokenizer not initialized");
return null;
}
// 使用FullTokenizer进行tokenization
java.util.List<String> tokens = mFullTokenizer.tokenize(text);
java.util.List<Integer> ids = mFullTokenizer.convertTokensToIds(tokens);
// 转换为int数组
int[] result = new int[ids.size()];
for (int i = 0; i < ids.size(); i++) {
result[i] = ids.get(i);
}
return result;
}
/**
* Calculate [UNK] ratio for token ids
* @param tokens Token ids array
* @return [UNK] ratio between 0.0 and 1.0
*/
public float calculateUnkRatio(int[] tokens) {
if (!mIsInitialized) {
Log.e(TAG, "Engine not initialized");
return 0.0f;
}
if (mFullTokenizer == null || tokens == null || tokens.length == 0) {
return 0.0f;
}
// 获取[UNK]的id
int unkId = mVocabMap.getOrDefault("[UNK]", -1);
if (unkId == -1) {
return 0.0f;
}
// 统计[UNK]的数量
int unkCount = 0;
for (int token : tokens) {
if (token == unkId) {
unkCount++;
}
}
return (float) unkCount / tokens.length;
}
/**
* Get BERT model inputs (input_ids, attention_mask, token_type_ids)
* @param text Input text
* @param maxSeqLength Maximum sequence length
* @return Array of int arrays: [input_ids, attention_mask, token_type_ids]
*/
public int[][] getBertInputs(String text, int maxSeqLength) {
if (!mIsInitialized) {
Log.e(TAG, "Engine not initialized");
return null;
}
if (mFullTokenizer == null) {
Log.e(TAG, "FullTokenizer not initialized");
return null;
}
// 使用FullTokenizer进行tokenization
java.util.List<String> tokens = mFullTokenizer.tokenize(text);
java.util.List<Integer> ids = mFullTokenizer.convertTokensToIds(tokens);
// 准备BERT输入
int[] inputIds = new int[maxSeqLength];
int[] attentionMask = new int[maxSeqLength];
int[] tokenTypeIds = new int[maxSeqLength];
// 填充[CLS] token
inputIds[0] = mVocabMap.getOrDefault("[CLS]", 101);
attentionMask[0] = 1;
// 填充文本token
int textLength = ids.size();
int maxTextLength = maxSeqLength - 2; // 减去[CLS]和[SEP]
int actualLength = Math.min(textLength, maxTextLength);
for (int i = 0; i < actualLength; i++) {
inputIds[i + 1] = ids.get(i);
attentionMask[i + 1] = 1;
}
// 填充[SEP] token
inputIds[actualLength + 1] = mVocabMap.getOrDefault("[SEP]", 102);
attentionMask[actualLength + 1] = 1;
return new int[][]{inputIds, attentionMask, tokenTypeIds};
}
// Native methods
private native long createBgeEngine();
private native int loadModel(long nativePtr, String modelPath, String vocabPath);
private native void freeModel(long ptr);
private native float[] getEmbeddingNative(long ptr, String text);
private native int getEmbeddingDim(long ptr);
private native float calculateSimilarityNative(long ptr, String text1, String text2);
}

View File

@@ -0,0 +1,58 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package com.digitalperson.engine;
/** To check whether a char is whitespace/control/punctuation. */
final class CharChecker {
/** To judge whether it's an empty or unknown character. */
public static boolean isInvalid(char ch) {
return (ch == 0 || ch == 0xfffd);
}
/** To judge whether it's a control character(exclude whitespace). */
public static boolean isControl(char ch) {
if (Character.isWhitespace(ch)) {
return false;
}
int type = Character.getType(ch);
return (type == Character.CONTROL || type == Character.FORMAT);
}
/** To judge whether it can be regarded as a whitespace. */
public static boolean isWhitespace(char ch) {
if (Character.isWhitespace(ch)) {
return true;
}
int type = Character.getType(ch);
return (type == Character.SPACE_SEPARATOR
|| type == Character.LINE_SEPARATOR
|| type == Character.PARAGRAPH_SEPARATOR);
}
/** To judge whether it's a punctuation. */
public static boolean isPunctuation(char ch) {
int type = Character.getType(ch);
return (type == Character.CONNECTOR_PUNCTUATION
|| type == Character.DASH_PUNCTUATION
|| type == Character.START_PUNCTUATION
|| type == Character.END_PUNCTUATION
|| type == Character.INITIAL_QUOTE_PUNCTUATION
|| type == Character.FINAL_QUOTE_PUNCTUATION
|| type == Character.OTHER_PUNCTUATION);
}
private CharChecker() {}
}

View File

@@ -0,0 +1,52 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package com.digitalperson.engine;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
/**
* A java realization of Bert tokenization. Original python code:
* https://github.com/google-research/bert/blob/master/tokenization.py runs full tokenization to
* tokenize a String into split subtokens or ids.
*/
public final class FullTokenizer {
private final BasicTokenizer basicTokenizer;
private final WordpieceTokenizer wordpieceTokenizer;
private final Map<String, Integer> dic;
public FullTokenizer(Map<String, Integer> inputDic, boolean doLowerCase) {
dic = inputDic;
basicTokenizer = new BasicTokenizer(doLowerCase);
wordpieceTokenizer = new WordpieceTokenizer(inputDic);
}
public List<String> tokenize(String text) {
List<String> splitTokens = new ArrayList<>();
for (String token : basicTokenizer.tokenize(text)) {
splitTokens.addAll(wordpieceTokenizer.tokenize(token));
}
return splitTokens;
}
public List<Integer> convertTokensToIds(List<String> tokens) {
List<Integer> outputIds = new ArrayList<>();
for (String token : tokens) {
outputIds.add(dic.get(token));
}
return outputIds;
}
}

View File

@@ -0,0 +1,94 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package com.digitalperson.engine;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
/** Word piece tokenization to split a piece of text into its word pieces. */
public final class WordpieceTokenizer {
private final Map<String, Integer> dic;
private static final String UNKNOWN_TOKEN = "[UNK]"; // For unknown words.
private static final int MAX_INPUTCHARS_PER_WORD = 200;
public WordpieceTokenizer(Map<String, Integer> vocab) {
dic = vocab;
}
/**
* Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first
* algorithm to perform tokenization using the given vocabulary. For example: input = "unaffable",
* output = ["un", "##aff", "##able"].
*
* @param text: A single token or whitespace separated tokens. This should have already been
* passed through `BasicTokenizer.
* @return A list of wordpiece tokens.
*/
public List<String> tokenize(String text) {
if (text == null) {
throw new NullPointerException("The input String is null.");
}
List<String> outputTokens = new ArrayList<>();
for (String token : BasicTokenizer.whitespaceTokenize(text)) {
if (token.length() > MAX_INPUTCHARS_PER_WORD) {
outputTokens.add(UNKNOWN_TOKEN);
continue;
}
boolean isBad = false; // Mark if a word cannot be tokenized into known subwords.
int start = 0;
List<String> subTokens = new ArrayList<>();
while (start < token.length()) {
String curSubStr = "";
int end = token.length(); // Longer substring matches first.
while (start < end) {
String subStr =
(start == 0) ? token.substring(start, end) : "##" + token.substring(start, end);
if (dic.containsKey(subStr)) {
curSubStr = subStr;
break;
}
end--;
}
// The word doesn't contain any known subwords.
if ("".equals(curSubStr)) {
isBad = true;
break;
}
// curSubStr is the longeset subword that can be found.
subTokens.add(curSubStr);
// Proceed to tokenize the resident string.
start = end;
}
if (isBad) {
outputTokens.add(UNKNOWN_TOKEN);
} else {
outputTokens.addAll(subTokens);
}
}
return outputTokens;
}
}

View File

@@ -185,7 +185,7 @@ class QuestionGenerationAgent(
val generationPrompt = buildGenerationPrompt(prompt, userProfile) val generationPrompt = buildGenerationPrompt(prompt, userProfile)
// 3. 调用大模型生成题目 // 3. 调用大模型生成题目
generateQuestionFromLLM(generationPrompt) { generatedQuestion -> generateQuestionFromLLM(generationPrompt, prompt) { generatedQuestion ->
if (generatedQuestion == null) { if (generatedQuestion == null) {
Log.w(TAG, "Failed to generate question") Log.w(TAG, "Failed to generate question")
return@generateQuestionFromLLM return@generateQuestionFromLLM
@@ -301,18 +301,22 @@ class QuestionGenerationAgent(
/** /**
* 调用LLM生成题目 * 调用LLM生成题目
*/ */
private fun generateQuestionFromLLM(prompt: String, onResult: (GeneratedQuestion?) -> Unit) { private fun generateQuestionFromLLM(
promptText: String,
promptMeta: QuestionPrompt,
onResult: (GeneratedQuestion?) -> Unit
) {
// 优先使用本地LLM如果不可用则使用云端LLM // 优先使用本地LLM如果不可用则使用云端LLM
if (llmManager != null) { if (llmManager != null) {
// 使用本地LLM // 使用本地LLM
llmManager.generate(prompt) { response: String -> llmManager.generate(promptText) { response: String ->
parseGeneratedQuestion(response, onResult) parseGeneratedQuestion(response, promptMeta, onResult)
} }
} else if (cloudLLMGenerator != null) { } else if (cloudLLMGenerator != null) {
// 使用云端LLM // 使用云端LLM
Log.d(TAG, "Using cloud LLM to generate question") Log.d(TAG, "Using cloud LLM to generate question")
cloudLLMGenerator.invoke(prompt) { response -> cloudLLMGenerator.invoke(promptText) { response ->
parseGeneratedQuestion(response, onResult) parseGeneratedQuestion(response, promptMeta, onResult)
} }
} else { } else {
Log.e(TAG, "No LLM available (neither local nor cloud)") Log.e(TAG, "No LLM available (neither local nor cloud)")
@@ -323,16 +327,20 @@ class QuestionGenerationAgent(
/** /**
* 解析生成的题目 * 解析生成的题目
*/ */
private fun parseGeneratedQuestion(response: String, onResult: (GeneratedQuestion?) -> Unit) { private fun parseGeneratedQuestion(
response: String,
promptMeta: QuestionPrompt,
onResult: (GeneratedQuestion?) -> Unit
) {
try { try {
val json = extractJsonFromResponse(response) val json = extractJsonFromResponse(response)
if (json != null) { if (json != null) {
val question = GeneratedQuestion( val question = GeneratedQuestion(
content = json.getString("content"), content = json.getString("content"),
answer = json.getString("answer"), answer = json.getString("answer"),
subject = "生活适应", subject = promptMeta.subject,
grade = 1, grade = promptMeta.grade,
difficulty = 1 difficulty = promptMeta.difficulty
) )
onResult(question) onResult(question)
} else { } else {

View File

@@ -7,6 +7,7 @@ import android.util.Log
import com.digitalperson.config.AppConfig import com.digitalperson.config.AppConfig
import java.io.File import java.io.File
import java.io.FileOutputStream import java.io.FileOutputStream
import java.io.InputStream
object FileHelper { object FileHelper {
private const val TAG = AppConfig.TAG private const val TAG = AppConfig.TAG
@@ -64,6 +65,54 @@ object FileHelper {
val files = arrayOf(AppConfig.FaceRecognition.MODEL_NAME) val files = arrayOf(AppConfig.FaceRecognition.MODEL_NAME)
return copyAssetsToInternal(context, AppConfig.FaceRecognition.MODEL_DIR, outDir, files) return copyAssetsToInternal(context, AppConfig.FaceRecognition.MODEL_DIR, outDir, files)
} }
/**
* 将 BGE 相关文件从 assets 复制到 [Context.getFilesDir]/[AppConfig.Bge.ASSET_DIR]。
* 若已存在且长度与 asset 一致则跳过(与 [com.digitalperson.embedding.SimilarityManager] 行为一致)。
*/
@JvmStatic
fun copyBgeModels(context: Context): File? {
val assetDir = AppConfig.Bge.ASSET_DIR
val modelDir = File(context.filesDir, assetDir).apply { mkdirs() }
val files = arrayOf(
AppConfig.Bge.MODEL_FILE,
"vocab.txt",
"tokenizer.json",
"tokenizer_config.json"
)
for (name in files) {
val outFile = File(modelDir, name)
val assetPath = "$assetDir/$name"
try {
var skip = false
if (outFile.exists()) {
context.assets.open(assetPath).use { input ->
val assetSize = assetSizeOrNegative(input)
if (assetSize >= 0 && outFile.length() == assetSize) skip = true
}
}
if (skip) continue
context.assets.open(assetPath).use { input ->
FileOutputStream(outFile).use { output -> input.copyTo(output) }
}
Log.i(TAG, "Copied BGE asset: $name")
} catch (e: Exception) {
Log.e(TAG, "copyBgeModels failed for $name: ${e.message}", e)
return null
}
}
return modelDir
}
/** [InputStream.available] 在部分实现上不可靠;失败时返回 -1强制重拷。 */
private fun assetSizeOrNegative(input: InputStream): Long {
return try {
val n = input.available()
if (n > 0) n.toLong() else -1L
} catch (_: Exception) {
-1L
}
}
fun ensureDir(dir: File): File { fun ensureDir(dir: File): File {
if (!dir.exists()) { if (!dir.exists()) {

View File

@@ -38,6 +38,18 @@
android:layout_height="match_parent" /> android:layout_height="match_parent" />
</FrameLayout> </FrameLayout>
<ImageView
android:id="@+id/ref_match_image"
android:layout_width="200dp"
android:layout_height="200dp"
android:layout_margin="12dp"
android:background="#66000000"
android:scaleType="fitCenter"
android:visibility="gone"
app:layout_constraintBottom_toTopOf="@+id/button_row"
app:layout_constraintStart_toStartOf="parent"
tools:visibility="visible" />
<ScrollView <ScrollView
android:id="@+id/scroll_view" android:id="@+id/scroll_view"
android:layout_width="0dp" android:layout_width="0dp"

View File

@@ -31,6 +31,18 @@
android:textIsSelectable="true" /> android:textIsSelectable="true" />
</ScrollView> </ScrollView>
<ImageView
android:id="@+id/ref_match_image"
android:layout_width="200dp"
android:layout_height="200dp"
android:layout_margin="12dp"
android:background="#66000000"
android:scaleType="fitCenter"
android:visibility="gone"
app:layout_constraintBottom_toTopOf="@+id/button_row"
app:layout_constraintStart_toStartOf="parent"
tools:visibility="visible" />
<LinearLayout <LinearLayout
android:id="@+id/button_row" android:id="@+id/button_row"
android:layout_width="0dp" android:layout_width="0dp"