word similarity
This commit is contained in:
26
app/src/main/java/com/digitalperson/DigitalPersonApp.kt
Normal file
26
app/src/main/java/com/digitalperson/DigitalPersonApp.kt
Normal file
@@ -0,0 +1,26 @@
|
||||
package com.digitalperson
|
||||
|
||||
import android.app.Application
|
||||
import android.util.Log
|
||||
import com.digitalperson.config.AppConfig
|
||||
import com.digitalperson.embedding.RefEmbeddingIndexer
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.SupervisorJob
|
||||
import kotlinx.coroutines.launch
|
||||
|
||||
class DigitalPersonApp : Application() {
|
||||
|
||||
private val appScope = CoroutineScope(SupervisorJob() + Dispatchers.Default)
|
||||
|
||||
override fun onCreate() {
|
||||
super.onCreate()
|
||||
appScope.launch {
|
||||
try {
|
||||
RefEmbeddingIndexer.runOnce(this@DigitalPersonApp)
|
||||
} catch (t: Throwable) {
|
||||
Log.e(AppConfig.TAG, "[RefEmbed] 索引任务异常", t)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -38,6 +38,7 @@ import com.digitalperson.interaction.ConversationSummaryMemory
|
||||
|
||||
import java.io.File
|
||||
import android.graphics.BitmapFactory
|
||||
import android.widget.ImageView
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.Job
|
||||
@@ -49,6 +50,7 @@ import kotlinx.coroutines.withContext
|
||||
|
||||
import com.digitalperson.onboard_testing.FaceRecognitionTest
|
||||
import com.digitalperson.onboard_testing.LLMSummaryTest
|
||||
import com.digitalperson.embedding.RefImageMatcher
|
||||
|
||||
class Live2DChatActivity : AppCompatActivity() {
|
||||
companion object {
|
||||
@@ -109,6 +111,7 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
|
||||
private lateinit var faceRecognitionTest: FaceRecognitionTest
|
||||
private lateinit var llmSummaryTest: LLMSummaryTest
|
||||
private var refMatchImageView: ImageView? = null
|
||||
|
||||
override fun onRequestPermissionsResult(
|
||||
requestCode: Int,
|
||||
@@ -159,6 +162,8 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
speakingPlayerViewId = 0,
|
||||
live2dViewId = R.id.live2d_view
|
||||
)
|
||||
|
||||
refMatchImageView = findViewById(R.id.ref_match_image)
|
||||
|
||||
cameraPreviewView = findViewById(R.id.camera_preview)
|
||||
cameraPreviewView.implementationMode = PreviewView.ImplementationMode.COMPATIBLE
|
||||
@@ -611,6 +616,7 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
runOnUiThread {
|
||||
uiManager.appendToUi("${filteredText.orEmpty()}\n")
|
||||
}
|
||||
maybeShowMatchedRefImage(filteredText ?: response)
|
||||
}
|
||||
interactionCoordinator.onCloudFinalResponse(response)
|
||||
}
|
||||
@@ -648,6 +654,24 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
onStopClicked(userInitiated = false)
|
||||
}
|
||||
}
|
||||
|
||||
private fun maybeShowMatchedRefImage(text: String) {
|
||||
val imageView = refMatchImageView ?: return
|
||||
ioScope.launch {
|
||||
val match = RefImageMatcher.findBestMatch(applicationContext, text)
|
||||
if (match == null) return@launch
|
||||
val bitmap = try {
|
||||
assets.open(match.pngAssetPath).use { BitmapFactory.decodeStream(it) }
|
||||
} catch (_: Throwable) {
|
||||
null
|
||||
}
|
||||
if (bitmap == null) return@launch
|
||||
runOnUiThread {
|
||||
imageView.setImageBitmap(bitmap)
|
||||
imageView.visibility = android.view.View.VISIBLE
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun createTtsCallback() = object : TtsController.TtsCallback {
|
||||
override fun onTtsStarted(text: String) {
|
||||
|
||||
@@ -25,6 +25,7 @@ import android.view.View
|
||||
import androidx.lifecycle.Lifecycle
|
||||
import androidx.lifecycle.LifecycleOwner
|
||||
import androidx.lifecycle.LifecycleRegistry
|
||||
import android.widget.ImageView
|
||||
import com.unity3d.player.UnityPlayer
|
||||
import com.unity3d.player.UnityPlayerActivity
|
||||
import com.digitalperson.audio.AudioProcessor
|
||||
@@ -47,6 +48,8 @@ import com.digitalperson.tts.TtsController
|
||||
import com.digitalperson.util.FileHelper
|
||||
import com.digitalperson.vad.VadManager
|
||||
import kotlinx.coroutines.*
|
||||
import com.digitalperson.embedding.RefImageMatcher
|
||||
import android.graphics.BitmapFactory
|
||||
|
||||
class UnityDigitalPersonActivity : UnityPlayerActivity(), LifecycleOwner {
|
||||
|
||||
@@ -108,6 +111,7 @@ class UnityDigitalPersonActivity : UnityPlayerActivity(), LifecycleOwner {
|
||||
private lateinit var holdToSpeakButton: Button
|
||||
private var recordButtonGlow: View? = null
|
||||
private var pulseAnimator: ObjectAnimator? = null
|
||||
private var refMatchImageView: ImageView? = null
|
||||
|
||||
|
||||
// 音频和AI模块
|
||||
@@ -254,6 +258,7 @@ class UnityDigitalPersonActivity : UnityPlayerActivity(), LifecycleOwner {
|
||||
chatHistoryText = chatLayout.findViewById(R.id.my_text)
|
||||
holdToSpeakButton = chatLayout.findViewById(R.id.record_button)
|
||||
recordButtonGlow = chatLayout.findViewById(R.id.record_button_glow)
|
||||
refMatchImageView = chatLayout.findViewById(R.id.ref_match_image)
|
||||
|
||||
// 根据配置设置按钮可见性
|
||||
if (AppConfig.USE_HOLD_TO_SPEAK) {
|
||||
@@ -735,6 +740,7 @@ class UnityDigitalPersonActivity : UnityPlayerActivity(), LifecycleOwner {
|
||||
val filteredText = ttsController.speakLlmResponse(response)
|
||||
android.util.Log.d("UnityDigitalPerson", "LLM response filtered: ${filteredText?.take(60)}")
|
||||
if (filteredText != null) appendChat("助手: $filteredText")
|
||||
maybeShowMatchedRefImage(filteredText ?: response)
|
||||
interactionCoordinator.onCloudFinalResponse(filteredText ?: response.trim())
|
||||
}
|
||||
|
||||
@@ -751,6 +757,25 @@ class UnityDigitalPersonActivity : UnityPlayerActivity(), LifecycleOwner {
|
||||
}
|
||||
}
|
||||
|
||||
private fun maybeShowMatchedRefImage(text: String) {
|
||||
val imageView = refMatchImageView ?: return
|
||||
// Unity Activity already has coroutines
|
||||
CoroutineScope(SupervisorJob() + Dispatchers.IO).launch {
|
||||
val match = RefImageMatcher.findBestMatch(applicationContext, text)
|
||||
if (match == null) return@launch
|
||||
val bitmap = try {
|
||||
assets.open(match.pngAssetPath).use { BitmapFactory.decodeStream(it) }
|
||||
} catch (_: Throwable) {
|
||||
null
|
||||
}
|
||||
if (bitmap == null) return@launch
|
||||
runOnUiThread {
|
||||
imageView.setImageBitmap(bitmap)
|
||||
imageView.visibility = View.VISIBLE
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun requestLocalThought(prompt: String, onResult: (String) -> Unit) {
|
||||
val local = llmManager
|
||||
if (local == null) {
|
||||
|
||||
@@ -110,6 +110,19 @@ object AppConfig {
|
||||
const val MODEL_SIZE_ESTIMATE = 500L * 1024 * 1024 // 500MB
|
||||
}
|
||||
|
||||
/** BGE-small-zh-v1.5 文本嵌入(RKNN),用于语义相似度 / 检索。 */
|
||||
object Bge {
|
||||
const val ASSET_DIR = "bge_models"
|
||||
const val MODEL_FILE = "bge-small-zh-v1.5.rknn"
|
||||
}
|
||||
|
||||
/**
|
||||
* app/note/ref 通过 Gradle 额外 assets 目录打入 apk 后,在 assets 中的根路径为 `ref/`。
|
||||
*/
|
||||
object RefCorpus {
|
||||
const val ASSETS_ROOT = "ref"
|
||||
}
|
||||
|
||||
object OnboardTesting {
|
||||
// 测试人脸识别
|
||||
const val FACE_REGONITION = false
|
||||
|
||||
@@ -9,15 +9,24 @@ import com.digitalperson.data.dao.UserAnswerDao
|
||||
import com.digitalperson.data.dao.UserMemoryDao
|
||||
import com.digitalperson.data.dao.ChatMessageDao
|
||||
import com.digitalperson.data.dao.ConversationSummaryDao
|
||||
import com.digitalperson.data.dao.RefTextEmbeddingDao
|
||||
import com.digitalperson.data.entity.Question
|
||||
import com.digitalperson.data.entity.UserAnswer
|
||||
import com.digitalperson.data.entity.UserMemory
|
||||
import com.digitalperson.data.entity.ChatMessageEntity
|
||||
import com.digitalperson.data.entity.ConversationSummaryEntity
|
||||
import com.digitalperson.data.entity.RefTextEmbedding
|
||||
|
||||
@Database(
|
||||
entities = [UserMemory::class, Question::class, UserAnswer::class, ChatMessageEntity::class, ConversationSummaryEntity::class],
|
||||
version = 4,
|
||||
entities = [
|
||||
UserMemory::class,
|
||||
Question::class,
|
||||
UserAnswer::class,
|
||||
ChatMessageEntity::class,
|
||||
ConversationSummaryEntity::class,
|
||||
RefTextEmbedding::class
|
||||
],
|
||||
version = 5,
|
||||
exportSchema = false
|
||||
)
|
||||
abstract class AppDatabase : RoomDatabase() {
|
||||
@@ -26,6 +35,7 @@ abstract class AppDatabase : RoomDatabase() {
|
||||
abstract fun userAnswerDao(): UserAnswerDao
|
||||
abstract fun chatMessageDao(): ChatMessageDao
|
||||
abstract fun conversationSummaryDao(): ConversationSummaryDao
|
||||
abstract fun refTextEmbeddingDao(): RefTextEmbeddingDao
|
||||
|
||||
companion object {
|
||||
private const val DATABASE_NAME = "digital_human.db"
|
||||
|
||||
@@ -9,6 +9,17 @@ import com.digitalperson.data.entity.Question
|
||||
interface QuestionDao {
|
||||
@Insert
|
||||
fun insert(question: Question): Long
|
||||
|
||||
@Query(
|
||||
"""
|
||||
SELECT * FROM questions
|
||||
WHERE content = :content
|
||||
AND ((:subject IS NULL AND subject IS NULL) OR subject = :subject)
|
||||
AND ((:grade IS NULL AND grade IS NULL) OR grade = :grade)
|
||||
LIMIT 1
|
||||
"""
|
||||
)
|
||||
fun findByContentSubjectGrade(content: String, subject: String?, grade: Int?): Question?
|
||||
|
||||
@Query("SELECT * FROM questions WHERE subject = :subject ORDER BY difficulty")
|
||||
fun getQuestionsBySubject(subject: String): List<Question>
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
package com.digitalperson.data.dao
|
||||
|
||||
import androidx.room.Dao
|
||||
import androidx.room.Insert
|
||||
import androidx.room.OnConflictStrategy
|
||||
import androidx.room.Query
|
||||
import com.digitalperson.data.entity.RefTextEmbedding
|
||||
|
||||
@Dao
|
||||
interface RefTextEmbeddingDao {
|
||||
|
||||
@Query("SELECT * FROM ref_text_embeddings WHERE assetPath = :path LIMIT 1")
|
||||
fun getByPath(path: String): RefTextEmbedding?
|
||||
|
||||
@Insert(onConflict = OnConflictStrategy.REPLACE)
|
||||
fun insert(row: RefTextEmbedding): Long
|
||||
|
||||
@Query("SELECT * FROM ref_text_embeddings")
|
||||
fun getAll(): List<RefTextEmbedding>
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
package com.digitalperson.data.entity
|
||||
|
||||
import androidx.room.Entity
|
||||
import androidx.room.Index
|
||||
import androidx.room.PrimaryKey
|
||||
import com.digitalperson.data.util.embeddingBytesToFloatArray
|
||||
|
||||
@Entity(
|
||||
tableName = "ref_text_embeddings",
|
||||
indices = [Index(value = ["assetPath"], unique = true)]
|
||||
)
|
||||
data class RefTextEmbedding(
|
||||
@PrimaryKey(autoGenerate = true) val id: Long = 0,
|
||||
/** assets 相对路径,如 ref/一年级.../xxx.txt */
|
||||
val assetPath: String,
|
||||
/** 参与嵌入的正文 UTF-8 的 SHA-256 十六进制,用于跳过未变更文件 */
|
||||
val contentHash: String,
|
||||
val dim: Int,
|
||||
val embedding: ByteArray,
|
||||
val updatedAt: Long = System.currentTimeMillis()
|
||||
) {
|
||||
fun toFloatArray(): FloatArray = embeddingBytesToFloatArray(embedding)
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package com.digitalperson.data.util
|
||||
|
||||
import java.nio.ByteBuffer
|
||||
import java.nio.ByteOrder
|
||||
|
||||
fun floatArrayToEmbeddingBytes(values: FloatArray): ByteArray {
|
||||
val bb = ByteBuffer.allocate(values.size * 4).order(ByteOrder.LITTLE_ENDIAN)
|
||||
for (v in values) {
|
||||
bb.putFloat(v)
|
||||
}
|
||||
return bb.array()
|
||||
}
|
||||
|
||||
fun embeddingBytesToFloatArray(blob: ByteArray): FloatArray {
|
||||
val bb = ByteBuffer.wrap(blob).order(ByteOrder.LITTLE_ENDIAN)
|
||||
val n = blob.size / 4
|
||||
return FloatArray(n) { bb.getFloat() }
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package com.digitalperson.embedding
|
||||
|
||||
import android.content.Context
|
||||
import com.digitalperson.config.AppConfig
|
||||
import com.digitalperson.engine.BgeEngineRKNN
|
||||
import com.digitalperson.util.FileHelper
|
||||
import java.io.File
|
||||
|
||||
/**
|
||||
* 懒加载 BGE 文本嵌入(RKNN)。用于对话文本与预存标注的语义相似度检索。
|
||||
*
|
||||
* 初始化会复制 [AppConfig.Bge.ASSET_DIR] 下资源到内部存储并加载 [AppConfig.Bge.MODEL_FILE]。
|
||||
*/
|
||||
object BgeEmbedding {
|
||||
|
||||
@Volatile
|
||||
private var engine: BgeEngineRKNN? = null
|
||||
|
||||
fun isReady(): Boolean = engine?.isInitialized == true
|
||||
|
||||
/**
|
||||
* 在主线程调用会阻塞;建议在后台线程或协程 [Dispatchers.IO] 中调用。
|
||||
*/
|
||||
@Synchronized
|
||||
fun initialize(context: Context): Boolean {
|
||||
if (engine?.isInitialized == true) return true
|
||||
val dir = FileHelper.copyBgeModels(context.applicationContext) ?: return false
|
||||
val path = File(dir, AppConfig.Bge.MODEL_FILE).absolutePath
|
||||
if (!File(path).exists()) return false
|
||||
val eng = BgeEngineRKNN(context.applicationContext)
|
||||
if (!eng.initialize(path)) return false
|
||||
engine = eng
|
||||
return true
|
||||
}
|
||||
|
||||
@Synchronized
|
||||
fun release() {
|
||||
engine?.deinitialize()
|
||||
engine = null
|
||||
}
|
||||
|
||||
fun getEmbedding(text: String): FloatArray? = engine?.getEmbedding(text)
|
||||
|
||||
fun similarity(text1: String, text2: String): Float? =
|
||||
engine?.calculateSimilarity(text1, text2)
|
||||
|
||||
fun embeddingDim(): Int = engine?.embeddingDim ?: -1
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
package com.digitalperson.embedding
|
||||
|
||||
import android.content.Context
|
||||
|
||||
internal object RefCorpusAssetScanner {
|
||||
|
||||
/**
|
||||
* 递归列出 [root] 目录下(含子目录)所有 `.txt` 的 assets 路径,使用 `/` 分隔。
|
||||
*/
|
||||
fun listTxtFilesUnder(context: Context, root: String): List<String> {
|
||||
val am = context.assets
|
||||
val out = ArrayList<String>()
|
||||
|
||||
fun walk(prefix: String) {
|
||||
val children = am.list(prefix) ?: return
|
||||
if (children.isEmpty()) return
|
||||
for (child in children) {
|
||||
if (child.isEmpty()) continue
|
||||
val full = "$prefix/$child"
|
||||
val grand = am.list(full)
|
||||
if (grand.isNullOrEmpty()) {
|
||||
if (child.endsWith(".txt", ignoreCase = true)) {
|
||||
out.add(full)
|
||||
}
|
||||
} else {
|
||||
walk(full)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
walk(root.removeSuffix("/"))
|
||||
return out.sorted()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,173 @@
|
||||
package com.digitalperson.embedding
|
||||
|
||||
import android.content.Context
|
||||
import android.util.Log
|
||||
import com.digitalperson.config.AppConfig
|
||||
import com.digitalperson.data.AppDatabase
|
||||
import com.digitalperson.data.entity.Question
|
||||
import com.digitalperson.data.entity.RefTextEmbedding
|
||||
import com.digitalperson.data.util.floatArrayToEmbeddingBytes
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.withContext
|
||||
import java.security.MessageDigest
|
||||
|
||||
/**
|
||||
* 开机(进程启动)时在后台扫描 assets 下 [AppConfig.RefCorpus.ASSETS_ROOT] 中全部 `.txt`,
|
||||
* 跳过 `#` 行后做 BGE 嵌入,写入 Room 并填充 [RefEmbeddingMemoryCache]。
|
||||
*
|
||||
* DAO 为同步接口,整个函数在 [Dispatchers.IO] 上运行,不阻塞主线程。
|
||||
*/
|
||||
object RefEmbeddingIndexer {
|
||||
|
||||
private const val TAG = AppConfig.TAG
|
||||
|
||||
suspend fun runOnce(context: Context) = withContext(Dispatchers.IO) {
|
||||
val app = context.applicationContext
|
||||
val db = AppDatabase.getInstance(app)
|
||||
val dao = db.refTextEmbeddingDao()
|
||||
val questionDao = db.questionDao()
|
||||
|
||||
if (!BgeEmbedding.initialize(app)) {
|
||||
Log.e(TAG, "[RefEmbed] BGE 初始化失败,跳过 ref 语料索引")
|
||||
return@withContext
|
||||
}
|
||||
|
||||
val root = AppConfig.RefCorpus.ASSETS_ROOT
|
||||
val paths = RefCorpusAssetScanner.listTxtFilesUnder(app, root)
|
||||
Log.i(TAG, "[RefEmbed] 发现 ${paths.size} 个 txt(root=$root)")
|
||||
|
||||
var skipped = 0
|
||||
var embedded = 0
|
||||
var empty = 0
|
||||
var failed = 0
|
||||
|
||||
for (path in paths) {
|
||||
val raw = try {
|
||||
app.assets.open(path).bufferedReader(Charsets.UTF_8).use { it.readText() }
|
||||
} catch (e: Exception) {
|
||||
Log.w(TAG, "[RefEmbed] 读取失败 $path: ${e.message}")
|
||||
failed++
|
||||
continue
|
||||
}
|
||||
|
||||
// 题库:遇到包含 ?/? 的行,写入 questions
|
||||
val subject = extractSubjectFromRaw(raw)
|
||||
val grade = extractGradeFromPath(path)
|
||||
val questionLines = extractQuestionLines(raw)
|
||||
for (line in questionLines) {
|
||||
val content = line.trim()
|
||||
if (content.isEmpty()) continue
|
||||
val exists = questionDao.findByContentSubjectGrade(content, subject, grade)
|
||||
if (exists == null) {
|
||||
questionDao.insert(
|
||||
Question(
|
||||
id = 0,
|
||||
content = content,
|
||||
answer = null,
|
||||
subject = subject,
|
||||
grade = grade,
|
||||
difficulty = 1,
|
||||
createdAt = System.currentTimeMillis()
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
val embedText = RefTxtEmbedText.fromRawFileContent(raw)
|
||||
if (embedText.isEmpty()) {
|
||||
empty++
|
||||
continue
|
||||
}
|
||||
|
||||
val hash = sha256Hex(embedText.toByteArray(Charsets.UTF_8))
|
||||
// 同步 DAO 调用(已在 IO 线程)
|
||||
val existing = dao.getByPath(path)
|
||||
if (existing != null && existing.contentHash == hash) {
|
||||
RefEmbeddingMemoryCache.put(path, normalizeL2(existing.toFloatArray()))
|
||||
skipped++
|
||||
continue
|
||||
}
|
||||
|
||||
val vec = BgeEmbedding.getEmbedding(embedText)
|
||||
if (vec == null || vec.isEmpty()) {
|
||||
Log.w(TAG, "[RefEmbed] 嵌入为空 $path")
|
||||
failed++
|
||||
continue
|
||||
}
|
||||
|
||||
val normalized = normalizeL2(vec)
|
||||
dao.insert(
|
||||
RefTextEmbedding(
|
||||
assetPath = path,
|
||||
contentHash = hash,
|
||||
dim = normalized.size,
|
||||
embedding = floatArrayToEmbeddingBytes(normalized)
|
||||
)
|
||||
)
|
||||
RefEmbeddingMemoryCache.put(path, normalized)
|
||||
embedded++
|
||||
}
|
||||
|
||||
Log.i(
|
||||
TAG,
|
||||
"[RefEmbed] 完成 embedded=$embedded skipped=$skipped empty=$empty failed=$failed cacheSize=${RefEmbeddingMemoryCache.size()}"
|
||||
)
|
||||
}
|
||||
|
||||
private fun extractSubjectFromRaw(raw: String): String? {
|
||||
val line = raw.lineSequence()
|
||||
.map { it.trimEnd() }
|
||||
.firstOrNull { it.trimStart().startsWith("#") }
|
||||
?: return null
|
||||
val s = line.trimStart().removePrefix("#").trim()
|
||||
return s.ifEmpty { null }
|
||||
}
|
||||
|
||||
private fun extractQuestionLines(raw: String): List<String> {
|
||||
return raw.lineSequence()
|
||||
.map { it.trimEnd() }
|
||||
.filter { it.isNotBlank() }
|
||||
.filter { !it.trimStart().startsWith("#") }
|
||||
.filter { it.contains('?') || it.contains('?') }
|
||||
.toList()
|
||||
}
|
||||
|
||||
private fun extractGradeFromPath(assetPath: String): Int? {
|
||||
// example: ref/一年级上-生活适应/... or ref/二年级下-...
|
||||
val idx = assetPath.indexOf("年级")
|
||||
if (idx <= 0) return null
|
||||
val prefix = assetPath.substring(0, idx)
|
||||
val cn = prefix.lastOrNull() ?: return null
|
||||
return chineseGradeToInt(cn)
|
||||
}
|
||||
|
||||
private fun chineseGradeToInt(c: Char): Int? {
|
||||
return when (c) {
|
||||
'一' -> 1
|
||||
'二' -> 2
|
||||
'三' -> 3
|
||||
'四' -> 4
|
||||
'五' -> 5
|
||||
'六' -> 6
|
||||
'七' -> 7
|
||||
'八' -> 8
|
||||
'九' -> 9
|
||||
'十' -> 10
|
||||
else -> null
|
||||
}
|
||||
}
|
||||
|
||||
private fun normalizeL2(v: FloatArray): FloatArray {
|
||||
var sum = 0.0
|
||||
for (x in v) sum += (x * x).toDouble()
|
||||
val norm = kotlin.math.sqrt(sum).toFloat()
|
||||
if (norm <= 1e-12f) return v.copyOf()
|
||||
return FloatArray(v.size) { i -> v[i] / norm }
|
||||
}
|
||||
|
||||
private fun sha256Hex(data: ByteArray): String {
|
||||
val md = MessageDigest.getInstance("SHA-256")
|
||||
val digest = md.digest(data)
|
||||
return digest.joinToString("") { "%02x".format(it) }
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
package com.digitalperson.embedding
|
||||
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
|
||||
/**
|
||||
* ref 语料 txt 嵌入的内存缓存,键为 assets 相对路径(与数据库 [com.digitalperson.data.entity.RefTextEmbedding.assetPath] 一致)。
|
||||
*/
|
||||
object RefEmbeddingMemoryCache {
|
||||
|
||||
private val vectors = ConcurrentHashMap<String, FloatArray>()
|
||||
|
||||
fun put(assetPath: String, embedding: FloatArray) {
|
||||
vectors[assetPath] = embedding.copyOf()
|
||||
}
|
||||
|
||||
fun get(assetPath: String): FloatArray? {
|
||||
val v = vectors[assetPath] ?: return null
|
||||
return v.copyOf()
|
||||
}
|
||||
|
||||
fun clear() {
|
||||
vectors.clear()
|
||||
}
|
||||
|
||||
/** 只读快照(向量已为拷贝,可安全使用)。 */
|
||||
fun snapshot(): Map<String, FloatArray> =
|
||||
vectors.mapValues { it.value.copyOf() }
|
||||
|
||||
fun size(): Int = vectors.size
|
||||
}
|
||||
@@ -0,0 +1,95 @@
|
||||
package com.digitalperson.embedding
|
||||
|
||||
import android.content.Context
|
||||
import android.util.Log
|
||||
import com.digitalperson.config.AppConfig
|
||||
import kotlin.math.sqrt
|
||||
|
||||
data class RefImageMatch(
|
||||
val txtAssetPath: String,
|
||||
val pngAssetPath: String,
|
||||
val score: Float
|
||||
)
|
||||
|
||||
object RefImageMatcher {
|
||||
|
||||
private const val TAG = AppConfig.TAG
|
||||
|
||||
/**
|
||||
* @param threshold 余弦相似度阈值(向量已归一化时等价于 dot product)。
|
||||
*/
|
||||
fun findBestMatch(
|
||||
context: Context,
|
||||
text: String,
|
||||
threshold: Float = 0.70f
|
||||
): RefImageMatch? {
|
||||
val query = text.trim()
|
||||
if (query.isEmpty()) return null
|
||||
|
||||
if (!BgeEmbedding.isReady()) {
|
||||
val ok = BgeEmbedding.initialize(context.applicationContext)
|
||||
if (!ok) {
|
||||
Log.w(TAG, "[RefMatch] BGE not ready, skip match")
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
val q = BgeEmbedding.getEmbedding(query) ?: return null
|
||||
if (q.isEmpty()) return null
|
||||
val qn = normalizeL2(q)
|
||||
|
||||
val vectors = RefEmbeddingMemoryCache.snapshot()
|
||||
if (vectors.isEmpty()) return null
|
||||
|
||||
var bestPath: String? = null
|
||||
var bestScore = -1f
|
||||
|
||||
for ((path, v) in vectors) {
|
||||
if (v.isEmpty() || v.size != qn.size) continue
|
||||
val score = dot(qn, v)
|
||||
if (score > bestScore) {
|
||||
bestScore = score
|
||||
bestPath = path
|
||||
}
|
||||
}
|
||||
|
||||
val txtPath = bestPath ?: return null
|
||||
if (bestScore < threshold) return null
|
||||
|
||||
val pngPath = if (txtPath.endsWith(".txt", ignoreCase = true)) {
|
||||
txtPath.dropLast(4) + ".png"
|
||||
} else {
|
||||
"$txtPath.png"
|
||||
}
|
||||
|
||||
// 不在这里 decode,只检查是否存在,避免 UI 线程 IO。
|
||||
val exists = try {
|
||||
context.assets.open(pngPath).close()
|
||||
true
|
||||
} catch (_: Throwable) {
|
||||
false
|
||||
}
|
||||
if (!exists) return null
|
||||
|
||||
return RefImageMatch(
|
||||
txtAssetPath = txtPath,
|
||||
pngAssetPath = pngPath,
|
||||
score = bestScore
|
||||
)
|
||||
}
|
||||
|
||||
private fun dot(a: FloatArray, b: FloatArray): Float {
|
||||
var s = 0f
|
||||
for (i in a.indices) s += a[i] * b[i]
|
||||
return s
|
||||
}
|
||||
|
||||
private fun normalizeL2(v: FloatArray): FloatArray {
|
||||
var sum = 0.0
|
||||
for (x in v) sum += (x * x).toDouble()
|
||||
val norm = sqrt(sum).toFloat()
|
||||
if (norm <= 1e-12f) return v.copyOf()
|
||||
return FloatArray(v.size) { i -> v[i] / norm }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
package com.digitalperson.embedding
|
||||
|
||||
/**
|
||||
* 从原始 txt 中构造待嵌入文本:去掉「以 # 开头」的行(行首可含空白),其余行以换行连接。
|
||||
*/
|
||||
object RefTxtEmbedText {
|
||||
|
||||
fun fromRawFileContent(raw: String): String {
|
||||
return raw.lineSequence()
|
||||
.map { it.trimEnd() }
|
||||
.filter { line ->
|
||||
if (line.isEmpty()) return@filter false
|
||||
!line.trimStart().startsWith("#")
|
||||
}
|
||||
.joinToString("\n")
|
||||
.trim()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,353 @@
|
||||
package com.digitalperson.embedding;
|
||||
|
||||
import android.content.Context;
|
||||
import android.os.Handler;
|
||||
import android.os.Looper;
|
||||
import android.util.Log;
|
||||
|
||||
import com.digitalperson.config.AppConfig;
|
||||
import com.digitalperson.engine.BgeEngineRKNN;
|
||||
import com.digitalperson.util.FileHelper;
|
||||
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class SimilarityManager {
|
||||
private static final String TAG = "SimilarityManager";
|
||||
|
||||
private Context mContext;
|
||||
private Handler mHandler;
|
||||
private BgeEngineRKNN mBgeEngine;
|
||||
|
||||
// 测试数据相关
|
||||
private List<float[]> mTestEmbeddings; // 预计算的100个句子的嵌入(float数组格式)
|
||||
private List<SimpleMatrix> mTestEmbeddingsEJML; // 预计算的100个句子的嵌入(EJML格式)
|
||||
private SimpleMatrix mTestEmbeddingsMatrix; // 所有测试嵌入的矩阵(嵌入维度 × 句子数量)
|
||||
private List<Double> mTestEmbeddingsNorms; // 预计算的100个句子的嵌入范数
|
||||
private SimpleMatrix mTestEmbeddingsNormsMatrix; // 预计算的嵌入范数向量(1 × 句子数量)
|
||||
private SimpleMatrix mTestEmbeddingsNormsReciprocalMatrix; // 预计算的嵌入范数倒数向量(1 × 句子数量)
|
||||
private List<String> mTestSentences; // 预生成的100个测试句子
|
||||
private boolean useEJMLForSimilarity = false; // 默认使用传统for循环计算相似度
|
||||
|
||||
public interface SimilarityListener {
|
||||
void onSimilarityCalculated(float similarity, long timeTaken);
|
||||
void onPerformanceTestComplete(long traditionalTime, long ejmlTime, double speedup);
|
||||
void onError(String errorMessage);
|
||||
}
|
||||
|
||||
private SimilarityListener mListener;
|
||||
|
||||
public SimilarityManager(Context context) {
|
||||
this.mContext = context;
|
||||
this.mHandler = new Handler(Looper.getMainLooper());
|
||||
}
|
||||
|
||||
public void setListener(SimilarityListener listener) {
|
||||
this.mListener = listener;
|
||||
}
|
||||
|
||||
// 初始化BGE模型
|
||||
public void initBgeModel() {
|
||||
try {
|
||||
File bgeModelDir = FileHelper.copyBgeModels(mContext);
|
||||
if (bgeModelDir == null) {
|
||||
Log.e(TAG, "BGE model directory copy failed");
|
||||
if (mListener != null) {
|
||||
mListener.onError("BGE 模型文件复制失败");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// 初始化BGE模型
|
||||
mBgeEngine = new BgeEngineRKNN(mContext);
|
||||
String modelPath = new File(bgeModelDir, AppConfig.Bge.MODEL_FILE).getAbsolutePath();
|
||||
|
||||
// 检查模型文件是否存在
|
||||
if (!new File(modelPath).exists()) {
|
||||
Log.e(TAG, "BGE model file does not exist: " + modelPath);
|
||||
if (mListener != null) {
|
||||
mListener.onError("BGE 模型文件不存在");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
boolean success = mBgeEngine.initialize(modelPath);
|
||||
if (!success) {
|
||||
Log.e(TAG, "Failed to initialize BGE model");
|
||||
if (mListener != null) {
|
||||
mListener.onError("BGE 模型初始化失败");
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
Log.e(TAG, "Exception in initBgeModel", e);
|
||||
if (mListener != null) {
|
||||
mListener.onError("初始化BGE模型异常: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 计算BGE相似度
|
||||
public void calculateSimilarity(String text1, String text2) {
|
||||
if (mBgeEngine == null || !mBgeEngine.isInitialized()) {
|
||||
Log.w(TAG, "BGE engine not initialized, skipping similarity calculation");
|
||||
if (mListener != null) {
|
||||
mListener.onError("BGE 模型未初始化");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// 在后台线程中计算相似度
|
||||
new Thread(() -> {
|
||||
try {
|
||||
long startTime = System.currentTimeMillis();
|
||||
float similarity = mBgeEngine.calculateSimilarity(text1, text2);
|
||||
long timeTaken = System.currentTimeMillis() - startTime;
|
||||
|
||||
if (mListener != null) {
|
||||
mListener.onSimilarityCalculated(similarity, timeTaken);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
Log.e(TAG, "Exception in calculateSimilarity", e);
|
||||
if (mListener != null) {
|
||||
mListener.onError("相似度计算失败: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
}).start();
|
||||
}
|
||||
|
||||
// 测试BGE性能
|
||||
public void testPerformance(String userInput) {
|
||||
if (mBgeEngine == null || !mBgeEngine.isInitialized()) {
|
||||
Log.w(TAG, "BGE engine not initialized, skipping performance test");
|
||||
if (mListener != null) {
|
||||
mListener.onError("BGE 模型未初始化");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// 在后台线程中执行测试
|
||||
new Thread(() -> {
|
||||
try {
|
||||
// 准备测试数据
|
||||
prepareTestData();
|
||||
|
||||
// 1. 计算用户输入的嵌入
|
||||
float[] userEmbedding = mBgeEngine.getEmbedding(userInput);
|
||||
if (userEmbedding == null) {
|
||||
if (mListener != null) {
|
||||
mListener.onError("无法计算用户输入的嵌入");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// 创建EJML格式的用户嵌入
|
||||
SimpleMatrix userVec = new SimpleMatrix(userEmbedding.length, 1);
|
||||
for (int i = 0; i < userEmbedding.length; i++) {
|
||||
userVec.set(i, 0, userEmbedding[i]);
|
||||
}
|
||||
|
||||
// 方法1: 使用float数组和循环计算相似度(传统方法)
|
||||
long startTime1 = System.currentTimeMillis();
|
||||
List<Float> similarities1 = new ArrayList<>();
|
||||
|
||||
// 优化传统方法:使用范数倒数和乘法代替除法
|
||||
double normUserTraditional = 0.0;
|
||||
for (float value : userEmbedding) {
|
||||
normUserTraditional += value * value;
|
||||
}
|
||||
normUserTraditional = Math.sqrt(normUserTraditional);
|
||||
double normUserReciprocalTraditional = (normUserTraditional > 1e-10) ? 1.0 / normUserTraditional : 0.0;
|
||||
|
||||
for (int i = 0; i < mTestEmbeddings.size(); i++) {
|
||||
float[] embedding = mTestEmbeddings.get(i);
|
||||
// 计算点积
|
||||
double dotProduct = 0.0;
|
||||
for (int j = 0; j < userEmbedding.length; j++) {
|
||||
dotProduct += userEmbedding[j] * embedding[j];
|
||||
}
|
||||
// 使用乘法代替除法:cos(u,v_i) = (u·v_i) × (1/||u||) × (1/||v_i||)
|
||||
double normEmbeddingReciprocal = mTestEmbeddingsNorms.get(i) > 1e-10 ? 1.0 / mTestEmbeddingsNorms.get(i) : 0.0;
|
||||
float sim = (float) (dotProduct * normUserReciprocalTraditional * normEmbeddingReciprocal);
|
||||
similarities1.add(sim);
|
||||
}
|
||||
long timeTaken1 = System.currentTimeMillis() - startTime1;
|
||||
|
||||
// 方法2: 使用EJML计算相似度(优化方法)
|
||||
long startTime2 = System.currentTimeMillis();
|
||||
List<Float> similarities2 = new ArrayList<>();
|
||||
|
||||
// 优化1: 预计算用户向量的范数和范数倒数
|
||||
double normUser = userVec.normF();
|
||||
double normUserReciprocal = (normUser > 1e-10) ? 1.0 / normUser : 0.0;
|
||||
|
||||
// 优化2: 使用批量矩阵计算一次性计算所有点积和相似度
|
||||
if (mTestEmbeddingsMatrix != null && mTestEmbeddingsNormsReciprocalMatrix != null) {
|
||||
// 使用矩阵乘法一次性计算所有点积
|
||||
SimpleMatrix dotProducts = userVec.transpose().mult(mTestEmbeddingsMatrix);
|
||||
|
||||
// 一次性计算所有相似度
|
||||
// 使用乘法代替除法:cos(u,v_i) = (u·v_i) × (1/||u||) × (1/||v_i||)
|
||||
// 步骤1: 计算 (u·v_i) × (1/||u||)
|
||||
SimpleMatrix dotProductsScaled = dotProducts.scale(normUserReciprocal);
|
||||
// 步骤2: 逐元素乘法计算最终相似度
|
||||
SimpleMatrix similaritiesMatrix = dotProductsScaled.elementMult(mTestEmbeddingsNormsReciprocalMatrix);
|
||||
|
||||
// 将结果转换为列表
|
||||
for (int i = 0; i < similaritiesMatrix.numCols(); i++) {
|
||||
similarities2.add((float) similaritiesMatrix.get(0, i));
|
||||
}
|
||||
} else if (mTestEmbeddingsMatrix != null) {
|
||||
// 降级方案1: 只有嵌入矩阵,使用批量点积 + 循环计算相似度
|
||||
SimpleMatrix dotProducts = userVec.transpose().mult(mTestEmbeddingsMatrix);
|
||||
for (int i = 0; i < dotProducts.numCols(); i++) {
|
||||
double dotProduct = dotProducts.get(0, i);
|
||||
double normEmbeddingReciprocal = mTestEmbeddingsNorms.get(i) > 1e-10 ? 1.0 / mTestEmbeddingsNorms.get(i) : 0.0;
|
||||
float sim = (float) (dotProduct * normUserReciprocal * normEmbeddingReciprocal);
|
||||
similarities2.add(sim);
|
||||
}
|
||||
} else {
|
||||
// 降级方案2: 使用循环计算(当矩阵未初始化时)
|
||||
for (int i = 0; i < mTestEmbeddingsEJML.size(); i++) {
|
||||
SimpleMatrix embedding = mTestEmbeddingsEJML.get(i);
|
||||
double dotProduct = userVec.dot(embedding);
|
||||
double normEmbeddingReciprocal = mTestEmbeddingsNorms.get(i) > 1e-10 ? 1.0 / mTestEmbeddingsNorms.get(i) : 0.0;
|
||||
float sim = (float) (dotProduct * normUserReciprocal * normEmbeddingReciprocal);
|
||||
similarities2.add(sim);
|
||||
}
|
||||
}
|
||||
long timeTaken2 = System.currentTimeMillis() - startTime2;
|
||||
|
||||
// 计算速度提升
|
||||
double speedup = (double) timeTaken1 / timeTaken2;
|
||||
|
||||
if (mListener != null) {
|
||||
mListener.onPerformanceTestComplete(timeTaken1, timeTaken2, speedup);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
Log.e(TAG, "Exception in testPerformance", e);
|
||||
if (mListener != null) {
|
||||
mListener.onError("性能测试失败: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
}).start();
|
||||
}
|
||||
|
||||
// 准备测试数据
|
||||
private void prepareTestData() {
|
||||
if (mTestEmbeddings == null || mTestEmbeddings.isEmpty() ||
|
||||
mTestEmbeddingsEJML == null || mTestEmbeddingsEJML.isEmpty() ||
|
||||
mTestEmbeddingsNorms == null || mTestEmbeddingsNorms.isEmpty() ||
|
||||
mTestEmbeddingsMatrix == null ||
|
||||
mTestEmbeddingsNormsMatrix == null ||
|
||||
mTestEmbeddingsNormsReciprocalMatrix == null) {
|
||||
// 生成100个测试句子
|
||||
generateTestSentences();
|
||||
|
||||
// 预计算嵌入(float数组格式)
|
||||
mTestEmbeddings = new ArrayList<>();
|
||||
mTestEmbeddingsEJML = new ArrayList<>();
|
||||
mTestEmbeddingsNorms = new ArrayList<>();
|
||||
List<String> validTestSentences = new ArrayList<>(); // 只包含成功获取嵌入的句子
|
||||
|
||||
for (String sentence : mTestSentences) {
|
||||
float[] embedding = mBgeEngine.getEmbedding(sentence);
|
||||
if (embedding != null) {
|
||||
// 添加float数组格式的嵌入
|
||||
mTestEmbeddings.add(embedding);
|
||||
|
||||
// 添加EJML格式的嵌入
|
||||
SimpleMatrix vec = new SimpleMatrix(embedding.length, 1);
|
||||
for (int i = 0; i < embedding.length; i++) {
|
||||
vec.set(i, 0, embedding[i]);
|
||||
}
|
||||
mTestEmbeddingsEJML.add(vec);
|
||||
|
||||
// 预计算嵌入的范数
|
||||
double norm = vec.normF();
|
||||
mTestEmbeddingsNorms.add(norm);
|
||||
|
||||
// 添加到有效句子列表
|
||||
validTestSentences.add(sentence);
|
||||
}
|
||||
}
|
||||
|
||||
// 更新mTestSentences为只包含有效句子的列表,确保索引匹配
|
||||
mTestSentences = validTestSentences;
|
||||
|
||||
// 创建测试嵌入矩阵(嵌入维度 × 句子数量)
|
||||
if (!mTestEmbeddingsEJML.isEmpty()) {
|
||||
int embeddingDim = mTestEmbeddingsEJML.get(0).numRows();
|
||||
int numSentences = mTestEmbeddingsEJML.size();
|
||||
mTestEmbeddingsMatrix = new SimpleMatrix(embeddingDim, numSentences);
|
||||
|
||||
for (int i = 0; i < numSentences; i++) {
|
||||
SimpleMatrix embedding = mTestEmbeddingsEJML.get(i);
|
||||
for (int j = 0; j < embeddingDim; j++) {
|
||||
mTestEmbeddingsMatrix.set(j, i, embedding.get(j, 0));
|
||||
}
|
||||
}
|
||||
Log.d(TAG, "Created test embeddings matrix, size: " + embeddingDim + " × " + numSentences);
|
||||
}
|
||||
|
||||
// 创建测试嵌入范数矩阵(1 × 句子数量)
|
||||
if (!mTestEmbeddingsNorms.isEmpty()) {
|
||||
int numSentences = mTestEmbeddingsNorms.size();
|
||||
mTestEmbeddingsNormsMatrix = new SimpleMatrix(1, numSentences);
|
||||
mTestEmbeddingsNormsReciprocalMatrix = new SimpleMatrix(1, numSentences);
|
||||
for (int i = 0; i < numSentences; i++) {
|
||||
double norm = mTestEmbeddingsNorms.get(i);
|
||||
mTestEmbeddingsNormsMatrix.set(0, i, norm);
|
||||
// 预计算范数倒数,避免在相似度计算时重复计算
|
||||
double reciprocal = (norm > 1e-10) ? 1.0 / norm : 0.0;
|
||||
mTestEmbeddingsNormsReciprocalMatrix.set(0, i, reciprocal);
|
||||
}
|
||||
Log.d(TAG, "Created test embeddings norms matrix, size: 1 × " + numSentences);
|
||||
Log.d(TAG, "Created test embeddings norms reciprocal matrix, size: 1 × " + numSentences);
|
||||
}
|
||||
|
||||
Log.d(TAG, "Generated embeddings for " + mTestEmbeddings.size() + " test sentences");
|
||||
}
|
||||
}
|
||||
|
||||
// 生成测试句子
|
||||
private void generateTestSentences() {
|
||||
mTestSentences = new ArrayList<>();
|
||||
|
||||
// 基础词汇
|
||||
String[] words = {"手", "头", "眼睛", "鼻子", "嘴巴", "耳朵", "头发", "手指", "脚", "腿"};
|
||||
String[] adjectives = {"大", "小", "长", "短", "多", "少", "漂亮", "丑陋", "干净", "脏"};
|
||||
String[] verbs = {"有", "是", "看", "听", "说", "走", "跑", "跳", "吃", "喝"};
|
||||
String[] subjects = {"我", "你", "他", "她", "它", "我们", "你们", "他们", "这个", "那个"};
|
||||
|
||||
// 添加一些固定的测试句子
|
||||
mTestSentences.add("我的手有很多手指头");
|
||||
mTestSentences.add("他的头发很长");
|
||||
mTestSentences.add("她的眼睛很大");
|
||||
mTestSentences.add("这个鼻子很高");
|
||||
mTestSentences.add("那个嘴巴很小");
|
||||
|
||||
Log.d(TAG, "Generated " + mTestSentences.size() + " test sentences");
|
||||
}
|
||||
|
||||
// 设置是否使用EJML进行相似度计算
|
||||
public void setUseEJMLForSimilarity(boolean useEJML) {
|
||||
this.useEJMLForSimilarity = useEJML;
|
||||
}
|
||||
|
||||
// 检查BGE模型是否初始化
|
||||
public boolean isInitialized() {
|
||||
return mBgeEngine != null && mBgeEngine.isInitialized();
|
||||
}
|
||||
|
||||
// 释放BGE模型资源
|
||||
public void deinitialize() {
|
||||
if (mBgeEngine != null) {
|
||||
mBgeEngine.deinitialize();
|
||||
mBgeEngine = null;
|
||||
Log.d(TAG, "BGE engine deinitialized");
|
||||
}
|
||||
}
|
||||
}
|
||||
104
app/src/main/java/com/digitalperson/engine/BasicTokenizer.java
Normal file
104
app/src/main/java/com/digitalperson/engine/BasicTokenizer.java
Normal file
@@ -0,0 +1,104 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
package com.digitalperson.engine;
|
||||
|
||||
import com.google.common.base.Ascii;
|
||||
import com.google.common.collect.Iterables;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
/** Basic tokenization (punctuation splitting, lower casing, etc.) */
|
||||
public final class BasicTokenizer {
|
||||
private final boolean doLowerCase;
|
||||
|
||||
public BasicTokenizer(boolean doLowerCase) {
|
||||
this.doLowerCase = doLowerCase;
|
||||
}
|
||||
|
||||
public List<String> tokenize(String text) {
|
||||
String cleanedText = cleanText(text);
|
||||
|
||||
List<String> origTokens = whitespaceTokenize(cleanedText);
|
||||
|
||||
StringBuilder stringBuilder = new StringBuilder();
|
||||
for (String token : origTokens) {
|
||||
if (doLowerCase) {
|
||||
token = Ascii.toLowerCase(token);
|
||||
}
|
||||
List<String> list = runSplitOnPunc(token);
|
||||
for (String subToken : list) {
|
||||
stringBuilder.append(subToken).append(" ");
|
||||
}
|
||||
}
|
||||
return whitespaceTokenize(stringBuilder.toString());
|
||||
}
|
||||
|
||||
/* Performs invalid character removal and whitespace cleanup on text. */
|
||||
static String cleanText(String text) {
|
||||
if (text == null) {
|
||||
throw new NullPointerException("The input String is null.");
|
||||
}
|
||||
|
||||
StringBuilder stringBuilder = new StringBuilder("");
|
||||
for (int index = 0; index < text.length(); index++) {
|
||||
char ch = text.charAt(index);
|
||||
|
||||
// Skip the characters that cannot be used.
|
||||
if (CharChecker.isInvalid(ch) || CharChecker.isControl(ch)) {
|
||||
continue;
|
||||
}
|
||||
if (CharChecker.isWhitespace(ch)) {
|
||||
stringBuilder.append(" ");
|
||||
} else {
|
||||
stringBuilder.append(ch);
|
||||
}
|
||||
}
|
||||
return stringBuilder.toString();
|
||||
}
|
||||
|
||||
/* Runs basic whitespace cleaning and splitting on a piece of text. */
|
||||
static List<String> whitespaceTokenize(String text) {
|
||||
if (text == null) {
|
||||
throw new NullPointerException("The input String is null.");
|
||||
}
|
||||
return Arrays.asList(text.split(" "));
|
||||
}
|
||||
|
||||
/* Splits punctuation on a piece of text. */
|
||||
static List<String> runSplitOnPunc(String text) {
|
||||
if (text == null) {
|
||||
throw new NullPointerException("The input String is null.");
|
||||
}
|
||||
|
||||
List<String> tokens = new ArrayList<>();
|
||||
boolean startNewWord = true;
|
||||
for (int i = 0; i < text.length(); i++) {
|
||||
char ch = text.charAt(i);
|
||||
if (CharChecker.isPunctuation(ch)) {
|
||||
tokens.add(String.valueOf(ch));
|
||||
startNewWord = true;
|
||||
} else {
|
||||
if (startNewWord) {
|
||||
tokens.add("");
|
||||
startNewWord = false;
|
||||
}
|
||||
tokens.set(tokens.size() - 1, Iterables.getLast(tokens) + ch);
|
||||
}
|
||||
}
|
||||
|
||||
return tokens;
|
||||
}
|
||||
}
|
||||
263
app/src/main/java/com/digitalperson/engine/BgeEngineRKNN.java
Normal file
263
app/src/main/java/com/digitalperson/engine/BgeEngineRKNN.java
Normal file
@@ -0,0 +1,263 @@
|
||||
package com.digitalperson.engine;
|
||||
|
||||
import android.content.Context;
|
||||
import android.util.Log;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
public class BgeEngineRKNN {
|
||||
private static final String TAG = "BgeEngineRKNN";
|
||||
private final long nativePtr;
|
||||
private final Context mContext;
|
||||
private boolean mIsInitialized = false;
|
||||
private FullTokenizer mFullTokenizer = null;
|
||||
private Map<String, Integer> mVocabMap = new HashMap<>();
|
||||
|
||||
static {
|
||||
try {
|
||||
// Load dependent libraries
|
||||
System.loadLibrary("rknnrt");
|
||||
System.loadLibrary("bgeEngine");
|
||||
Log.d(TAG, "Successfully loaded librknnrt.so and libbgeEngine.so");
|
||||
} catch (UnsatisfiedLinkError e) {
|
||||
Log.e(TAG, "Failed to load native library", e);
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
public BgeEngineRKNN(Context context) {
|
||||
mContext = context;
|
||||
try {
|
||||
nativePtr = createBgeEngine();
|
||||
if (nativePtr == 0) {
|
||||
throw new RuntimeException("Failed to create native BGE engine");
|
||||
}
|
||||
} catch (UnsatisfiedLinkError e) {
|
||||
Log.e(TAG, "Failed to load native library", e);
|
||||
throw new RuntimeException("Failed to load native library: " + e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
|
||||
public boolean initialize(String modelPath) {
|
||||
// 自动查找词汇表路径
|
||||
String vocabPath = modelPath.replace("bge-small-zh-v1.5.rknn", "vocab.txt");
|
||||
return initialize(modelPath, vocabPath);
|
||||
}
|
||||
|
||||
public boolean initialize(String modelPath, String vocabPath) {
|
||||
if (mIsInitialized) {
|
||||
Log.i(TAG, "Model already initialized");
|
||||
return true;
|
||||
}
|
||||
|
||||
Log.d(TAG, "Loading BGE model: " + modelPath);
|
||||
Log.d(TAG, "Loading vocab: " + vocabPath);
|
||||
|
||||
// 加载词汇表
|
||||
if (!loadVocab(vocabPath)) {
|
||||
Log.e(TAG, "Failed to load vocab");
|
||||
return false;
|
||||
}
|
||||
|
||||
// 创建FullTokenizer实例
|
||||
mFullTokenizer = new FullTokenizer(mVocabMap, true); // true表示使用小写
|
||||
|
||||
int ret = loadModel(nativePtr, modelPath, vocabPath);
|
||||
if (ret == 0) {
|
||||
mIsInitialized = true;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load vocabulary from file
|
||||
* @param vocabPath Path to vocabulary file
|
||||
* @return True if successful, false otherwise
|
||||
*/
|
||||
private boolean loadVocab(String vocabPath) {
|
||||
try {
|
||||
java.io.BufferedReader reader = new java.io.BufferedReader(
|
||||
new java.io.FileReader(vocabPath));
|
||||
String line;
|
||||
int index = 0;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
line = line.trim();
|
||||
if (!line.isEmpty()) {
|
||||
mVocabMap.put(line, index);
|
||||
index++;
|
||||
}
|
||||
}
|
||||
reader.close();
|
||||
Log.d(TAG, "Vocab loaded successfully with " + mVocabMap.size() + " tokens");
|
||||
return true;
|
||||
} catch (java.io.IOException e) {
|
||||
Log.e(TAG, "Failed to load vocab: " + e.getMessage());
|
||||
e.printStackTrace();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public void deinitialize() {
|
||||
if (nativePtr != 0) {
|
||||
freeModel(nativePtr);
|
||||
}
|
||||
mIsInitialized = false;
|
||||
}
|
||||
|
||||
public boolean isInitialized() {
|
||||
return mIsInitialized;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the embedding dimension of the model
|
||||
* @return Embedding dimension
|
||||
*/
|
||||
public int getEmbeddingDim() {
|
||||
if (!mIsInitialized) {
|
||||
Log.e(TAG, "Engine not initialized");
|
||||
return -1;
|
||||
}
|
||||
return getEmbeddingDim(nativePtr);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate embedding for a single text
|
||||
* @param text Input text
|
||||
* @return Embedding vector as float array
|
||||
*/
|
||||
public float[] getEmbedding(String text) {
|
||||
if (!mIsInitialized) {
|
||||
Log.e(TAG, "Engine not initialized");
|
||||
return null;
|
||||
}
|
||||
return getEmbeddingNative(nativePtr, text);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate cosine similarity between two texts
|
||||
* @param text1 First text
|
||||
* @param text2 Second text
|
||||
* @return Cosine similarity score between -1.0 and 1.0
|
||||
*/
|
||||
public float calculateSimilarity(String text1, String text2) {
|
||||
if (!mIsInitialized) {
|
||||
Log.e(TAG, "Engine not initialized");
|
||||
return 0.0f;
|
||||
}
|
||||
return calculateSimilarityNative(nativePtr, text1, text2);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get token ids for a text using FullTokenizer
|
||||
* @param text Input text
|
||||
* @return Token ids as int array
|
||||
*/
|
||||
public int[] getTokens(String text) {
|
||||
if (!mIsInitialized) {
|
||||
Log.e(TAG, "Engine not initialized");
|
||||
return null;
|
||||
}
|
||||
if (mFullTokenizer == null) {
|
||||
Log.e(TAG, "FullTokenizer not initialized");
|
||||
return null;
|
||||
}
|
||||
|
||||
// 使用FullTokenizer进行tokenization
|
||||
java.util.List<String> tokens = mFullTokenizer.tokenize(text);
|
||||
java.util.List<Integer> ids = mFullTokenizer.convertTokensToIds(tokens);
|
||||
|
||||
// 转换为int数组
|
||||
int[] result = new int[ids.size()];
|
||||
for (int i = 0; i < ids.size(); i++) {
|
||||
result[i] = ids.get(i);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate [UNK] ratio for token ids
|
||||
* @param tokens Token ids array
|
||||
* @return [UNK] ratio between 0.0 and 1.0
|
||||
*/
|
||||
public float calculateUnkRatio(int[] tokens) {
|
||||
if (!mIsInitialized) {
|
||||
Log.e(TAG, "Engine not initialized");
|
||||
return 0.0f;
|
||||
}
|
||||
if (mFullTokenizer == null || tokens == null || tokens.length == 0) {
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
// 获取[UNK]的id
|
||||
int unkId = mVocabMap.getOrDefault("[UNK]", -1);
|
||||
if (unkId == -1) {
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
// 统计[UNK]的数量
|
||||
int unkCount = 0;
|
||||
for (int token : tokens) {
|
||||
if (token == unkId) {
|
||||
unkCount++;
|
||||
}
|
||||
}
|
||||
|
||||
return (float) unkCount / tokens.length;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get BERT model inputs (input_ids, attention_mask, token_type_ids)
|
||||
* @param text Input text
|
||||
* @param maxSeqLength Maximum sequence length
|
||||
* @return Array of int arrays: [input_ids, attention_mask, token_type_ids]
|
||||
*/
|
||||
public int[][] getBertInputs(String text, int maxSeqLength) {
|
||||
if (!mIsInitialized) {
|
||||
Log.e(TAG, "Engine not initialized");
|
||||
return null;
|
||||
}
|
||||
if (mFullTokenizer == null) {
|
||||
Log.e(TAG, "FullTokenizer not initialized");
|
||||
return null;
|
||||
}
|
||||
|
||||
// 使用FullTokenizer进行tokenization
|
||||
java.util.List<String> tokens = mFullTokenizer.tokenize(text);
|
||||
java.util.List<Integer> ids = mFullTokenizer.convertTokensToIds(tokens);
|
||||
|
||||
// 准备BERT输入
|
||||
int[] inputIds = new int[maxSeqLength];
|
||||
int[] attentionMask = new int[maxSeqLength];
|
||||
int[] tokenTypeIds = new int[maxSeqLength];
|
||||
|
||||
// 填充[CLS] token
|
||||
inputIds[0] = mVocabMap.getOrDefault("[CLS]", 101);
|
||||
attentionMask[0] = 1;
|
||||
|
||||
// 填充文本token
|
||||
int textLength = ids.size();
|
||||
int maxTextLength = maxSeqLength - 2; // 减去[CLS]和[SEP]
|
||||
int actualLength = Math.min(textLength, maxTextLength);
|
||||
|
||||
for (int i = 0; i < actualLength; i++) {
|
||||
inputIds[i + 1] = ids.get(i);
|
||||
attentionMask[i + 1] = 1;
|
||||
}
|
||||
|
||||
// 填充[SEP] token
|
||||
inputIds[actualLength + 1] = mVocabMap.getOrDefault("[SEP]", 102);
|
||||
attentionMask[actualLength + 1] = 1;
|
||||
|
||||
return new int[][]{inputIds, attentionMask, tokenTypeIds};
|
||||
}
|
||||
|
||||
// Native methods
|
||||
private native long createBgeEngine();
|
||||
private native int loadModel(long nativePtr, String modelPath, String vocabPath);
|
||||
private native void freeModel(long ptr);
|
||||
private native float[] getEmbeddingNative(long ptr, String text);
|
||||
private native int getEmbeddingDim(long ptr);
|
||||
private native float calculateSimilarityNative(long ptr, String text1, String text2);
|
||||
}
|
||||
58
app/src/main/java/com/digitalperson/engine/CharChecker.java
Normal file
58
app/src/main/java/com/digitalperson/engine/CharChecker.java
Normal file
@@ -0,0 +1,58 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
package com.digitalperson.engine;
|
||||
|
||||
/** To check whether a char is whitespace/control/punctuation. */
|
||||
final class CharChecker {
|
||||
|
||||
/** To judge whether it's an empty or unknown character. */
|
||||
public static boolean isInvalid(char ch) {
|
||||
return (ch == 0 || ch == 0xfffd);
|
||||
}
|
||||
|
||||
/** To judge whether it's a control character(exclude whitespace). */
|
||||
public static boolean isControl(char ch) {
|
||||
if (Character.isWhitespace(ch)) {
|
||||
return false;
|
||||
}
|
||||
int type = Character.getType(ch);
|
||||
return (type == Character.CONTROL || type == Character.FORMAT);
|
||||
}
|
||||
|
||||
/** To judge whether it can be regarded as a whitespace. */
|
||||
public static boolean isWhitespace(char ch) {
|
||||
if (Character.isWhitespace(ch)) {
|
||||
return true;
|
||||
}
|
||||
int type = Character.getType(ch);
|
||||
return (type == Character.SPACE_SEPARATOR
|
||||
|| type == Character.LINE_SEPARATOR
|
||||
|| type == Character.PARAGRAPH_SEPARATOR);
|
||||
}
|
||||
|
||||
/** To judge whether it's a punctuation. */
|
||||
public static boolean isPunctuation(char ch) {
|
||||
int type = Character.getType(ch);
|
||||
return (type == Character.CONNECTOR_PUNCTUATION
|
||||
|| type == Character.DASH_PUNCTUATION
|
||||
|| type == Character.START_PUNCTUATION
|
||||
|| type == Character.END_PUNCTUATION
|
||||
|| type == Character.INITIAL_QUOTE_PUNCTUATION
|
||||
|| type == Character.FINAL_QUOTE_PUNCTUATION
|
||||
|| type == Character.OTHER_PUNCTUATION);
|
||||
}
|
||||
|
||||
private CharChecker() {}
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
package com.digitalperson.engine;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* A java realization of Bert tokenization. Original python code:
|
||||
* https://github.com/google-research/bert/blob/master/tokenization.py runs full tokenization to
|
||||
* tokenize a String into split subtokens or ids.
|
||||
*/
|
||||
public final class FullTokenizer {
|
||||
private final BasicTokenizer basicTokenizer;
|
||||
private final WordpieceTokenizer wordpieceTokenizer;
|
||||
private final Map<String, Integer> dic;
|
||||
|
||||
public FullTokenizer(Map<String, Integer> inputDic, boolean doLowerCase) {
|
||||
dic = inputDic;
|
||||
basicTokenizer = new BasicTokenizer(doLowerCase);
|
||||
wordpieceTokenizer = new WordpieceTokenizer(inputDic);
|
||||
}
|
||||
|
||||
public List<String> tokenize(String text) {
|
||||
List<String> splitTokens = new ArrayList<>();
|
||||
for (String token : basicTokenizer.tokenize(text)) {
|
||||
splitTokens.addAll(wordpieceTokenizer.tokenize(token));
|
||||
}
|
||||
return splitTokens;
|
||||
}
|
||||
|
||||
public List<Integer> convertTokensToIds(List<String> tokens) {
|
||||
List<Integer> outputIds = new ArrayList<>();
|
||||
for (String token : tokens) {
|
||||
outputIds.add(dic.get(token));
|
||||
}
|
||||
return outputIds;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
package com.digitalperson.engine;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/** Word piece tokenization to split a piece of text into its word pieces. */
|
||||
public final class WordpieceTokenizer {
|
||||
private final Map<String, Integer> dic;
|
||||
|
||||
private static final String UNKNOWN_TOKEN = "[UNK]"; // For unknown words.
|
||||
private static final int MAX_INPUTCHARS_PER_WORD = 200;
|
||||
|
||||
public WordpieceTokenizer(Map<String, Integer> vocab) {
|
||||
dic = vocab;
|
||||
}
|
||||
|
||||
/**
|
||||
* Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first
|
||||
* algorithm to perform tokenization using the given vocabulary. For example: input = "unaffable",
|
||||
* output = ["un", "##aff", "##able"].
|
||||
*
|
||||
* @param text: A single token or whitespace separated tokens. This should have already been
|
||||
* passed through `BasicTokenizer.
|
||||
* @return A list of wordpiece tokens.
|
||||
*/
|
||||
public List<String> tokenize(String text) {
|
||||
if (text == null) {
|
||||
throw new NullPointerException("The input String is null.");
|
||||
}
|
||||
|
||||
List<String> outputTokens = new ArrayList<>();
|
||||
for (String token : BasicTokenizer.whitespaceTokenize(text)) {
|
||||
|
||||
if (token.length() > MAX_INPUTCHARS_PER_WORD) {
|
||||
outputTokens.add(UNKNOWN_TOKEN);
|
||||
continue;
|
||||
}
|
||||
|
||||
boolean isBad = false; // Mark if a word cannot be tokenized into known subwords.
|
||||
int start = 0;
|
||||
List<String> subTokens = new ArrayList<>();
|
||||
|
||||
while (start < token.length()) {
|
||||
String curSubStr = "";
|
||||
|
||||
int end = token.length(); // Longer substring matches first.
|
||||
while (start < end) {
|
||||
String subStr =
|
||||
(start == 0) ? token.substring(start, end) : "##" + token.substring(start, end);
|
||||
if (dic.containsKey(subStr)) {
|
||||
curSubStr = subStr;
|
||||
break;
|
||||
}
|
||||
end--;
|
||||
}
|
||||
|
||||
// The word doesn't contain any known subwords.
|
||||
if ("".equals(curSubStr)) {
|
||||
isBad = true;
|
||||
break;
|
||||
}
|
||||
|
||||
// curSubStr is the longeset subword that can be found.
|
||||
subTokens.add(curSubStr);
|
||||
|
||||
// Proceed to tokenize the resident string.
|
||||
start = end;
|
||||
}
|
||||
|
||||
if (isBad) {
|
||||
outputTokens.add(UNKNOWN_TOKEN);
|
||||
} else {
|
||||
outputTokens.addAll(subTokens);
|
||||
}
|
||||
}
|
||||
|
||||
return outputTokens;
|
||||
}
|
||||
}
|
||||
@@ -185,7 +185,7 @@ class QuestionGenerationAgent(
|
||||
val generationPrompt = buildGenerationPrompt(prompt, userProfile)
|
||||
|
||||
// 3. 调用大模型生成题目
|
||||
generateQuestionFromLLM(generationPrompt) { generatedQuestion ->
|
||||
generateQuestionFromLLM(generationPrompt, prompt) { generatedQuestion ->
|
||||
if (generatedQuestion == null) {
|
||||
Log.w(TAG, "Failed to generate question")
|
||||
return@generateQuestionFromLLM
|
||||
@@ -301,18 +301,22 @@ class QuestionGenerationAgent(
|
||||
/**
|
||||
* 调用LLM生成题目
|
||||
*/
|
||||
private fun generateQuestionFromLLM(prompt: String, onResult: (GeneratedQuestion?) -> Unit) {
|
||||
private fun generateQuestionFromLLM(
|
||||
promptText: String,
|
||||
promptMeta: QuestionPrompt,
|
||||
onResult: (GeneratedQuestion?) -> Unit
|
||||
) {
|
||||
// 优先使用本地LLM,如果不可用则使用云端LLM
|
||||
if (llmManager != null) {
|
||||
// 使用本地LLM
|
||||
llmManager.generate(prompt) { response: String ->
|
||||
parseGeneratedQuestion(response, onResult)
|
||||
llmManager.generate(promptText) { response: String ->
|
||||
parseGeneratedQuestion(response, promptMeta, onResult)
|
||||
}
|
||||
} else if (cloudLLMGenerator != null) {
|
||||
// 使用云端LLM
|
||||
Log.d(TAG, "Using cloud LLM to generate question")
|
||||
cloudLLMGenerator.invoke(prompt) { response ->
|
||||
parseGeneratedQuestion(response, onResult)
|
||||
cloudLLMGenerator.invoke(promptText) { response ->
|
||||
parseGeneratedQuestion(response, promptMeta, onResult)
|
||||
}
|
||||
} else {
|
||||
Log.e(TAG, "No LLM available (neither local nor cloud)")
|
||||
@@ -323,16 +327,20 @@ class QuestionGenerationAgent(
|
||||
/**
|
||||
* 解析生成的题目
|
||||
*/
|
||||
private fun parseGeneratedQuestion(response: String, onResult: (GeneratedQuestion?) -> Unit) {
|
||||
private fun parseGeneratedQuestion(
|
||||
response: String,
|
||||
promptMeta: QuestionPrompt,
|
||||
onResult: (GeneratedQuestion?) -> Unit
|
||||
) {
|
||||
try {
|
||||
val json = extractJsonFromResponse(response)
|
||||
if (json != null) {
|
||||
val question = GeneratedQuestion(
|
||||
content = json.getString("content"),
|
||||
answer = json.getString("answer"),
|
||||
subject = "生活适应",
|
||||
grade = 1,
|
||||
difficulty = 1
|
||||
subject = promptMeta.subject,
|
||||
grade = promptMeta.grade,
|
||||
difficulty = promptMeta.difficulty
|
||||
)
|
||||
onResult(question)
|
||||
} else {
|
||||
|
||||
@@ -7,6 +7,7 @@ import android.util.Log
|
||||
import com.digitalperson.config.AppConfig
|
||||
import java.io.File
|
||||
import java.io.FileOutputStream
|
||||
import java.io.InputStream
|
||||
|
||||
object FileHelper {
|
||||
private const val TAG = AppConfig.TAG
|
||||
@@ -64,6 +65,54 @@ object FileHelper {
|
||||
val files = arrayOf(AppConfig.FaceRecognition.MODEL_NAME)
|
||||
return copyAssetsToInternal(context, AppConfig.FaceRecognition.MODEL_DIR, outDir, files)
|
||||
}
|
||||
|
||||
/**
|
||||
* 将 BGE 相关文件从 assets 复制到 [Context.getFilesDir]/[AppConfig.Bge.ASSET_DIR]。
|
||||
* 若已存在且长度与 asset 一致则跳过(与 [com.digitalperson.embedding.SimilarityManager] 行为一致)。
|
||||
*/
|
||||
@JvmStatic
|
||||
fun copyBgeModels(context: Context): File? {
|
||||
val assetDir = AppConfig.Bge.ASSET_DIR
|
||||
val modelDir = File(context.filesDir, assetDir).apply { mkdirs() }
|
||||
val files = arrayOf(
|
||||
AppConfig.Bge.MODEL_FILE,
|
||||
"vocab.txt",
|
||||
"tokenizer.json",
|
||||
"tokenizer_config.json"
|
||||
)
|
||||
for (name in files) {
|
||||
val outFile = File(modelDir, name)
|
||||
val assetPath = "$assetDir/$name"
|
||||
try {
|
||||
var skip = false
|
||||
if (outFile.exists()) {
|
||||
context.assets.open(assetPath).use { input ->
|
||||
val assetSize = assetSizeOrNegative(input)
|
||||
if (assetSize >= 0 && outFile.length() == assetSize) skip = true
|
||||
}
|
||||
}
|
||||
if (skip) continue
|
||||
context.assets.open(assetPath).use { input ->
|
||||
FileOutputStream(outFile).use { output -> input.copyTo(output) }
|
||||
}
|
||||
Log.i(TAG, "Copied BGE asset: $name")
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "copyBgeModels failed for $name: ${e.message}", e)
|
||||
return null
|
||||
}
|
||||
}
|
||||
return modelDir
|
||||
}
|
||||
|
||||
/** [InputStream.available] 在部分实现上不可靠;失败时返回 -1,强制重拷。 */
|
||||
private fun assetSizeOrNegative(input: InputStream): Long {
|
||||
return try {
|
||||
val n = input.available()
|
||||
if (n > 0) n.toLong() else -1L
|
||||
} catch (_: Exception) {
|
||||
-1L
|
||||
}
|
||||
}
|
||||
|
||||
fun ensureDir(dir: File): File {
|
||||
if (!dir.exists()) {
|
||||
|
||||
Reference in New Issue
Block a user