word similarity

This commit is contained in:
gcw_4spBpAfv
2026-04-21 23:05:59 +08:00
parent e23aaaa4ba
commit 1550783eef
36 changed files with 44822 additions and 12 deletions

View File

@@ -0,0 +1,26 @@
package com.digitalperson
import android.app.Application
import android.util.Log
import com.digitalperson.config.AppConfig
import com.digitalperson.embedding.RefEmbeddingIndexer
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.launch
class DigitalPersonApp : Application() {
private val appScope = CoroutineScope(SupervisorJob() + Dispatchers.Default)
override fun onCreate() {
super.onCreate()
appScope.launch {
try {
RefEmbeddingIndexer.runOnce(this@DigitalPersonApp)
} catch (t: Throwable) {
Log.e(AppConfig.TAG, "[RefEmbed] 索引任务异常", t)
}
}
}
}

View File

@@ -38,6 +38,7 @@ import com.digitalperson.interaction.ConversationSummaryMemory
import java.io.File
import android.graphics.BitmapFactory
import android.widget.ImageView
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
@@ -49,6 +50,7 @@ import kotlinx.coroutines.withContext
import com.digitalperson.onboard_testing.FaceRecognitionTest
import com.digitalperson.onboard_testing.LLMSummaryTest
import com.digitalperson.embedding.RefImageMatcher
class Live2DChatActivity : AppCompatActivity() {
companion object {
@@ -109,6 +111,7 @@ class Live2DChatActivity : AppCompatActivity() {
private lateinit var faceRecognitionTest: FaceRecognitionTest
private lateinit var llmSummaryTest: LLMSummaryTest
private var refMatchImageView: ImageView? = null
override fun onRequestPermissionsResult(
requestCode: Int,
@@ -159,6 +162,8 @@ class Live2DChatActivity : AppCompatActivity() {
speakingPlayerViewId = 0,
live2dViewId = R.id.live2d_view
)
refMatchImageView = findViewById(R.id.ref_match_image)
cameraPreviewView = findViewById(R.id.camera_preview)
cameraPreviewView.implementationMode = PreviewView.ImplementationMode.COMPATIBLE
@@ -611,6 +616,7 @@ class Live2DChatActivity : AppCompatActivity() {
runOnUiThread {
uiManager.appendToUi("${filteredText.orEmpty()}\n")
}
maybeShowMatchedRefImage(filteredText ?: response)
}
interactionCoordinator.onCloudFinalResponse(response)
}
@@ -648,6 +654,24 @@ class Live2DChatActivity : AppCompatActivity() {
onStopClicked(userInitiated = false)
}
}
private fun maybeShowMatchedRefImage(text: String) {
val imageView = refMatchImageView ?: return
ioScope.launch {
val match = RefImageMatcher.findBestMatch(applicationContext, text)
if (match == null) return@launch
val bitmap = try {
assets.open(match.pngAssetPath).use { BitmapFactory.decodeStream(it) }
} catch (_: Throwable) {
null
}
if (bitmap == null) return@launch
runOnUiThread {
imageView.setImageBitmap(bitmap)
imageView.visibility = android.view.View.VISIBLE
}
}
}
private fun createTtsCallback() = object : TtsController.TtsCallback {
override fun onTtsStarted(text: String) {

View File

@@ -25,6 +25,7 @@ import android.view.View
import androidx.lifecycle.Lifecycle
import androidx.lifecycle.LifecycleOwner
import androidx.lifecycle.LifecycleRegistry
import android.widget.ImageView
import com.unity3d.player.UnityPlayer
import com.unity3d.player.UnityPlayerActivity
import com.digitalperson.audio.AudioProcessor
@@ -47,6 +48,8 @@ import com.digitalperson.tts.TtsController
import com.digitalperson.util.FileHelper
import com.digitalperson.vad.VadManager
import kotlinx.coroutines.*
import com.digitalperson.embedding.RefImageMatcher
import android.graphics.BitmapFactory
class UnityDigitalPersonActivity : UnityPlayerActivity(), LifecycleOwner {
@@ -108,6 +111,7 @@ class UnityDigitalPersonActivity : UnityPlayerActivity(), LifecycleOwner {
private lateinit var holdToSpeakButton: Button
private var recordButtonGlow: View? = null
private var pulseAnimator: ObjectAnimator? = null
private var refMatchImageView: ImageView? = null
// 音频和AI模块
@@ -254,6 +258,7 @@ class UnityDigitalPersonActivity : UnityPlayerActivity(), LifecycleOwner {
chatHistoryText = chatLayout.findViewById(R.id.my_text)
holdToSpeakButton = chatLayout.findViewById(R.id.record_button)
recordButtonGlow = chatLayout.findViewById(R.id.record_button_glow)
refMatchImageView = chatLayout.findViewById(R.id.ref_match_image)
// 根据配置设置按钮可见性
if (AppConfig.USE_HOLD_TO_SPEAK) {
@@ -735,6 +740,7 @@ class UnityDigitalPersonActivity : UnityPlayerActivity(), LifecycleOwner {
val filteredText = ttsController.speakLlmResponse(response)
android.util.Log.d("UnityDigitalPerson", "LLM response filtered: ${filteredText?.take(60)}")
if (filteredText != null) appendChat("助手: $filteredText")
maybeShowMatchedRefImage(filteredText ?: response)
interactionCoordinator.onCloudFinalResponse(filteredText ?: response.trim())
}
@@ -751,6 +757,25 @@ class UnityDigitalPersonActivity : UnityPlayerActivity(), LifecycleOwner {
}
}
private fun maybeShowMatchedRefImage(text: String) {
val imageView = refMatchImageView ?: return
// Unity Activity already has coroutines
CoroutineScope(SupervisorJob() + Dispatchers.IO).launch {
val match = RefImageMatcher.findBestMatch(applicationContext, text)
if (match == null) return@launch
val bitmap = try {
assets.open(match.pngAssetPath).use { BitmapFactory.decodeStream(it) }
} catch (_: Throwable) {
null
}
if (bitmap == null) return@launch
runOnUiThread {
imageView.setImageBitmap(bitmap)
imageView.visibility = View.VISIBLE
}
}
}
private fun requestLocalThought(prompt: String, onResult: (String) -> Unit) {
val local = llmManager
if (local == null) {

View File

@@ -110,6 +110,19 @@ object AppConfig {
const val MODEL_SIZE_ESTIMATE = 500L * 1024 * 1024 // 500MB
}
/** BGE-small-zh-v1.5 文本嵌入RKNN用于语义相似度 / 检索。 */
object Bge {
const val ASSET_DIR = "bge_models"
const val MODEL_FILE = "bge-small-zh-v1.5.rknn"
}
/**
* app/note/ref 通过 Gradle 额外 assets 目录打入 apk 后,在 assets 中的根路径为 `ref/`。
*/
object RefCorpus {
const val ASSETS_ROOT = "ref"
}
object OnboardTesting {
// 测试人脸识别
const val FACE_REGONITION = false

View File

@@ -9,15 +9,24 @@ import com.digitalperson.data.dao.UserAnswerDao
import com.digitalperson.data.dao.UserMemoryDao
import com.digitalperson.data.dao.ChatMessageDao
import com.digitalperson.data.dao.ConversationSummaryDao
import com.digitalperson.data.dao.RefTextEmbeddingDao
import com.digitalperson.data.entity.Question
import com.digitalperson.data.entity.UserAnswer
import com.digitalperson.data.entity.UserMemory
import com.digitalperson.data.entity.ChatMessageEntity
import com.digitalperson.data.entity.ConversationSummaryEntity
import com.digitalperson.data.entity.RefTextEmbedding
@Database(
entities = [UserMemory::class, Question::class, UserAnswer::class, ChatMessageEntity::class, ConversationSummaryEntity::class],
version = 4,
entities = [
UserMemory::class,
Question::class,
UserAnswer::class,
ChatMessageEntity::class,
ConversationSummaryEntity::class,
RefTextEmbedding::class
],
version = 5,
exportSchema = false
)
abstract class AppDatabase : RoomDatabase() {
@@ -26,6 +35,7 @@ abstract class AppDatabase : RoomDatabase() {
abstract fun userAnswerDao(): UserAnswerDao
abstract fun chatMessageDao(): ChatMessageDao
abstract fun conversationSummaryDao(): ConversationSummaryDao
abstract fun refTextEmbeddingDao(): RefTextEmbeddingDao
companion object {
private const val DATABASE_NAME = "digital_human.db"

View File

@@ -9,6 +9,17 @@ import com.digitalperson.data.entity.Question
interface QuestionDao {
@Insert
fun insert(question: Question): Long
@Query(
"""
SELECT * FROM questions
WHERE content = :content
AND ((:subject IS NULL AND subject IS NULL) OR subject = :subject)
AND ((:grade IS NULL AND grade IS NULL) OR grade = :grade)
LIMIT 1
"""
)
fun findByContentSubjectGrade(content: String, subject: String?, grade: Int?): Question?
@Query("SELECT * FROM questions WHERE subject = :subject ORDER BY difficulty")
fun getQuestionsBySubject(subject: String): List<Question>

View File

@@ -0,0 +1,20 @@
package com.digitalperson.data.dao
import androidx.room.Dao
import androidx.room.Insert
import androidx.room.OnConflictStrategy
import androidx.room.Query
import com.digitalperson.data.entity.RefTextEmbedding
@Dao
interface RefTextEmbeddingDao {
@Query("SELECT * FROM ref_text_embeddings WHERE assetPath = :path LIMIT 1")
fun getByPath(path: String): RefTextEmbedding?
@Insert(onConflict = OnConflictStrategy.REPLACE)
fun insert(row: RefTextEmbedding): Long
@Query("SELECT * FROM ref_text_embeddings")
fun getAll(): List<RefTextEmbedding>
}

View File

@@ -0,0 +1,23 @@
package com.digitalperson.data.entity
import androidx.room.Entity
import androidx.room.Index
import androidx.room.PrimaryKey
import com.digitalperson.data.util.embeddingBytesToFloatArray
@Entity(
tableName = "ref_text_embeddings",
indices = [Index(value = ["assetPath"], unique = true)]
)
data class RefTextEmbedding(
@PrimaryKey(autoGenerate = true) val id: Long = 0,
/** assets 相对路径,如 ref/一年级.../xxx.txt */
val assetPath: String,
/** 参与嵌入的正文 UTF-8 的 SHA-256 十六进制,用于跳过未变更文件 */
val contentHash: String,
val dim: Int,
val embedding: ByteArray,
val updatedAt: Long = System.currentTimeMillis()
) {
fun toFloatArray(): FloatArray = embeddingBytesToFloatArray(embedding)
}

View File

@@ -0,0 +1,18 @@
package com.digitalperson.data.util
import java.nio.ByteBuffer
import java.nio.ByteOrder
fun floatArrayToEmbeddingBytes(values: FloatArray): ByteArray {
val bb = ByteBuffer.allocate(values.size * 4).order(ByteOrder.LITTLE_ENDIAN)
for (v in values) {
bb.putFloat(v)
}
return bb.array()
}
fun embeddingBytesToFloatArray(blob: ByteArray): FloatArray {
val bb = ByteBuffer.wrap(blob).order(ByteOrder.LITTLE_ENDIAN)
val n = blob.size / 4
return FloatArray(n) { bb.getFloat() }
}

View File

@@ -0,0 +1,48 @@
package com.digitalperson.embedding
import android.content.Context
import com.digitalperson.config.AppConfig
import com.digitalperson.engine.BgeEngineRKNN
import com.digitalperson.util.FileHelper
import java.io.File
/**
* 懒加载 BGE 文本嵌入RKNN。用于对话文本与预存标注的语义相似度检索。
*
* 初始化会复制 [AppConfig.Bge.ASSET_DIR] 下资源到内部存储并加载 [AppConfig.Bge.MODEL_FILE]。
*/
object BgeEmbedding {
@Volatile
private var engine: BgeEngineRKNN? = null
fun isReady(): Boolean = engine?.isInitialized == true
/**
* 在主线程调用会阻塞;建议在后台线程或协程 [Dispatchers.IO] 中调用。
*/
@Synchronized
fun initialize(context: Context): Boolean {
if (engine?.isInitialized == true) return true
val dir = FileHelper.copyBgeModels(context.applicationContext) ?: return false
val path = File(dir, AppConfig.Bge.MODEL_FILE).absolutePath
if (!File(path).exists()) return false
val eng = BgeEngineRKNN(context.applicationContext)
if (!eng.initialize(path)) return false
engine = eng
return true
}
@Synchronized
fun release() {
engine?.deinitialize()
engine = null
}
fun getEmbedding(text: String): FloatArray? = engine?.getEmbedding(text)
fun similarity(text1: String, text2: String): Float? =
engine?.calculateSimilarity(text1, text2)
fun embeddingDim(): Int = engine?.embeddingDim ?: -1
}

View File

@@ -0,0 +1,34 @@
package com.digitalperson.embedding
import android.content.Context
internal object RefCorpusAssetScanner {
/**
* 递归列出 [root] 目录下(含子目录)所有 `.txt` 的 assets 路径,使用 `/` 分隔。
*/
fun listTxtFilesUnder(context: Context, root: String): List<String> {
val am = context.assets
val out = ArrayList<String>()
fun walk(prefix: String) {
val children = am.list(prefix) ?: return
if (children.isEmpty()) return
for (child in children) {
if (child.isEmpty()) continue
val full = "$prefix/$child"
val grand = am.list(full)
if (grand.isNullOrEmpty()) {
if (child.endsWith(".txt", ignoreCase = true)) {
out.add(full)
}
} else {
walk(full)
}
}
}
walk(root.removeSuffix("/"))
return out.sorted()
}
}

View File

@@ -0,0 +1,173 @@
package com.digitalperson.embedding
import android.content.Context
import android.util.Log
import com.digitalperson.config.AppConfig
import com.digitalperson.data.AppDatabase
import com.digitalperson.data.entity.Question
import com.digitalperson.data.entity.RefTextEmbedding
import com.digitalperson.data.util.floatArrayToEmbeddingBytes
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import java.security.MessageDigest
/**
* 开机(进程启动)时在后台扫描 assets 下 [AppConfig.RefCorpus.ASSETS_ROOT] 中全部 `.txt`
* 跳过 `#` 行后做 BGE 嵌入,写入 Room 并填充 [RefEmbeddingMemoryCache]。
*
* DAO 为同步接口,整个函数在 [Dispatchers.IO] 上运行,不阻塞主线程。
*/
object RefEmbeddingIndexer {
private const val TAG = AppConfig.TAG
suspend fun runOnce(context: Context) = withContext(Dispatchers.IO) {
val app = context.applicationContext
val db = AppDatabase.getInstance(app)
val dao = db.refTextEmbeddingDao()
val questionDao = db.questionDao()
if (!BgeEmbedding.initialize(app)) {
Log.e(TAG, "[RefEmbed] BGE 初始化失败,跳过 ref 语料索引")
return@withContext
}
val root = AppConfig.RefCorpus.ASSETS_ROOT
val paths = RefCorpusAssetScanner.listTxtFilesUnder(app, root)
Log.i(TAG, "[RefEmbed] 发现 ${paths.size} 个 txtroot=$root")
var skipped = 0
var embedded = 0
var empty = 0
var failed = 0
for (path in paths) {
val raw = try {
app.assets.open(path).bufferedReader(Charsets.UTF_8).use { it.readText() }
} catch (e: Exception) {
Log.w(TAG, "[RefEmbed] 读取失败 $path: ${e.message}")
failed++
continue
}
// 题库:遇到包含 ?/ 的行,写入 questions
val subject = extractSubjectFromRaw(raw)
val grade = extractGradeFromPath(path)
val questionLines = extractQuestionLines(raw)
for (line in questionLines) {
val content = line.trim()
if (content.isEmpty()) continue
val exists = questionDao.findByContentSubjectGrade(content, subject, grade)
if (exists == null) {
questionDao.insert(
Question(
id = 0,
content = content,
answer = null,
subject = subject,
grade = grade,
difficulty = 1,
createdAt = System.currentTimeMillis()
)
)
}
}
val embedText = RefTxtEmbedText.fromRawFileContent(raw)
if (embedText.isEmpty()) {
empty++
continue
}
val hash = sha256Hex(embedText.toByteArray(Charsets.UTF_8))
// 同步 DAO 调用(已在 IO 线程)
val existing = dao.getByPath(path)
if (existing != null && existing.contentHash == hash) {
RefEmbeddingMemoryCache.put(path, normalizeL2(existing.toFloatArray()))
skipped++
continue
}
val vec = BgeEmbedding.getEmbedding(embedText)
if (vec == null || vec.isEmpty()) {
Log.w(TAG, "[RefEmbed] 嵌入为空 $path")
failed++
continue
}
val normalized = normalizeL2(vec)
dao.insert(
RefTextEmbedding(
assetPath = path,
contentHash = hash,
dim = normalized.size,
embedding = floatArrayToEmbeddingBytes(normalized)
)
)
RefEmbeddingMemoryCache.put(path, normalized)
embedded++
}
Log.i(
TAG,
"[RefEmbed] 完成 embedded=$embedded skipped=$skipped empty=$empty failed=$failed cacheSize=${RefEmbeddingMemoryCache.size()}"
)
}
private fun extractSubjectFromRaw(raw: String): String? {
val line = raw.lineSequence()
.map { it.trimEnd() }
.firstOrNull { it.trimStart().startsWith("#") }
?: return null
val s = line.trimStart().removePrefix("#").trim()
return s.ifEmpty { null }
}
private fun extractQuestionLines(raw: String): List<String> {
return raw.lineSequence()
.map { it.trimEnd() }
.filter { it.isNotBlank() }
.filter { !it.trimStart().startsWith("#") }
.filter { it.contains('?') || it.contains('') }
.toList()
}
private fun extractGradeFromPath(assetPath: String): Int? {
// example: ref/一年级上-生活适应/... or ref/二年级下-...
val idx = assetPath.indexOf("年级")
if (idx <= 0) return null
val prefix = assetPath.substring(0, idx)
val cn = prefix.lastOrNull() ?: return null
return chineseGradeToInt(cn)
}
private fun chineseGradeToInt(c: Char): Int? {
return when (c) {
'一' -> 1
'二' -> 2
'三' -> 3
'四' -> 4
'五' -> 5
'六' -> 6
'七' -> 7
'八' -> 8
'九' -> 9
'十' -> 10
else -> null
}
}
private fun normalizeL2(v: FloatArray): FloatArray {
var sum = 0.0
for (x in v) sum += (x * x).toDouble()
val norm = kotlin.math.sqrt(sum).toFloat()
if (norm <= 1e-12f) return v.copyOf()
return FloatArray(v.size) { i -> v[i] / norm }
}
private fun sha256Hex(data: ByteArray): String {
val md = MessageDigest.getInstance("SHA-256")
val digest = md.digest(data)
return digest.joinToString("") { "%02x".format(it) }
}
}

View File

@@ -0,0 +1,30 @@
package com.digitalperson.embedding
import java.util.concurrent.ConcurrentHashMap
/**
* ref 语料 txt 嵌入的内存缓存,键为 assets 相对路径(与数据库 [com.digitalperson.data.entity.RefTextEmbedding.assetPath] 一致)。
*/
object RefEmbeddingMemoryCache {
private val vectors = ConcurrentHashMap<String, FloatArray>()
fun put(assetPath: String, embedding: FloatArray) {
vectors[assetPath] = embedding.copyOf()
}
fun get(assetPath: String): FloatArray? {
val v = vectors[assetPath] ?: return null
return v.copyOf()
}
fun clear() {
vectors.clear()
}
/** 只读快照(向量已为拷贝,可安全使用)。 */
fun snapshot(): Map<String, FloatArray> =
vectors.mapValues { it.value.copyOf() }
fun size(): Int = vectors.size
}

View File

@@ -0,0 +1,95 @@
package com.digitalperson.embedding
import android.content.Context
import android.util.Log
import com.digitalperson.config.AppConfig
import kotlin.math.sqrt
data class RefImageMatch(
val txtAssetPath: String,
val pngAssetPath: String,
val score: Float
)
object RefImageMatcher {
private const val TAG = AppConfig.TAG
/**
* @param threshold 余弦相似度阈值(向量已归一化时等价于 dot product
*/
fun findBestMatch(
context: Context,
text: String,
threshold: Float = 0.70f
): RefImageMatch? {
val query = text.trim()
if (query.isEmpty()) return null
if (!BgeEmbedding.isReady()) {
val ok = BgeEmbedding.initialize(context.applicationContext)
if (!ok) {
Log.w(TAG, "[RefMatch] BGE not ready, skip match")
return null
}
}
val q = BgeEmbedding.getEmbedding(query) ?: return null
if (q.isEmpty()) return null
val qn = normalizeL2(q)
val vectors = RefEmbeddingMemoryCache.snapshot()
if (vectors.isEmpty()) return null
var bestPath: String? = null
var bestScore = -1f
for ((path, v) in vectors) {
if (v.isEmpty() || v.size != qn.size) continue
val score = dot(qn, v)
if (score > bestScore) {
bestScore = score
bestPath = path
}
}
val txtPath = bestPath ?: return null
if (bestScore < threshold) return null
val pngPath = if (txtPath.endsWith(".txt", ignoreCase = true)) {
txtPath.dropLast(4) + ".png"
} else {
"$txtPath.png"
}
// 不在这里 decode只检查是否存在避免 UI 线程 IO。
val exists = try {
context.assets.open(pngPath).close()
true
} catch (_: Throwable) {
false
}
if (!exists) return null
return RefImageMatch(
txtAssetPath = txtPath,
pngAssetPath = pngPath,
score = bestScore
)
}
private fun dot(a: FloatArray, b: FloatArray): Float {
var s = 0f
for (i in a.indices) s += a[i] * b[i]
return s
}
private fun normalizeL2(v: FloatArray): FloatArray {
var sum = 0.0
for (x in v) sum += (x * x).toDouble()
val norm = sqrt(sum).toFloat()
if (norm <= 1e-12f) return v.copyOf()
return FloatArray(v.size) { i -> v[i] / norm }
}
}

View File

@@ -0,0 +1,18 @@
package com.digitalperson.embedding
/**
* 从原始 txt 中构造待嵌入文本:去掉「以 # 开头」的行(行首可含空白),其余行以换行连接。
*/
object RefTxtEmbedText {
fun fromRawFileContent(raw: String): String {
return raw.lineSequence()
.map { it.trimEnd() }
.filter { line ->
if (line.isEmpty()) return@filter false
!line.trimStart().startsWith("#")
}
.joinToString("\n")
.trim()
}
}

View File

@@ -0,0 +1,353 @@
package com.digitalperson.embedding;
import android.content.Context;
import android.os.Handler;
import android.os.Looper;
import android.util.Log;
import com.digitalperson.config.AppConfig;
import com.digitalperson.engine.BgeEngineRKNN;
import com.digitalperson.util.FileHelper;
import org.ejml.simple.SimpleMatrix;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
public class SimilarityManager {
private static final String TAG = "SimilarityManager";
private Context mContext;
private Handler mHandler;
private BgeEngineRKNN mBgeEngine;
// 测试数据相关
private List<float[]> mTestEmbeddings; // 预计算的100个句子的嵌入float数组格式
private List<SimpleMatrix> mTestEmbeddingsEJML; // 预计算的100个句子的嵌入EJML格式
private SimpleMatrix mTestEmbeddingsMatrix; // 所有测试嵌入的矩阵(嵌入维度 × 句子数量)
private List<Double> mTestEmbeddingsNorms; // 预计算的100个句子的嵌入范数
private SimpleMatrix mTestEmbeddingsNormsMatrix; // 预计算的嵌入范数向量1 × 句子数量)
private SimpleMatrix mTestEmbeddingsNormsReciprocalMatrix; // 预计算的嵌入范数倒数向量1 × 句子数量)
private List<String> mTestSentences; // 预生成的100个测试句子
private boolean useEJMLForSimilarity = false; // 默认使用传统for循环计算相似度
public interface SimilarityListener {
void onSimilarityCalculated(float similarity, long timeTaken);
void onPerformanceTestComplete(long traditionalTime, long ejmlTime, double speedup);
void onError(String errorMessage);
}
private SimilarityListener mListener;
public SimilarityManager(Context context) {
this.mContext = context;
this.mHandler = new Handler(Looper.getMainLooper());
}
public void setListener(SimilarityListener listener) {
this.mListener = listener;
}
// 初始化BGE模型
public void initBgeModel() {
try {
File bgeModelDir = FileHelper.copyBgeModels(mContext);
if (bgeModelDir == null) {
Log.e(TAG, "BGE model directory copy failed");
if (mListener != null) {
mListener.onError("BGE 模型文件复制失败");
}
return;
}
// 初始化BGE模型
mBgeEngine = new BgeEngineRKNN(mContext);
String modelPath = new File(bgeModelDir, AppConfig.Bge.MODEL_FILE).getAbsolutePath();
// 检查模型文件是否存在
if (!new File(modelPath).exists()) {
Log.e(TAG, "BGE model file does not exist: " + modelPath);
if (mListener != null) {
mListener.onError("BGE 模型文件不存在");
}
return;
}
boolean success = mBgeEngine.initialize(modelPath);
if (!success) {
Log.e(TAG, "Failed to initialize BGE model");
if (mListener != null) {
mListener.onError("BGE 模型初始化失败");
}
}
} catch (Exception e) {
Log.e(TAG, "Exception in initBgeModel", e);
if (mListener != null) {
mListener.onError("初始化BGE模型异常: " + e.getMessage());
}
}
}
// 计算BGE相似度
public void calculateSimilarity(String text1, String text2) {
if (mBgeEngine == null || !mBgeEngine.isInitialized()) {
Log.w(TAG, "BGE engine not initialized, skipping similarity calculation");
if (mListener != null) {
mListener.onError("BGE 模型未初始化");
}
return;
}
// 在后台线程中计算相似度
new Thread(() -> {
try {
long startTime = System.currentTimeMillis();
float similarity = mBgeEngine.calculateSimilarity(text1, text2);
long timeTaken = System.currentTimeMillis() - startTime;
if (mListener != null) {
mListener.onSimilarityCalculated(similarity, timeTaken);
}
} catch (Exception e) {
Log.e(TAG, "Exception in calculateSimilarity", e);
if (mListener != null) {
mListener.onError("相似度计算失败: " + e.getMessage());
}
}
}).start();
}
// 测试BGE性能
public void testPerformance(String userInput) {
if (mBgeEngine == null || !mBgeEngine.isInitialized()) {
Log.w(TAG, "BGE engine not initialized, skipping performance test");
if (mListener != null) {
mListener.onError("BGE 模型未初始化");
}
return;
}
// 在后台线程中执行测试
new Thread(() -> {
try {
// 准备测试数据
prepareTestData();
// 1. 计算用户输入的嵌入
float[] userEmbedding = mBgeEngine.getEmbedding(userInput);
if (userEmbedding == null) {
if (mListener != null) {
mListener.onError("无法计算用户输入的嵌入");
}
return;
}
// 创建EJML格式的用户嵌入
SimpleMatrix userVec = new SimpleMatrix(userEmbedding.length, 1);
for (int i = 0; i < userEmbedding.length; i++) {
userVec.set(i, 0, userEmbedding[i]);
}
// 方法1: 使用float数组和循环计算相似度传统方法
long startTime1 = System.currentTimeMillis();
List<Float> similarities1 = new ArrayList<>();
// 优化传统方法:使用范数倒数和乘法代替除法
double normUserTraditional = 0.0;
for (float value : userEmbedding) {
normUserTraditional += value * value;
}
normUserTraditional = Math.sqrt(normUserTraditional);
double normUserReciprocalTraditional = (normUserTraditional > 1e-10) ? 1.0 / normUserTraditional : 0.0;
for (int i = 0; i < mTestEmbeddings.size(); i++) {
float[] embedding = mTestEmbeddings.get(i);
// 计算点积
double dotProduct = 0.0;
for (int j = 0; j < userEmbedding.length; j++) {
dotProduct += userEmbedding[j] * embedding[j];
}
// 使用乘法代替除法cos(u,v_i) = (u·v_i) × (1/||u||) × (1/||v_i||)
double normEmbeddingReciprocal = mTestEmbeddingsNorms.get(i) > 1e-10 ? 1.0 / mTestEmbeddingsNorms.get(i) : 0.0;
float sim = (float) (dotProduct * normUserReciprocalTraditional * normEmbeddingReciprocal);
similarities1.add(sim);
}
long timeTaken1 = System.currentTimeMillis() - startTime1;
// 方法2: 使用EJML计算相似度优化方法
long startTime2 = System.currentTimeMillis();
List<Float> similarities2 = new ArrayList<>();
// 优化1: 预计算用户向量的范数和范数倒数
double normUser = userVec.normF();
double normUserReciprocal = (normUser > 1e-10) ? 1.0 / normUser : 0.0;
// 优化2: 使用批量矩阵计算一次性计算所有点积和相似度
if (mTestEmbeddingsMatrix != null && mTestEmbeddingsNormsReciprocalMatrix != null) {
// 使用矩阵乘法一次性计算所有点积
SimpleMatrix dotProducts = userVec.transpose().mult(mTestEmbeddingsMatrix);
// 一次性计算所有相似度
// 使用乘法代替除法cos(u,v_i) = (u·v_i) × (1/||u||) × (1/||v_i||)
// 步骤1: 计算 (u·v_i) × (1/||u||)
SimpleMatrix dotProductsScaled = dotProducts.scale(normUserReciprocal);
// 步骤2: 逐元素乘法计算最终相似度
SimpleMatrix similaritiesMatrix = dotProductsScaled.elementMult(mTestEmbeddingsNormsReciprocalMatrix);
// 将结果转换为列表
for (int i = 0; i < similaritiesMatrix.numCols(); i++) {
similarities2.add((float) similaritiesMatrix.get(0, i));
}
} else if (mTestEmbeddingsMatrix != null) {
// 降级方案1: 只有嵌入矩阵,使用批量点积 + 循环计算相似度
SimpleMatrix dotProducts = userVec.transpose().mult(mTestEmbeddingsMatrix);
for (int i = 0; i < dotProducts.numCols(); i++) {
double dotProduct = dotProducts.get(0, i);
double normEmbeddingReciprocal = mTestEmbeddingsNorms.get(i) > 1e-10 ? 1.0 / mTestEmbeddingsNorms.get(i) : 0.0;
float sim = (float) (dotProduct * normUserReciprocal * normEmbeddingReciprocal);
similarities2.add(sim);
}
} else {
// 降级方案2: 使用循环计算(当矩阵未初始化时)
for (int i = 0; i < mTestEmbeddingsEJML.size(); i++) {
SimpleMatrix embedding = mTestEmbeddingsEJML.get(i);
double dotProduct = userVec.dot(embedding);
double normEmbeddingReciprocal = mTestEmbeddingsNorms.get(i) > 1e-10 ? 1.0 / mTestEmbeddingsNorms.get(i) : 0.0;
float sim = (float) (dotProduct * normUserReciprocal * normEmbeddingReciprocal);
similarities2.add(sim);
}
}
long timeTaken2 = System.currentTimeMillis() - startTime2;
// 计算速度提升
double speedup = (double) timeTaken1 / timeTaken2;
if (mListener != null) {
mListener.onPerformanceTestComplete(timeTaken1, timeTaken2, speedup);
}
} catch (Exception e) {
Log.e(TAG, "Exception in testPerformance", e);
if (mListener != null) {
mListener.onError("性能测试失败: " + e.getMessage());
}
}
}).start();
}
// 准备测试数据
private void prepareTestData() {
if (mTestEmbeddings == null || mTestEmbeddings.isEmpty() ||
mTestEmbeddingsEJML == null || mTestEmbeddingsEJML.isEmpty() ||
mTestEmbeddingsNorms == null || mTestEmbeddingsNorms.isEmpty() ||
mTestEmbeddingsMatrix == null ||
mTestEmbeddingsNormsMatrix == null ||
mTestEmbeddingsNormsReciprocalMatrix == null) {
// 生成100个测试句子
generateTestSentences();
// 预计算嵌入float数组格式
mTestEmbeddings = new ArrayList<>();
mTestEmbeddingsEJML = new ArrayList<>();
mTestEmbeddingsNorms = new ArrayList<>();
List<String> validTestSentences = new ArrayList<>(); // 只包含成功获取嵌入的句子
for (String sentence : mTestSentences) {
float[] embedding = mBgeEngine.getEmbedding(sentence);
if (embedding != null) {
// 添加float数组格式的嵌入
mTestEmbeddings.add(embedding);
// 添加EJML格式的嵌入
SimpleMatrix vec = new SimpleMatrix(embedding.length, 1);
for (int i = 0; i < embedding.length; i++) {
vec.set(i, 0, embedding[i]);
}
mTestEmbeddingsEJML.add(vec);
// 预计算嵌入的范数
double norm = vec.normF();
mTestEmbeddingsNorms.add(norm);
// 添加到有效句子列表
validTestSentences.add(sentence);
}
}
// 更新mTestSentences为只包含有效句子的列表确保索引匹配
mTestSentences = validTestSentences;
// 创建测试嵌入矩阵(嵌入维度 × 句子数量)
if (!mTestEmbeddingsEJML.isEmpty()) {
int embeddingDim = mTestEmbeddingsEJML.get(0).numRows();
int numSentences = mTestEmbeddingsEJML.size();
mTestEmbeddingsMatrix = new SimpleMatrix(embeddingDim, numSentences);
for (int i = 0; i < numSentences; i++) {
SimpleMatrix embedding = mTestEmbeddingsEJML.get(i);
for (int j = 0; j < embeddingDim; j++) {
mTestEmbeddingsMatrix.set(j, i, embedding.get(j, 0));
}
}
Log.d(TAG, "Created test embeddings matrix, size: " + embeddingDim + " × " + numSentences);
}
// 创建测试嵌入范数矩阵1 × 句子数量)
if (!mTestEmbeddingsNorms.isEmpty()) {
int numSentences = mTestEmbeddingsNorms.size();
mTestEmbeddingsNormsMatrix = new SimpleMatrix(1, numSentences);
mTestEmbeddingsNormsReciprocalMatrix = new SimpleMatrix(1, numSentences);
for (int i = 0; i < numSentences; i++) {
double norm = mTestEmbeddingsNorms.get(i);
mTestEmbeddingsNormsMatrix.set(0, i, norm);
// 预计算范数倒数,避免在相似度计算时重复计算
double reciprocal = (norm > 1e-10) ? 1.0 / norm : 0.0;
mTestEmbeddingsNormsReciprocalMatrix.set(0, i, reciprocal);
}
Log.d(TAG, "Created test embeddings norms matrix, size: 1 × " + numSentences);
Log.d(TAG, "Created test embeddings norms reciprocal matrix, size: 1 × " + numSentences);
}
Log.d(TAG, "Generated embeddings for " + mTestEmbeddings.size() + " test sentences");
}
}
// 生成测试句子
private void generateTestSentences() {
mTestSentences = new ArrayList<>();
// 基础词汇
String[] words = {"", "", "眼睛", "鼻子", "嘴巴", "耳朵", "头发", "手指", "", ""};
String[] adjectives = {"", "", "", "", "", "", "漂亮", "丑陋", "干净", ""};
String[] verbs = {"", "", "", "", "", "", "", "", "", ""};
String[] subjects = {"", "", "", "", "", "我们", "你们", "他们", "这个", "那个"};
// 添加一些固定的测试句子
mTestSentences.add("我的手有很多手指头");
mTestSentences.add("他的头发很长");
mTestSentences.add("她的眼睛很大");
mTestSentences.add("这个鼻子很高");
mTestSentences.add("那个嘴巴很小");
Log.d(TAG, "Generated " + mTestSentences.size() + " test sentences");
}
// 设置是否使用EJML进行相似度计算
public void setUseEJMLForSimilarity(boolean useEJML) {
this.useEJMLForSimilarity = useEJML;
}
// 检查BGE模型是否初始化
public boolean isInitialized() {
return mBgeEngine != null && mBgeEngine.isInitialized();
}
// 释放BGE模型资源
public void deinitialize() {
if (mBgeEngine != null) {
mBgeEngine.deinitialize();
mBgeEngine = null;
Log.d(TAG, "BGE engine deinitialized");
}
}
}

View File

@@ -0,0 +1,104 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package com.digitalperson.engine;
import com.google.common.base.Ascii;
import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/** Basic tokenization (punctuation splitting, lower casing, etc.) */
public final class BasicTokenizer {
private final boolean doLowerCase;
public BasicTokenizer(boolean doLowerCase) {
this.doLowerCase = doLowerCase;
}
public List<String> tokenize(String text) {
String cleanedText = cleanText(text);
List<String> origTokens = whitespaceTokenize(cleanedText);
StringBuilder stringBuilder = new StringBuilder();
for (String token : origTokens) {
if (doLowerCase) {
token = Ascii.toLowerCase(token);
}
List<String> list = runSplitOnPunc(token);
for (String subToken : list) {
stringBuilder.append(subToken).append(" ");
}
}
return whitespaceTokenize(stringBuilder.toString());
}
/* Performs invalid character removal and whitespace cleanup on text. */
static String cleanText(String text) {
if (text == null) {
throw new NullPointerException("The input String is null.");
}
StringBuilder stringBuilder = new StringBuilder("");
for (int index = 0; index < text.length(); index++) {
char ch = text.charAt(index);
// Skip the characters that cannot be used.
if (CharChecker.isInvalid(ch) || CharChecker.isControl(ch)) {
continue;
}
if (CharChecker.isWhitespace(ch)) {
stringBuilder.append(" ");
} else {
stringBuilder.append(ch);
}
}
return stringBuilder.toString();
}
/* Runs basic whitespace cleaning and splitting on a piece of text. */
static List<String> whitespaceTokenize(String text) {
if (text == null) {
throw new NullPointerException("The input String is null.");
}
return Arrays.asList(text.split(" "));
}
/* Splits punctuation on a piece of text. */
static List<String> runSplitOnPunc(String text) {
if (text == null) {
throw new NullPointerException("The input String is null.");
}
List<String> tokens = new ArrayList<>();
boolean startNewWord = true;
for (int i = 0; i < text.length(); i++) {
char ch = text.charAt(i);
if (CharChecker.isPunctuation(ch)) {
tokens.add(String.valueOf(ch));
startNewWord = true;
} else {
if (startNewWord) {
tokens.add("");
startNewWord = false;
}
tokens.set(tokens.size() - 1, Iterables.getLast(tokens) + ch);
}
}
return tokens;
}
}

View File

@@ -0,0 +1,263 @@
package com.digitalperson.engine;
import android.content.Context;
import android.util.Log;
import java.util.HashMap;
import java.util.Map;
public class BgeEngineRKNN {
private static final String TAG = "BgeEngineRKNN";
private final long nativePtr;
private final Context mContext;
private boolean mIsInitialized = false;
private FullTokenizer mFullTokenizer = null;
private Map<String, Integer> mVocabMap = new HashMap<>();
static {
try {
// Load dependent libraries
System.loadLibrary("rknnrt");
System.loadLibrary("bgeEngine");
Log.d(TAG, "Successfully loaded librknnrt.so and libbgeEngine.so");
} catch (UnsatisfiedLinkError e) {
Log.e(TAG, "Failed to load native library", e);
throw e;
}
}
public BgeEngineRKNN(Context context) {
mContext = context;
try {
nativePtr = createBgeEngine();
if (nativePtr == 0) {
throw new RuntimeException("Failed to create native BGE engine");
}
} catch (UnsatisfiedLinkError e) {
Log.e(TAG, "Failed to load native library", e);
throw new RuntimeException("Failed to load native library: " + e.getMessage(), e);
}
}
public boolean initialize(String modelPath) {
// 自动查找词汇表路径
String vocabPath = modelPath.replace("bge-small-zh-v1.5.rknn", "vocab.txt");
return initialize(modelPath, vocabPath);
}
public boolean initialize(String modelPath, String vocabPath) {
if (mIsInitialized) {
Log.i(TAG, "Model already initialized");
return true;
}
Log.d(TAG, "Loading BGE model: " + modelPath);
Log.d(TAG, "Loading vocab: " + vocabPath);
// 加载词汇表
if (!loadVocab(vocabPath)) {
Log.e(TAG, "Failed to load vocab");
return false;
}
// 创建FullTokenizer实例
mFullTokenizer = new FullTokenizer(mVocabMap, true); // true表示使用小写
int ret = loadModel(nativePtr, modelPath, vocabPath);
if (ret == 0) {
mIsInitialized = true;
return true;
}
return false;
}
/**
* Load vocabulary from file
* @param vocabPath Path to vocabulary file
* @return True if successful, false otherwise
*/
private boolean loadVocab(String vocabPath) {
try {
java.io.BufferedReader reader = new java.io.BufferedReader(
new java.io.FileReader(vocabPath));
String line;
int index = 0;
while ((line = reader.readLine()) != null) {
line = line.trim();
if (!line.isEmpty()) {
mVocabMap.put(line, index);
index++;
}
}
reader.close();
Log.d(TAG, "Vocab loaded successfully with " + mVocabMap.size() + " tokens");
return true;
} catch (java.io.IOException e) {
Log.e(TAG, "Failed to load vocab: " + e.getMessage());
e.printStackTrace();
return false;
}
}
public void deinitialize() {
if (nativePtr != 0) {
freeModel(nativePtr);
}
mIsInitialized = false;
}
public boolean isInitialized() {
return mIsInitialized;
}
/**
* Get the embedding dimension of the model
* @return Embedding dimension
*/
public int getEmbeddingDim() {
if (!mIsInitialized) {
Log.e(TAG, "Engine not initialized");
return -1;
}
return getEmbeddingDim(nativePtr);
}
/**
* Calculate embedding for a single text
* @param text Input text
* @return Embedding vector as float array
*/
public float[] getEmbedding(String text) {
if (!mIsInitialized) {
Log.e(TAG, "Engine not initialized");
return null;
}
return getEmbeddingNative(nativePtr, text);
}
/**
* Calculate cosine similarity between two texts
* @param text1 First text
* @param text2 Second text
* @return Cosine similarity score between -1.0 and 1.0
*/
public float calculateSimilarity(String text1, String text2) {
if (!mIsInitialized) {
Log.e(TAG, "Engine not initialized");
return 0.0f;
}
return calculateSimilarityNative(nativePtr, text1, text2);
}
/**
* Get token ids for a text using FullTokenizer
* @param text Input text
* @return Token ids as int array
*/
public int[] getTokens(String text) {
if (!mIsInitialized) {
Log.e(TAG, "Engine not initialized");
return null;
}
if (mFullTokenizer == null) {
Log.e(TAG, "FullTokenizer not initialized");
return null;
}
// 使用FullTokenizer进行tokenization
java.util.List<String> tokens = mFullTokenizer.tokenize(text);
java.util.List<Integer> ids = mFullTokenizer.convertTokensToIds(tokens);
// 转换为int数组
int[] result = new int[ids.size()];
for (int i = 0; i < ids.size(); i++) {
result[i] = ids.get(i);
}
return result;
}
/**
* Calculate [UNK] ratio for token ids
* @param tokens Token ids array
* @return [UNK] ratio between 0.0 and 1.0
*/
public float calculateUnkRatio(int[] tokens) {
if (!mIsInitialized) {
Log.e(TAG, "Engine not initialized");
return 0.0f;
}
if (mFullTokenizer == null || tokens == null || tokens.length == 0) {
return 0.0f;
}
// 获取[UNK]的id
int unkId = mVocabMap.getOrDefault("[UNK]", -1);
if (unkId == -1) {
return 0.0f;
}
// 统计[UNK]的数量
int unkCount = 0;
for (int token : tokens) {
if (token == unkId) {
unkCount++;
}
}
return (float) unkCount / tokens.length;
}
/**
* Get BERT model inputs (input_ids, attention_mask, token_type_ids)
* @param text Input text
* @param maxSeqLength Maximum sequence length
* @return Array of int arrays: [input_ids, attention_mask, token_type_ids]
*/
public int[][] getBertInputs(String text, int maxSeqLength) {
if (!mIsInitialized) {
Log.e(TAG, "Engine not initialized");
return null;
}
if (mFullTokenizer == null) {
Log.e(TAG, "FullTokenizer not initialized");
return null;
}
// 使用FullTokenizer进行tokenization
java.util.List<String> tokens = mFullTokenizer.tokenize(text);
java.util.List<Integer> ids = mFullTokenizer.convertTokensToIds(tokens);
// 准备BERT输入
int[] inputIds = new int[maxSeqLength];
int[] attentionMask = new int[maxSeqLength];
int[] tokenTypeIds = new int[maxSeqLength];
// 填充[CLS] token
inputIds[0] = mVocabMap.getOrDefault("[CLS]", 101);
attentionMask[0] = 1;
// 填充文本token
int textLength = ids.size();
int maxTextLength = maxSeqLength - 2; // 减去[CLS]和[SEP]
int actualLength = Math.min(textLength, maxTextLength);
for (int i = 0; i < actualLength; i++) {
inputIds[i + 1] = ids.get(i);
attentionMask[i + 1] = 1;
}
// 填充[SEP] token
inputIds[actualLength + 1] = mVocabMap.getOrDefault("[SEP]", 102);
attentionMask[actualLength + 1] = 1;
return new int[][]{inputIds, attentionMask, tokenTypeIds};
}
// Native methods
private native long createBgeEngine();
private native int loadModel(long nativePtr, String modelPath, String vocabPath);
private native void freeModel(long ptr);
private native float[] getEmbeddingNative(long ptr, String text);
private native int getEmbeddingDim(long ptr);
private native float calculateSimilarityNative(long ptr, String text1, String text2);
}

View File

@@ -0,0 +1,58 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package com.digitalperson.engine;
/** To check whether a char is whitespace/control/punctuation. */
final class CharChecker {
/** To judge whether it's an empty or unknown character. */
public static boolean isInvalid(char ch) {
return (ch == 0 || ch == 0xfffd);
}
/** To judge whether it's a control character(exclude whitespace). */
public static boolean isControl(char ch) {
if (Character.isWhitespace(ch)) {
return false;
}
int type = Character.getType(ch);
return (type == Character.CONTROL || type == Character.FORMAT);
}
/** To judge whether it can be regarded as a whitespace. */
public static boolean isWhitespace(char ch) {
if (Character.isWhitespace(ch)) {
return true;
}
int type = Character.getType(ch);
return (type == Character.SPACE_SEPARATOR
|| type == Character.LINE_SEPARATOR
|| type == Character.PARAGRAPH_SEPARATOR);
}
/** To judge whether it's a punctuation. */
public static boolean isPunctuation(char ch) {
int type = Character.getType(ch);
return (type == Character.CONNECTOR_PUNCTUATION
|| type == Character.DASH_PUNCTUATION
|| type == Character.START_PUNCTUATION
|| type == Character.END_PUNCTUATION
|| type == Character.INITIAL_QUOTE_PUNCTUATION
|| type == Character.FINAL_QUOTE_PUNCTUATION
|| type == Character.OTHER_PUNCTUATION);
}
private CharChecker() {}
}

View File

@@ -0,0 +1,52 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package com.digitalperson.engine;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
/**
* A java realization of Bert tokenization. Original python code:
* https://github.com/google-research/bert/blob/master/tokenization.py runs full tokenization to
* tokenize a String into split subtokens or ids.
*/
public final class FullTokenizer {
private final BasicTokenizer basicTokenizer;
private final WordpieceTokenizer wordpieceTokenizer;
private final Map<String, Integer> dic;
public FullTokenizer(Map<String, Integer> inputDic, boolean doLowerCase) {
dic = inputDic;
basicTokenizer = new BasicTokenizer(doLowerCase);
wordpieceTokenizer = new WordpieceTokenizer(inputDic);
}
public List<String> tokenize(String text) {
List<String> splitTokens = new ArrayList<>();
for (String token : basicTokenizer.tokenize(text)) {
splitTokens.addAll(wordpieceTokenizer.tokenize(token));
}
return splitTokens;
}
public List<Integer> convertTokensToIds(List<String> tokens) {
List<Integer> outputIds = new ArrayList<>();
for (String token : tokens) {
outputIds.add(dic.get(token));
}
return outputIds;
}
}

View File

@@ -0,0 +1,94 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package com.digitalperson.engine;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
/** Word piece tokenization to split a piece of text into its word pieces. */
public final class WordpieceTokenizer {
private final Map<String, Integer> dic;
private static final String UNKNOWN_TOKEN = "[UNK]"; // For unknown words.
private static final int MAX_INPUTCHARS_PER_WORD = 200;
public WordpieceTokenizer(Map<String, Integer> vocab) {
dic = vocab;
}
/**
* Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first
* algorithm to perform tokenization using the given vocabulary. For example: input = "unaffable",
* output = ["un", "##aff", "##able"].
*
* @param text: A single token or whitespace separated tokens. This should have already been
* passed through `BasicTokenizer.
* @return A list of wordpiece tokens.
*/
public List<String> tokenize(String text) {
if (text == null) {
throw new NullPointerException("The input String is null.");
}
List<String> outputTokens = new ArrayList<>();
for (String token : BasicTokenizer.whitespaceTokenize(text)) {
if (token.length() > MAX_INPUTCHARS_PER_WORD) {
outputTokens.add(UNKNOWN_TOKEN);
continue;
}
boolean isBad = false; // Mark if a word cannot be tokenized into known subwords.
int start = 0;
List<String> subTokens = new ArrayList<>();
while (start < token.length()) {
String curSubStr = "";
int end = token.length(); // Longer substring matches first.
while (start < end) {
String subStr =
(start == 0) ? token.substring(start, end) : "##" + token.substring(start, end);
if (dic.containsKey(subStr)) {
curSubStr = subStr;
break;
}
end--;
}
// The word doesn't contain any known subwords.
if ("".equals(curSubStr)) {
isBad = true;
break;
}
// curSubStr is the longeset subword that can be found.
subTokens.add(curSubStr);
// Proceed to tokenize the resident string.
start = end;
}
if (isBad) {
outputTokens.add(UNKNOWN_TOKEN);
} else {
outputTokens.addAll(subTokens);
}
}
return outputTokens;
}
}

View File

@@ -185,7 +185,7 @@ class QuestionGenerationAgent(
val generationPrompt = buildGenerationPrompt(prompt, userProfile)
// 3. 调用大模型生成题目
generateQuestionFromLLM(generationPrompt) { generatedQuestion ->
generateQuestionFromLLM(generationPrompt, prompt) { generatedQuestion ->
if (generatedQuestion == null) {
Log.w(TAG, "Failed to generate question")
return@generateQuestionFromLLM
@@ -301,18 +301,22 @@ class QuestionGenerationAgent(
/**
* 调用LLM生成题目
*/
private fun generateQuestionFromLLM(prompt: String, onResult: (GeneratedQuestion?) -> Unit) {
private fun generateQuestionFromLLM(
promptText: String,
promptMeta: QuestionPrompt,
onResult: (GeneratedQuestion?) -> Unit
) {
// 优先使用本地LLM如果不可用则使用云端LLM
if (llmManager != null) {
// 使用本地LLM
llmManager.generate(prompt) { response: String ->
parseGeneratedQuestion(response, onResult)
llmManager.generate(promptText) { response: String ->
parseGeneratedQuestion(response, promptMeta, onResult)
}
} else if (cloudLLMGenerator != null) {
// 使用云端LLM
Log.d(TAG, "Using cloud LLM to generate question")
cloudLLMGenerator.invoke(prompt) { response ->
parseGeneratedQuestion(response, onResult)
cloudLLMGenerator.invoke(promptText) { response ->
parseGeneratedQuestion(response, promptMeta, onResult)
}
} else {
Log.e(TAG, "No LLM available (neither local nor cloud)")
@@ -323,16 +327,20 @@ class QuestionGenerationAgent(
/**
* 解析生成的题目
*/
private fun parseGeneratedQuestion(response: String, onResult: (GeneratedQuestion?) -> Unit) {
private fun parseGeneratedQuestion(
response: String,
promptMeta: QuestionPrompt,
onResult: (GeneratedQuestion?) -> Unit
) {
try {
val json = extractJsonFromResponse(response)
if (json != null) {
val question = GeneratedQuestion(
content = json.getString("content"),
answer = json.getString("answer"),
subject = "生活适应",
grade = 1,
difficulty = 1
subject = promptMeta.subject,
grade = promptMeta.grade,
difficulty = promptMeta.difficulty
)
onResult(question)
} else {

View File

@@ -7,6 +7,7 @@ import android.util.Log
import com.digitalperson.config.AppConfig
import java.io.File
import java.io.FileOutputStream
import java.io.InputStream
object FileHelper {
private const val TAG = AppConfig.TAG
@@ -64,6 +65,54 @@ object FileHelper {
val files = arrayOf(AppConfig.FaceRecognition.MODEL_NAME)
return copyAssetsToInternal(context, AppConfig.FaceRecognition.MODEL_DIR, outDir, files)
}
/**
* 将 BGE 相关文件从 assets 复制到 [Context.getFilesDir]/[AppConfig.Bge.ASSET_DIR]。
* 若已存在且长度与 asset 一致则跳过(与 [com.digitalperson.embedding.SimilarityManager] 行为一致)。
*/
@JvmStatic
fun copyBgeModels(context: Context): File? {
val assetDir = AppConfig.Bge.ASSET_DIR
val modelDir = File(context.filesDir, assetDir).apply { mkdirs() }
val files = arrayOf(
AppConfig.Bge.MODEL_FILE,
"vocab.txt",
"tokenizer.json",
"tokenizer_config.json"
)
for (name in files) {
val outFile = File(modelDir, name)
val assetPath = "$assetDir/$name"
try {
var skip = false
if (outFile.exists()) {
context.assets.open(assetPath).use { input ->
val assetSize = assetSizeOrNegative(input)
if (assetSize >= 0 && outFile.length() == assetSize) skip = true
}
}
if (skip) continue
context.assets.open(assetPath).use { input ->
FileOutputStream(outFile).use { output -> input.copyTo(output) }
}
Log.i(TAG, "Copied BGE asset: $name")
} catch (e: Exception) {
Log.e(TAG, "copyBgeModels failed for $name: ${e.message}", e)
return null
}
}
return modelDir
}
/** [InputStream.available] 在部分实现上不可靠;失败时返回 -1强制重拷。 */
private fun assetSizeOrNegative(input: InputStream): Long {
return try {
val n = input.available()
if (n > 0) n.toLong() else -1L
} catch (_: Exception) {
-1L
}
}
fun ensureDir(dir: File): File {
if (!dir.exists()) {