face regonition refine

This commit is contained in:
gcw_4spBpAfv
2026-03-10 15:40:05 +08:00
parent d5767156b9
commit ec1f7d2e72
1579 changed files with 4286 additions and 319 deletions

View File

@@ -0,0 +1,69 @@
{
"version": 1,
"questions": [
{
"content": "老师要去菜市场买菜需要付4元钱。我给了老奶奶5元。需要找回多少钱",
"answer": "1",
"subject": "数学",
"grade": 1,
"difficulty": 1
},
{
"content": "老师去超市买了一个2块钱的苹果一个3块钱的火龙果一共要花多少钱呀",
"answer": "5",
"subject": "数学",
"grade": 1,
"difficulty": 1
},
{
"content": "10-4等于多少",
"answer": "6",
"subject": "数学",
"grade": 1,
"difficulty": 1
},
{
"content": "5+7等于多少",
"answer": "12",
"subject": "数学",
"grade": 1,
"difficulty": 1
},
{
"content": "8-3等于多少",
"answer": "5",
"subject": "数学",
"grade": 1,
"difficulty": 1
},
{
"content": "我们上完厕所应该干什么呀?",
"answer": "洗手",
"subject": "生活适应",
"grade": 1,
"difficulty": 1
},
{
"content": "太阳从哪边升起呀?",
"answer": "东方",
"subject": "生活适应",
"grade": 1,
"difficulty": 1
},
{
"content": "一年有几个季节呀?",
"answer": "4",
"subject": "生活适应",
"grade": 1,
"difficulty": 1
},
{
"content": "你知道厕所怎么走吗?帮老师指个路",
"answer": "4",
"subject": "生活适应",
"grade": 1,
"difficulty": 1
}
]
}

View File

@@ -123,10 +123,12 @@ std::vector<RetinaFaceEngineRKNN::PriorBox> RetinaFaceEngineRKNN::buildPriors()
bool RetinaFaceEngineRKNN::parseRetinaOutputs(
rknn_output* outputs,
std::vector<float>* locOut,
std::vector<float>* scoreOut) const {
std::vector<float>* scoreOut,
std::vector<float>* landmarkOut) const {
std::vector<std::vector<float>> locCandidates;
std::vector<std::vector<float>> confCandidates2;
std::vector<std::vector<float>> scoreCandidates1;
std::vector<std::vector<float>> landmarkCandidates;
const int anchors8 = (inputSize_ / 8) * (inputSize_ / 8) * 2;
const int anchors16 = (inputSize_ / 16) * (inputSize_ / 16) * 2;
@@ -141,6 +143,9 @@ bool RetinaFaceEngineRKNN::parseRetinaOutputs(
const int expectedConf8_1 = anchors8;
const int expectedConf16_1 = anchors16;
const int expectedConf32_1 = anchors32;
const int expectedLmk8 = anchors8 * 10;
const int expectedLmk16 = anchors16 * 10;
const int expectedLmk32 = anchors32 * 10;
for (uint32_t i = 0; i < ioNum_.n_output; ++i) {
const size_t elems = tensorElemCount(outputAttrs_[i]);
@@ -162,10 +167,15 @@ bool RetinaFaceEngineRKNN::parseRetinaOutputs(
scoreCandidates1.push_back(std::move(data));
continue;
}
if (e == expectedLmk8 || e == expectedLmk16 || e == expectedLmk32 || e == totalAnchors * 10) {
landmarkCandidates.push_back(std::move(data));
continue;
}
}
locOut->clear();
scoreOut->clear();
landmarkOut->clear();
auto sortBySize = [](const std::vector<float>& a, const std::vector<float>& b) {
return a.size() > b.size();
@@ -173,6 +183,7 @@ bool RetinaFaceEngineRKNN::parseRetinaOutputs(
std::sort(locCandidates.begin(), locCandidates.end(), sortBySize);
std::sort(confCandidates2.begin(), confCandidates2.end(), sortBySize);
std::sort(scoreCandidates1.begin(), scoreCandidates1.end(), sortBySize);
std::sort(landmarkCandidates.begin(), landmarkCandidates.end(), sortBySize);
auto mergeLoc = [&]() -> bool {
if (locCandidates.empty()) return false;
@@ -246,16 +257,41 @@ bool RetinaFaceEngineRKNN::parseRetinaOutputs(
return false;
};
auto mergeLandmark = [&]() -> bool {
if (landmarkCandidates.empty()) return false;
if (landmarkCandidates.size() >= 3 &&
static_cast<int>(landmarkCandidates[0].size()) == expectedLmk8 &&
static_cast<int>(landmarkCandidates[1].size()) == expectedLmk16 &&
static_cast<int>(landmarkCandidates[2].size()) == expectedLmk32) {
landmarkOut->reserve(static_cast<size_t>(totalAnchors) * 10);
landmarkOut->insert(landmarkOut->end(), landmarkCandidates[0].begin(), landmarkCandidates[0].end());
landmarkOut->insert(landmarkOut->end(), landmarkCandidates[1].begin(), landmarkCandidates[1].end());
landmarkOut->insert(landmarkOut->end(), landmarkCandidates[2].begin(), landmarkCandidates[2].end());
return true;
}
for (const auto& c : landmarkCandidates) {
if (static_cast<int>(c.size()) == totalAnchors * 10) {
*landmarkOut = c;
return true;
}
}
return false;
};
const bool locOk = mergeLoc();
bool scoreOk = mergeScoreFrom2Class();
if (!scoreOk) {
scoreOk = mergeScoreFrom1Class();
}
const bool lmkOk = mergeLandmark();
if (!locOk || !scoreOk) {
LOGW("Unable to parse retina outputs, loc_candidates=%zu, conf2_candidates=%zu, conf1_candidates=%zu",
locCandidates.size(), confCandidates2.size(), scoreCandidates1.size());
LOGW("Unable to parse retina outputs, loc_candidates=%zu, conf2_candidates=%zu, conf1_candidates=%zu, lmk_candidates=%zu",
locCandidates.size(), confCandidates2.size(), scoreCandidates1.size(), landmarkCandidates.size());
return false;
}
if (!lmkOk) {
LOGW("Retina outputs parsed without landmarks, lmk_candidates=%zu", landmarkCandidates.size());
}
return true;
}
@@ -369,7 +405,8 @@ std::vector<float> RetinaFaceEngineRKNN::detect(
std::vector<float> loc;
std::vector<float> scores;
if (!parseRetinaOutputs(outputs.data(), &loc, &scores)) {
std::vector<float> landmarks;
if (!parseRetinaOutputs(outputs.data(), &loc, &scores, &landmarks)) {
rknn_outputs_release(ctx_, ioNum_.n_output, outputs.data());
return empty;
}
@@ -381,6 +418,7 @@ std::vector<float> RetinaFaceEngineRKNN::detect(
rknn_outputs_release(ctx_, ioNum_.n_output, outputs.data());
return empty;
}
const bool hasLandmarks = landmarks.size() >= anchorCount * 10;
std::vector<FaceCandidate> candidates;
candidates.reserve(anchorCount / 8);
@@ -405,6 +443,17 @@ std::vector<float> RetinaFaceEngineRKNN::detect(
box.right = std::min(static_cast<float>(width), (cx + w * 0.5f) * width);
box.bottom = std::min(static_cast<float>(height), (cy + h * 0.5f) * height);
box.score = score;
box.hasLandmarks = hasLandmarks;
if (hasLandmarks) {
for (int k = 0; k < 5; ++k) {
const float lx = priors[i].cx + landmarks[i * 10 + k * 2 + 0] * kVariance0 * priors[i].w;
const float ly = priors[i].cy + landmarks[i * 10 + k * 2 + 1] * kVariance0 * priors[i].h;
box.landmarks[k * 2 + 0] = lx * width;
box.landmarks[k * 2 + 1] = ly * height;
}
} else {
box.landmarks.fill(-1.0f);
}
candidates.push_back(box);
}
@@ -412,13 +461,16 @@ std::vector<float> RetinaFaceEngineRKNN::detect(
std::vector<FaceCandidate> filtered = nms(candidates, nmsThreshold_);
std::vector<float> result;
result.reserve(filtered.size() * 5);
result.reserve(filtered.size() * 15);
for (const auto& f : filtered) {
result.push_back(f.left);
result.push_back(f.top);
result.push_back(f.right);
result.push_back(f.bottom);
result.push_back(f.score);
for (int k = 0; k < 10; ++k) {
result.push_back(f.hasLandmarks ? f.landmarks[k] : -1.0f);
}
}
return result;
}

View File

@@ -4,6 +4,7 @@
#include <cstdint>
#include <string>
#include <vector>
#include <array>
#include "rknn_api.h"
@@ -30,6 +31,8 @@ private:
float right;
float bottom;
float score;
std::array<float, 10> landmarks;
bool hasLandmarks = false;
};
static size_t tensorElemCount(const rknn_tensor_attr& attr);
@@ -40,7 +43,8 @@ private:
bool parseRetinaOutputs(
rknn_output* outputs,
std::vector<float>* locOut,
std::vector<float>* scoreOut) const;
std::vector<float>* scoreOut,
std::vector<float>* landmarkOut) const;
rknn_context ctx_ = 0;
bool initialized_ = false;

View File

@@ -6,7 +6,10 @@ import android.graphics.Bitmap
import android.os.Bundle
import android.util.Log
import android.widget.Toast
import androidx.appcompat.app.AlertDialog
import androidx.camera.core.CameraSelector
import com.digitalperson.engine.RetinaFaceEngineRKNN
import com.digitalperson.face.FaceBox
import androidx.camera.core.ImageAnalysis
import androidx.camera.core.ImageProxy
import androidx.camera.core.Preview
@@ -28,13 +31,19 @@ import com.digitalperson.metrics.TraceManager
import com.digitalperson.metrics.TraceSession
import com.digitalperson.tts.TtsController
import com.digitalperson.interaction.DigitalHumanInteractionController
import com.digitalperson.data.DatabaseInitializer
import com.digitalperson.interaction.InteractionActionHandler
import com.digitalperson.interaction.InteractionState
import com.digitalperson.interaction.UserMemoryStore
import com.digitalperson.llm.LLMManager
import com.digitalperson.llm.LLMManagerCallback
import com.digitalperson.util.FileHelper
import com.digitalperson.data.AppDatabase
import com.digitalperson.data.entity.ChatMessage
import com.digitalperson.interaction.ConversationBufferMemory
import com.digitalperson.interaction.ConversationSummaryMemory
import java.io.File
import android.graphics.BitmapFactory
import org.json.JSONObject
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
@@ -92,6 +101,8 @@ class Live2DChatActivity : AppCompatActivity() {
private lateinit var faceDetectionPipeline: FaceDetectionPipeline
private lateinit var interactionController: DigitalHumanInteractionController
private lateinit var userMemoryStore: UserMemoryStore
private lateinit var conversationBufferMemory: ConversationBufferMemory
private lateinit var conversationSummaryMemory: ConversationSummaryMemory
private var facePipelineReady: Boolean = false
private var cameraProvider: ProcessCameraProvider? = null
private lateinit var cameraAnalyzerExecutor: ExecutorService
@@ -102,6 +113,8 @@ class Live2DChatActivity : AppCompatActivity() {
private val recentConversationLines = ArrayList<String>()
private var recentConversationDirty: Boolean = false
private var lastFacePresent: Boolean = false
private var lastFaceIdentityId: String? = null
private var lastFaceRecognizedName: String? = null
override fun onRequestPermissionsResult(
requestCode: Int,
@@ -152,12 +165,20 @@ class Live2DChatActivity : AppCompatActivity() {
speakingPlayerViewId = 0,
live2dViewId = R.id.live2d_view
)
cameraPreviewView = findViewById(R.id.camera_preview)
cameraPreviewView.implementationMode = PreviewView.ImplementationMode.COMPATIBLE
faceOverlayView = findViewById(R.id.face_overlay)
cameraAnalyzerExecutor = Executors.newSingleThreadExecutor()
// 初始化数据库
val databaseInitializer = DatabaseInitializer(applicationContext)
databaseInitializer.initialize()
userMemoryStore = UserMemoryStore(applicationContext)
val database = AppDatabase.getInstance(applicationContext)
conversationBufferMemory = ConversationBufferMemory(database)
conversationSummaryMemory = ConversationSummaryMemory(database, llmManager)
interactionController = DigitalHumanInteractionController(
scope = ioScope,
handler = object : InteractionActionHandler {
@@ -165,8 +186,10 @@ class Live2DChatActivity : AppCompatActivity() {
runOnUiThread {
uiManager.appendToUi("\n[State] $state\n")
}
Log.i(TAG_ACTIVITY, "\n[State] $state\n")
if (state == InteractionState.IDLE) {
analyzeUserProfileInIdleIfNeeded()
Log.i(TAG_ACTIVITY, "[analyze] done")
}
}
@@ -205,6 +228,8 @@ class Live2DChatActivity : AppCompatActivity() {
override fun loadLatestThought(): String? = userMemoryStore.getLatestThought()
override fun loadRecentThoughts(timeRangeMs: Long): List<String> = userMemoryStore.getRecentThoughts(timeRangeMs)
override fun addToChatHistory(role: String, content: String) {
appendConversationLine(role, content)
}
@@ -212,7 +237,14 @@ class Live2DChatActivity : AppCompatActivity() {
override fun addAssistantMessageToCloudHistory(content: String) {
cloudApiManager.addAssistantMessage(content)
}
}
override fun getRandomQuestion(faceId: String): String {
// 从数据库获取该faceId未被问过的问题
val question = userMemoryStore.getRandomUnansweredQuestion(faceId)
return question?.content ?: "你喜欢什么颜色呀?"
}
},
context = applicationContext
)
faceDetectionPipeline = FaceDetectionPipeline(
context = applicationContext,
@@ -220,10 +252,21 @@ class Live2DChatActivity : AppCompatActivity() {
faceOverlayView.updateResult(result)
},
onPresenceChanged = { present, faceIdentityId, recognizedName ->
if (present == lastFacePresent) return@FaceDetectionPipeline
lastFacePresent = present
Log.d(TAG_ACTIVITY, "present=$present, faceIdentityId=$faceIdentityId, recognized=$recognizedName")
interactionController.onFacePresenceChanged(present, faceIdentityId, recognizedName)
if (present != lastFacePresent) {
lastFacePresent = present
Log.d(TAG_ACTIVITY, "presence changed: present=$present")
interactionController.onFacePresenceChanged(present)
if (!present) {
lastFaceIdentityId = null
lastFaceRecognizedName = null
}
}
if (present && (faceIdentityId != lastFaceIdentityId || recognizedName != lastFaceRecognizedName)) {
lastFaceIdentityId = faceIdentityId
lastFaceRecognizedName = recognizedName
Log.d(TAG_ACTIVITY, "identity update: faceIdentityId=$faceIdentityId, recognized=$recognizedName")
interactionController.onFaceIdentityUpdated(faceIdentityId, recognizedName)
}
}
)
@@ -261,7 +304,7 @@ class Live2DChatActivity : AppCompatActivity() {
try {
val ttsModeSwitch = findViewById<android.widget.Switch>(R.id.tts_mode_switch)
ttsModeSwitch.isChecked = false // 默认使用本地TTS
ttsModeSwitch.isChecked = true // 默认使用本地TTS
ttsModeSwitch.setOnCheckedChangeListener { _, isChecked ->
ttsController.setUseQCloudTts(isChecked)
uiManager.showToast("TTS模式已切换到${if (isChecked) "腾讯云" else "本地"}")
@@ -297,10 +340,6 @@ class Live2DChatActivity : AppCompatActivity() {
vadManager = VadManager(this)
vadManager.setCallback(createVadCallback())
// 初始化本地 LLM用于 memory 状态)
initLLM()
interactionController.start()
// 检查是否需要下载模型
if (!FileHelper.isLocalLLMAvailable(this)) {
// 显示下载进度对话框
@@ -332,19 +371,31 @@ class Live2DChatActivity : AppCompatActivity() {
Log.i(AppConfig.TAG, "Local LLM is available, enabling local LLM switch")
// 显示本地 LLM 开关,并同步状态
uiManager.showLLMSwitch(false)
// 初始化本地 LLM
initLLM()
// 重新初始化 ConversationSummaryMemory
conversationSummaryMemory = ConversationSummaryMemory(database, llmManager)
// 启动交互控制器
interactionController.start()
// 下载完成后初始化其他组件
initializeOtherComponents()
}
} else {
Log.e(AppConfig.TAG, "Failed to download model files: $message")
uiManager.showToast("模型下载失败: $message", Toast.LENGTH_LONG)
// 显示错误弹窗,阻止应用继续运行
showModelDownloadErrorDialog(message)
}
// 下载完成后初始化其他组件
initializeOtherComponents()
}
}
)
} else {
// 模型已存在,直接初始化其他组件
// 模型已存在,初始化本地 LLM
initLLM()
// 重新初始化 ConversationSummaryMemory
conversationSummaryMemory = ConversationSummaryMemory(database, llmManager)
// 启动交互控制器
interactionController.start()
// 直接初始化其他组件
initializeOtherComponents()
// 显示本地 LLM 开关,并同步状态
uiManager.showLLMSwitch(false)
@@ -404,6 +455,304 @@ class Live2DChatActivity : AppCompatActivity() {
ioScope.launch {
asrManager.runAsrWorker()
}
// 测试人脸识别(延迟执行,确保所有组件初始化完成)
// ioScope.launch {
// kotlinx.coroutines.delay(10000) // 等待3秒确保所有组件初始化完成
// runOnUiThread {
// runFaceRecognitionTest()
// }
// }
}
/**
* 显示模型下载错误弹窗,阻止应用继续运行
*/
private fun showModelDownloadErrorDialog(errorMessage: String) {
AlertDialog.Builder(this)
.setTitle("模型下载失败")
.setMessage("本地 LLM 模型下载失败,应用无法正常运行。\n\n错误信息:$errorMessage\n\n请检查网络连接后重启应用。")
.setCancelable(false)
.setPositiveButton("退出应用") { _, _ ->
finish()
}
.show()
}
/**
* 运行人脸识别相似度测试
* 使用网络服务器上的测试图片
*/
private fun runFaceRecognitionTest() {
Log.i(TAG_ACTIVITY, "Starting face recognition test...")
uiManager.appendToUi("\n[测试] 开始人脸识别相似度测试...\n")
// 从服务器获取目录下的所有图片文件列表
ioScope.launch {
try {
val imageUrls = fetchImageListFromServer("http://192.168.1.19:5000/api/face_test_images")
if (imageUrls.isEmpty()) {
Log.e(AppConfig.TAG, "No images found in server directory")
runOnUiThread {
uiManager.appendToUi("\n[测试] 服务器目录中没有找到图片文件\n")
}
return@launch
}
Log.i(AppConfig.TAG, "[测试]Found ${imageUrls.size} images: $imageUrls")
runOnUiThread {
uiManager.appendToUi("\n[测试] 发现 ${imageUrls.size} 张测试图片\n")
}
val bitmaps = mutableListOf<Pair<String, Bitmap>>()
// 下载所有图片
for (url in imageUrls) {
Log.d(AppConfig.TAG, "[测试]Downloading test image: $url")
val bitmap = downloadImage(url)
if (bitmap != null) {
val fileName = url.substringAfterLast("/")
bitmaps.add(fileName to bitmap)
Log.d(AppConfig.TAG, "[测试]Downloaded image $fileName successfully")
} else {
Log.e(AppConfig.TAG, "[测试]Failed to download image: $url")
}
}
if (bitmaps.size < 2) {
Log.e(AppConfig.TAG, "[测试]Not enough test images downloaded")
runOnUiThread {
uiManager.appendToUi("\n[测试] 测试图片下载失败,无法进行测试\n")
}
return@launch
}
// 对所有图片两两比较
Log.i(AppConfig.TAG, "[测试]Starting similarity comparison for ${bitmaps.size} images...")
for (i in 0 until bitmaps.size) {
for (j in i + 1 until bitmaps.size) {
val (fileName1, bitmap1) = bitmaps[i]
val (fileName2, bitmap2) = bitmaps[j]
Log.d(AppConfig.TAG, "[测试]Comparing $fileName1 with $fileName2")
// 检测人脸
val face1 = detectFace(bitmap1)
val face2 = detectFace(bitmap2)
Log.d(AppConfig.TAG, "[测试]Face detection result: face1=$face1, face2=$face2")
if (face1 != null && face2 != null) {
// 计算相似度
Log.d(AppConfig.TAG, "[测试]Detected faces, calculating similarity...")
val similarity = faceDetectionPipeline?.getRecognizer()?.testSimilarityBetween(
bitmap1, face1, bitmap2, face2
)
val similarityRaw = faceDetectionPipeline?.getRecognizer()?.run {
val emb1 = extractEmbedding(bitmap1, face1)
val emb2 = extractEmbedding(bitmap2, face2)
if (emb1.isNotEmpty() && emb2.isNotEmpty()) {
var dot = 0f
var n1 = 0f
var n2 = 0f
for (k in emb1.indices) {
dot += emb1[k] * emb2[k]
n1 += emb1[k] * emb1[k]
n2 += emb2[k] * emb2[k]
}
if (n1 > 1e-12f && n2 > 1e-12f) {
(dot / (kotlin.math.sqrt(n1) * kotlin.math.sqrt(n2))).coerceIn(-1f, 1f)
} else -1f
} else -1f
}
Log.d(AppConfig.TAG, "[测试]Similarity result: $similarity")
if (similarity != null && similarity >= 0) {
val message = "[测试] 图片 $fileName1$fileName2 的相似度: $similarity"
val compareMessage = "[测试] 对齐后=$similarity, 原始裁剪=$similarityRaw"
Log.i(AppConfig.TAG, message)
Log.i(AppConfig.TAG, compareMessage)
runOnUiThread {
uiManager.appendToUi("\n$message\n")
uiManager.appendToUi("$compareMessage\n")
}
} else {
Log.w(AppConfig.TAG, "[测试]Failed to calculate similarity: $similarity")
runOnUiThread {
uiManager.appendToUi("\n[测试] 计算相似度失败: $similarity\n")
}
}
} else {
val message = "[测试] 无法检测到人脸: $fileName1$fileName2"
Log.w(AppConfig.TAG, message)
runOnUiThread {
uiManager.appendToUi("\n$message\n")
}
}
}
}
Log.i(AppConfig.TAG, "[测试]Face recognition test completed")
runOnUiThread {
uiManager.appendToUi("\n[测试] 人脸识别相似度测试完成\n")
}
} catch (e: Exception) {
Log.e(AppConfig.TAG, "Error during face recognition test: ${e.message}", e)
runOnUiThread {
uiManager.appendToUi("\n[测试] 测试过程中发生错误: ${e.message}\n")
}
}
}
}
/**
* 从服务器获取目录下的图片文件列表
* 调用 API 接口获取图片列表
*/
private fun fetchImageListFromServer(apiUrl: String): List<String> {
val imageUrls = mutableListOf<String>()
return try {
// 调用 API 接口
val connection = java.net.URL(apiUrl).openConnection() as java.net.HttpURLConnection
connection.requestMethod = "GET"
connection.connectTimeout = 10000
connection.readTimeout = 10000
connection.setRequestProperty("Accept", "application/json")
try {
val responseCode = connection.responseCode
if (responseCode == 200) {
connection.inputStream.use { input ->
val content = input.bufferedReader().use { it.readText() }
Log.d(AppConfig.TAG, "API response: $content")
// 解析 JSON 响应
val jsonObject = org.json.JSONObject(content)
val imagesArray = jsonObject.getJSONArray("images")
// 构建完整的图片 URL
val baseUrl = apiUrl.replace("/api/face_test_images", "/shared_files/face_test")
for (i in 0 until imagesArray.length()) {
val fileName = imagesArray.getString(i)
val fullUrl = "$baseUrl/$fileName"
imageUrls.add(fullUrl)
Log.d(AppConfig.TAG, "Added image URL: $fullUrl")
}
}
} else {
Log.e(AppConfig.TAG, "API request failed with code: $responseCode")
}
} finally {
connection.disconnect()
}
imageUrls
} catch (e: Exception) {
Log.e(AppConfig.TAG, "Failed to fetch image list: ${e.message}", e)
// 如果获取失败,返回空列表
emptyList()
}
}
/**
* 检查 URL 是否存在
*/
private fun checkUrlExists(url: String): Boolean {
return try {
val connection = java.net.URL(url).openConnection() as java.net.HttpURLConnection
connection.requestMethod = "HEAD"
connection.connectTimeout = 3000
connection.readTimeout = 3000
val responseCode = connection.responseCode
connection.disconnect()
responseCode == 200
} catch (e: Exception) {
false
}
}
/**
* 从网络下载图片
*/
private fun downloadImage(url: String): Bitmap? {
return try {
// 使用与大模型相同的下载方式
val tempFile = File(cacheDir, "temp_test_image_${System.currentTimeMillis()}.jpg")
val success = FileHelper.downloadTestImage(url, tempFile)
if (success && tempFile.exists()) {
val bitmap = BitmapFactory.decodeFile(tempFile.absolutePath)
tempFile.delete() // 删除临时文件
bitmap
} else {
Log.e(AppConfig.TAG, "Failed to download image: $url")
null
}
} catch (e: Exception) {
Log.e(AppConfig.TAG, "Failed to download image: ${e.message}", e)
null
}
}
/**
* 检测图片中的人脸
*/
private fun detectFace(bitmap: Bitmap): FaceBox? {
Log.d(AppConfig.TAG, "[测试]Detecting face in bitmap: ${bitmap.width}x${bitmap.height}")
return try {
val engine = RetinaFaceEngineRKNN()
Log.d(AppConfig.TAG, "[测试]Initializing RetinaFace engine...")
if (engine.initialize(applicationContext)) {
Log.d(AppConfig.TAG, "[测试]RetinaFace engine initialized successfully")
val raw = engine.detect(bitmap)
Log.d(AppConfig.TAG, "[测试]Face detection result: ${raw.joinToString(", ")}")
engine.release()
if (raw.isNotEmpty()) {
val stride = when {
raw.size % 15 == 0 -> 15
raw.size % 5 == 0 -> 5
else -> 0
}
Log.d(AppConfig.TAG, "[测试]Stride: $stride, raw size: ${raw.size}")
if (stride > 0) {
val faceCount = raw.size / stride
Log.d(AppConfig.TAG, "[测试]Detected $faceCount faces")
if (faceCount > 0) {
val i = 0
val lm = if (stride >= 15) raw.copyOfRange(i + 5, i + 15) else null
val hasLm = lm?.all { it >= 0f } == true
val faceBox = FaceBox(
left = raw[i],
top = raw[i + 1],
right = raw[i + 2],
bottom = raw[i + 3],
score = raw[i + 4],
hasLandmarks = hasLm,
landmarks = if (hasLm) lm else null
)
Log.d(AppConfig.TAG, "[测试]Created face box: $faceBox")
return faceBox
}
}
} else {
Log.w(AppConfig.TAG, "[测试]No faces detected in bitmap")
}
} else {
Log.e(AppConfig.TAG, "[测试]Failed to initialize RetinaFace engine")
}
null
} catch (e: Exception) {
Log.e(AppConfig.TAG, "[测试]Failed to detect face: ${e.message}", e)
null
}
}
private fun createAsrCallback() = object : AsrManager.AsrCallback {
@@ -557,7 +906,6 @@ class Live2DChatActivity : AppCompatActivity() {
try { cameraAnalyzerExecutor.shutdown() } catch (_: Throwable) {}
try { ttsController.release() } catch (_: Throwable) {}
try { llmManager?.destroy() } catch (_: Throwable) {}
try { userMemoryStore.close() } catch (_: Throwable) {}
try { uiManager.release() } catch (_: Throwable) {}
try { audioProcessor.release() } catch (_: Throwable) {}
}
@@ -686,6 +1034,9 @@ class Live2DChatActivity : AppCompatActivity() {
uiManager.appendToUi("\n[LOG] 已打断TTS播放\n")
}
// 通知状态机用户开始说话,立即进入对话状态
interactionController.onUserStartSpeaking()
if (!audioProcessor.initMicrophone(micPermissions, AppConfig.REQUEST_RECORD_AUDIO_PERMISSION)) {
uiManager.showToast("麦克风初始化失败/无权限")
return
@@ -844,6 +1195,14 @@ class Live2DChatActivity : AppCompatActivity() {
recentConversationLines.removeAt(0)
}
recentConversationDirty = true
// 同时添加到对话记忆中
val memoryRole = if (role == "用户") "user" else "assistant"
conversationBufferMemory.addMessage(activeUserId, memoryRole, text.trim())
// 定期保存到数据库
if (recentConversationLines.size % 5 == 0) {
conversationBufferMemory.saveToDatabase(activeUserId)
}
}
private fun buildCloudPromptWithUserProfile(userText: String): String {
@@ -854,6 +1213,13 @@ class Live2DChatActivity : AppCompatActivity() {
profile.gender?.takeIf { it.isNotBlank() }?.let { profileParts.add("性别:$it") }
profile.hobbies?.takeIf { it.isNotBlank() }?.let { profileParts.add("爱好:$it") }
profile.profileSummary?.takeIf { it.isNotBlank() }?.let { profileParts.add("画像:$it") }
// 添加对话摘要
val conversationSummary = conversationSummaryMemory.getSummary(activeUserId)
if (conversationSummary.isNotBlank()) {
profileParts.add("对话摘要:$conversationSummary")
}
if (profileParts.isEmpty()) return userText
return buildString {
append("[用户画像]\n")
@@ -864,9 +1230,26 @@ class Live2DChatActivity : AppCompatActivity() {
}
private fun analyzeUserProfileInIdleIfNeeded() {
if (!recentConversationDirty || !activeUserId.startsWith("face_")) return
if (recentConversationLines.isEmpty()) return
val dialogue = recentConversationLines.joinToString("\n")
if (!activeUserId.startsWith("face_")) {
Log.d(AppConfig.TAG, "faceID is not face_")
return
}
// 使用 conversationBufferMemory 获取对话消息
val messages = conversationBufferMemory.getMessages(activeUserId)
Log.d(AppConfig.TAG, "msg is empty? ${messages.isEmpty()}")
val hasUserMessages = messages.any { it.role == "user" }
Log.d(AppConfig.TAG, "msg has user messages? $hasUserMessages")
if (messages.isEmpty() || !hasUserMessages) return
// 生成对话摘要
conversationSummaryMemory.generateSummary(activeUserId, messages) { summary ->
Log.d(AppConfig.TAG, "Generated conversation summary for $activeUserId: $summary")
}
// 使用 conversationBufferMemory 的对话记录提取用户信息
val dialogue = messages.joinToString("\n") { "${it.role}: ${it.content}" }
requestLocalProfileExtraction(dialogue) { raw ->
try {
val json = parseFirstJsonObject(raw)
@@ -879,7 +1262,10 @@ class Live2DChatActivity : AppCompatActivity() {
userMemoryStore.updateDisplayName(activeUserId, name)
}
userMemoryStore.updateProfile(activeUserId, age, gender, hobbies, summary)
recentConversationDirty = false
// 清空已处理的对话记录
conversationBufferMemory.clear(activeUserId)
runOnUiThread {
uiManager.appendToUi("\n[Memory] 已更新用户画像: $activeUserId\n")
}
@@ -901,7 +1287,7 @@ class Live2DChatActivity : AppCompatActivity() {
Log.i(TAG_LLM, "Routing profile extraction to LOCAL")
local.generateResponseWithSystem(
"你是信息抽取器。仅输出JSON对象不要其他文字。字段为name,age,gender,hobbies,summary。",
"请从以下对话提取用户信息,未知填空字符串:\n$dialogue"
"请从以下对话提取用户信息,未知填空字符串,注意不需要\n$dialogue"
)
} catch (e: Exception) {
pendingLocalProfileCallback = null

View File

@@ -50,7 +50,7 @@ object AppConfig {
object FaceRecognition {
const val MODEL_DIR = "Insightface"
const val MODEL_NAME = "ms1mv3_arcface_r18.rknn"
const val SIMILARITY_THRESHOLD = 0.5f
const val SIMILARITY_THRESHOLD = 0.6f
const val GREETING_COOLDOWN_MS = 6000L
}

View File

@@ -0,0 +1,54 @@
package com.digitalperson.data
import android.content.Context
import androidx.room.Database
import androidx.room.Room
import androidx.room.RoomDatabase
import com.digitalperson.data.dao.QuestionDao
import com.digitalperson.data.dao.UserAnswerDao
import com.digitalperson.data.dao.UserMemoryDao
import com.digitalperson.data.dao.ChatMessageDao
import com.digitalperson.data.dao.ConversationSummaryDao
import com.digitalperson.data.entity.Question
import com.digitalperson.data.entity.UserAnswer
import com.digitalperson.data.entity.UserMemory
import com.digitalperson.data.entity.ChatMessageEntity
import com.digitalperson.data.entity.ConversationSummaryEntity
@Database(
entities = [UserMemory::class, Question::class, UserAnswer::class, ChatMessageEntity::class, ConversationSummaryEntity::class],
version = 4,
exportSchema = false
)
abstract class AppDatabase : RoomDatabase() {
abstract fun userMemoryDao(): UserMemoryDao
abstract fun questionDao(): QuestionDao
abstract fun userAnswerDao(): UserAnswerDao
abstract fun chatMessageDao(): ChatMessageDao
abstract fun conversationSummaryDao(): ConversationSummaryDao
companion object {
private const val DATABASE_NAME = "digital_human.db"
@Volatile
private var INSTANCE: AppDatabase? = null
fun getInstance(context: Context): AppDatabase {
return INSTANCE ?: synchronized(this) {
INSTANCE ?: buildDatabase(context).also {
INSTANCE = it
}
}
}
private fun buildDatabase(context: Context): AppDatabase {
return Room.databaseBuilder(
context.applicationContext,
AppDatabase::class.java,
DATABASE_NAME
)
.fallbackToDestructiveMigration()
.build()
}
}
}

View File

@@ -0,0 +1,86 @@
package com.digitalperson.data
import android.content.Context
import com.digitalperson.data.entity.Question
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.launch
import org.json.JSONObject
import java.io.IOException
import java.io.InputStream
class DatabaseInitializer(private val context: Context) {
fun initialize() {
// 在后台线程初始化数据库
GlobalScope.launch(Dispatchers.IO) {
val db = AppDatabase.getInstance(context)
val questionDao = db.questionDao()
// 检查是否已有问题数据
if (questionDao.getAllSubjects().isEmpty()) {
// 从JSON文件导入问题
importQuestionsFromJson()
}
}
}
private fun importQuestionsFromJson() {
try {
// 读取JSON文件
val inputStream: InputStream = context.assets.open("questions.json")
val size: Int = inputStream.available()
val buffer = ByteArray(size)
inputStream.read(buffer)
inputStream.close()
val jsonString = String(buffer, Charsets.UTF_8)
val jsonObject = JSONObject(jsonString)
val questionsArray = jsonObject.getJSONArray("questions")
val db = AppDatabase.getInstance(context)
val questionDao = db.questionDao()
// 插入问题
for (i in 0 until questionsArray.length()) {
val questionJson = questionsArray.getJSONObject(i)
val content = questionJson.getString("content")
val answer = if (questionJson.isNull("answer")) null else questionJson.getString("answer")
val subject = if (questionJson.isNull("subject")) null else questionJson.getString("subject")
val grade = if (questionJson.isNull("grade")) null else questionJson.getInt("grade")
val difficulty = if (questionJson.isNull("difficulty")) 1 else questionJson.getInt("difficulty")
val question = Question(
id = 0,
content = content,
answer = answer,
subject = subject,
grade = grade,
difficulty = difficulty,
createdAt = System.currentTimeMillis()
)
questionDao.insert(question)
}
println("Successfully imported ${questionsArray.length()} questions from JSON")
} catch (e: IOException) {
e.printStackTrace()
println("Error reading questions.json: ${e.message}")
} catch (e: Exception) {
e.printStackTrace()
println("Error importing questions: ${e.message}")
}
}
// 强制重新导入JSON数据
fun forceImportQuestions() {
GlobalScope.launch(Dispatchers.IO) {
val db = AppDatabase.getInstance(context)
// 清空数据库
db.clearAllTables()
// 重新导入
importQuestionsFromJson()
}
}
}

View File

@@ -0,0 +1,47 @@
package com.digitalperson.data.dao
import androidx.room.Dao
import androidx.room.Insert
import androidx.room.Query
import com.digitalperson.data.entity.ChatMessageEntity
import com.digitalperson.data.entity.ConversationSummaryEntity
/**
* 对话消息DAO接口
*/
@Dao
interface ChatMessageDao {
@Insert
fun insert(message: ChatMessageEntity): Long
@Query("SELECT * FROM chat_messages WHERE user_id = :userId ORDER BY timestamp ASC")
fun getMessagesByUserId(userId: String): List<ChatMessageEntity>
@Query("""
SELECT * FROM chat_messages
WHERE user_id = :userId AND timestamp >= :startTime
ORDER BY timestamp ASC
""")
fun getMessagesByUserIdAndTime(userId: String, startTime: Long): List<ChatMessageEntity>
@Query("DELETE FROM chat_messages WHERE user_id = :userId")
fun clearMessagesByUserId(userId: String)
@Query("SELECT COUNT(*) FROM chat_messages WHERE user_id = :userId")
fun getMessageCountByUserId(userId: String): Int
}
/**
* 对话摘要DAO接口
*/
@Dao
interface ConversationSummaryDao {
@Insert
fun insert(summary: ConversationSummaryEntity): Long
@Query("SELECT * FROM conversation_summaries WHERE user_id = :userId ORDER BY created_at DESC LIMIT 1")
fun getLatestSummaryByUserId(userId: String): ConversationSummaryEntity?
@Query("DELETE FROM conversation_summaries WHERE user_id = :userId")
fun clearSummariesByUserId(userId: String)
}

View File

@@ -0,0 +1,44 @@
package com.digitalperson.data.dao
import androidx.room.Dao
import androidx.room.Insert
import androidx.room.Query
import com.digitalperson.data.entity.Question
@Dao
interface QuestionDao {
@Insert
fun insert(question: Question): Long
@Query("SELECT * FROM questions WHERE subject = :subject ORDER BY difficulty")
fun getQuestionsBySubject(subject: String): List<Question>
@Query("SELECT * FROM questions WHERE grade = :grade ORDER BY difficulty")
fun getQuestionsByGrade(grade: Int): List<Question>
@Query("SELECT * FROM questions WHERE subject = :subject AND grade = :grade ORDER BY RANDOM() LIMIT 1")
fun getRandomQuestionBySubjectAndGrade(subject: String, grade: Int): Question?
@Query("SELECT * FROM questions WHERE subject = :subject ORDER BY RANDOM() LIMIT 1")
fun getRandomQuestionBySubject(subject: String): Question?
@Query("SELECT * FROM questions ORDER BY RANDOM() LIMIT 1")
fun getRandomQuestion(): Question?
@Query("""
SELECT q.*
FROM questions q
WHERE NOT EXISTS (
SELECT 1 FROM user_answers ua
WHERE ua.question_id = q.id AND ua.user_id = :userId
)
ORDER BY RANDOM() LIMIT 1
""")
fun getRandomUnansweredQuestion(userId: String): Question?
@Query("SELECT * FROM questions WHERE id = :questionId")
fun getQuestionById(questionId: Long): Question?
@Query("SELECT DISTINCT subject FROM questions WHERE subject IS NOT NULL ORDER BY subject")
fun getAllSubjects(): List<String>
}

View File

@@ -0,0 +1,52 @@
package com.digitalperson.data.dao
import androidx.room.Dao
import androidx.room.Insert
import androidx.room.Query
import androidx.room.ColumnInfo
import com.digitalperson.data.entity.UserAnswer
@Dao
interface UserAnswerDao {
@Insert
fun insert(userAnswer: UserAnswer): Long
@Query("SELECT * FROM user_answers WHERE user_id = :userId ORDER BY answered_at DESC LIMIT :limit")
fun getUserAnswers(userId: String, limit: Int = 50): List<UserAnswer>
@Query("""
SELECT ua.id, ua.user_id as userId, ua.question_id as questionId, ua.user_answer as userAnswer, ua.evaluation, ua.answered_at as answeredAt, q.content as question_content, q.answer as question_answer
FROM user_answers ua
JOIN questions q ON ua.question_id = q.id
WHERE ua.user_id = :userId AND ua.evaluation = 'INCORRECT'
ORDER BY ua.answered_at DESC
LIMIT :limit
""")
fun getIncorrectAnswers(userId: String, limit: Int = 10): List<UserAnswerWithQuestion>
@Query("""
SELECT evaluation, COUNT(*) as count
FROM user_answers
WHERE user_id = :userId
GROUP BY evaluation
""")
fun getAnswerStatistics(userId: String): List<AnswerStatistic>
}
// 用于存储答案统计结果
data class AnswerStatistic(
val evaluation: String?,
val count: Int
)
// 用于关联查询用户答案和问题
data class UserAnswerWithQuestion(
val id: Long,
val userId: String,
val questionId: Long,
val userAnswer: String?,
val evaluation: String?,
val answeredAt: Long,
@ColumnInfo(name = "question_content") val question_content: String,
@ColumnInfo(name = "question_answer") val question_answer: String?
)

View File

@@ -0,0 +1,31 @@
package com.digitalperson.data.dao
import androidx.room.Dao
import androidx.room.Insert
import androidx.room.OnConflictStrategy
import androidx.room.Query
import com.digitalperson.data.entity.UserMemory
@Dao
interface UserMemoryDao {
@Insert(onConflict = OnConflictStrategy.REPLACE)
fun insert(userMemory: UserMemory)
@Query("SELECT * FROM user_memory WHERE user_id = :userId")
fun getUserMemory(userId: String): UserMemory?
@Query("SELECT last_thought FROM user_memory WHERE last_thought IS NOT NULL AND last_thought != '' ORDER BY last_seen_at DESC LIMIT 1")
fun getLatestThought(): String?
@Query("UPDATE user_memory SET last_thought = :thought, last_seen_at = :timestamp WHERE user_id = :userId")
fun updateThought(userId: String, thought: String, timestamp: Long)
@Query("UPDATE user_memory SET display_name = :displayName, last_seen_at = :timestamp WHERE user_id = :userId")
fun updateDisplayName(userId: String, displayName: String, timestamp: Long)
@Query("UPDATE user_memory SET age = :age, gender = :gender, hobbies = :hobbies, profile_summary = :summary, last_seen_at = :timestamp WHERE user_id = :userId")
fun updateProfile(userId: String, age: String?, gender: String?, hobbies: String?, summary: String?, timestamp: Long)
@Query("SELECT * FROM user_memory WHERE last_thought IS NOT NULL AND last_thought != '' AND last_seen_at >= :cutoffTime ORDER BY last_seen_at DESC")
fun getRecentThoughts(cutoffTime: Long): List<UserMemory>
}

View File

@@ -0,0 +1,44 @@
package com.digitalperson.data.entity
import androidx.room.Entity
import androidx.room.PrimaryKey
import androidx.room.ColumnInfo
/**
* 对话消息数据类
* @param id 消息ID
* @param userId 用户ID如 face_1
* @param role 角色user 或 assistant
* @param content 消息内容
* @param timestamp 时间戳
*/
data class ChatMessage(
val id: Long = 0,
val userId: String,
val role: String,
val content: String,
val timestamp: Long = System.currentTimeMillis()
)
/**
* 对话消息数据库实体
*/
@Entity(tableName = "chat_messages")
data class ChatMessageEntity(
@PrimaryKey(autoGenerate = true) val id: Long,
@ColumnInfo(name = "user_id") val userId: String,
@ColumnInfo(name = "role") val role: String,
@ColumnInfo(name = "content") val content: String,
@ColumnInfo(name = "timestamp") val timestamp: Long
)
/**
* 对话摘要数据库实体
*/
@Entity(tableName = "conversation_summaries")
data class ConversationSummaryEntity(
@PrimaryKey(autoGenerate = true) val id: Long,
@ColumnInfo(name = "user_id") val userId: String,
@ColumnInfo(name = "summary") val summary: String,
@ColumnInfo(name = "created_at") val createdAt: Long
)

View File

@@ -0,0 +1,16 @@
package com.digitalperson.data.entity
import androidx.room.Entity
import androidx.room.PrimaryKey
import androidx.room.ColumnInfo
@Entity(tableName = "questions")
data class Question(
@PrimaryKey(autoGenerate = true) val id: Long,
@ColumnInfo(name = "content") val content: String,
@ColumnInfo(name = "answer") val answer: String?,
@ColumnInfo(name = "subject") val subject: String?,
@ColumnInfo(name = "grade") val grade: Int?,
@ColumnInfo(name = "difficulty") val difficulty: Int,
@ColumnInfo(name = "created_at") val createdAt: Long,
)

View File

@@ -0,0 +1,15 @@
package com.digitalperson.data.entity
import androidx.room.Entity
import androidx.room.PrimaryKey
import androidx.room.ColumnInfo
@Entity(tableName = "user_answers")
data class UserAnswer(
@PrimaryKey(autoGenerate = true) val id: Long,
@ColumnInfo(name = "user_id") val userId: String,
@ColumnInfo(name = "question_id") val questionId: Long,
@ColumnInfo(name = "user_answer") val userAnswer: String?,
@ColumnInfo(name = "evaluation") val evaluation: String?,
@ColumnInfo(name = "answered_at") val answeredAt: Long,
)

View File

@@ -0,0 +1,19 @@
package com.digitalperson.data.entity
import androidx.room.Entity
import androidx.room.PrimaryKey
import androidx.room.ColumnInfo
@Entity(tableName = "user_memory")
data class UserMemory(
@PrimaryKey @ColumnInfo(name = "user_id") val userId: String,
@ColumnInfo(name = "display_name") val displayName: String?,
@ColumnInfo(name = "last_seen_at") val lastSeenAt: Long,
@ColumnInfo(name = "age") val age: String?,
@ColumnInfo(name = "gender") val gender: String?,
@ColumnInfo(name = "hobbies") val hobbies: String?,
@ColumnInfo(name = "preferences") val preferences: String?,
@ColumnInfo(name = "last_topics") val lastTopics: String?,
@ColumnInfo(name = "last_thought") val lastThought: String?,
@ColumnInfo(name = "profile_summary") val profileSummary: String?,
)

View File

@@ -5,8 +5,8 @@ import android.graphics.Bitmap
import android.util.Log
import com.digitalperson.config.AppConfig
import com.digitalperson.engine.RetinaFaceEngineRKNN
import java.util.ArrayDeque
import java.util.concurrent.atomic.AtomicBoolean
import kotlin.math.abs
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.SupervisorJob
@@ -20,6 +20,8 @@ data class FaceBox(
val right: Float,
val bottom: Float,
val score: Float,
val hasLandmarks: Boolean = false,
val landmarks: FloatArray? = null,
)
data class FaceDetectionResult(
@@ -29,24 +31,27 @@ data class FaceDetectionResult(
)
class FaceDetectionPipeline(
private val context: Context,
context: Context,
private val onResult: (FaceDetectionResult) -> Unit,
private val onPresenceChanged: (present: Boolean, faceIdentityId: String?, recognizedName: String?) -> Unit,
) {
private val appContext = context.applicationContext
private val engine = RetinaFaceEngineRKNN()
private val recognizer = FaceRecognizer(context)
private val recognizer = FaceRecognizer(appContext)
private val userMemoryStore = com.digitalperson.interaction.UserMemoryStore(appContext)
private val scope = CoroutineScope(SupervisorJob() + Dispatchers.Default)
private val frameInFlight = AtomicBoolean(false)
private val initialized = AtomicBoolean(false)
private var trackFace: FaceBox? = null
private var trackId: Long = 0
private var trackStableSinceMs: Long = 0
private var analyzedTrackId: Long = -1
private var lastFaceIdentityId: String? = null
private var lastRecognizedName: String? = null
private val fusionEmbeddings = ArrayDeque<FloatArray>()
private val fusionQualities = ArrayDeque<Float>()
fun initialize(): Boolean {
val detectorOk = engine.initialize(context)
val detectorOk = engine.initialize(appContext)
val recognizerOk = recognizer.initialize()
val ok = detectorOk && recognizerOk
initialized.set(ok)
@@ -69,11 +74,28 @@ class FaceDetectionPipeline(
val width = bitmap.width
val height = bitmap.height
val raw = engine.detect(bitmap)
if (raw.isEmpty()) {
withContext(Dispatchers.Main) {
onPresenceChanged(false, null, null)
onResult(FaceDetectionResult(width, height, emptyList()))
}
return@launch
}
val faceCount = raw.size / 5
val stride = when {
raw.size % 15 == 0 -> 15
raw.size % 5 == 0 -> 5
else -> 0
}
if (stride == 0) {
Log.w(AppConfig.TAG, "[Face] invalid detector output size=${raw.size}")
}
val faceCount = if (stride == 0) 0 else raw.size / stride
val faces = ArrayList<FaceBox>(faceCount)
var i = 0
while (i + 4 < raw.size) {
while (i + 4 < raw.size && stride > 0) {
val lm = if (stride >= 15) raw.copyOfRange(i + 5, i + 15) else null
val hasLm = lm?.all { it >= 0f } == true
faces.add(
FaceBox(
left = raw[i],
@@ -81,9 +103,11 @@ class FaceDetectionPipeline(
right = raw[i + 2],
bottom = raw[i + 3],
score = raw[i + 4],
hasLandmarks = hasLm,
landmarks = if (hasLm) lm else null,
)
)
i += 5
i += stride
}
// 过滤太小的人脸
val minFaceSize = 50 // 最小人脸大小(像素)
@@ -117,10 +141,11 @@ class FaceDetectionPipeline(
val now = System.currentTimeMillis()
if (faces.isEmpty()) {
trackFace = null
trackStableSinceMs = 0
analyzedTrackId = -1
lastFaceIdentityId = null
lastRecognizedName = null
fusionEmbeddings.clear()
fusionQualities.clear()
return
}
@@ -128,51 +153,132 @@ class FaceDetectionPipeline(
val prev = trackFace
if (prev == null || iou(prev, primary) < AppConfig.Face.TRACK_IOU_THRESHOLD) {
trackId += 1
trackStableSinceMs = now
analyzedTrackId = -1
lastFaceIdentityId = null
lastRecognizedName = null
fusionEmbeddings.clear()
fusionQualities.clear()
}
trackFace = primary
val stableMs = now - trackStableSinceMs
val frontal = isFrontal(primary, bitmap.width, bitmap.height)
if (stableMs < AppConfig.Face.STABLE_MS || !frontal) {
val quality = estimateQuality(primary, bitmap.width, bitmap.height)
// Log.e(AppConfig.TAG, "estimateQuality: ${quality}")
// 识别尽早进行:只要人脸清晰且朝向满足,就先完成 faceId 解析;问候稳定时机由状态机控制。
if (!frontal || quality < 0.65f) {
return
}
if (analyzedTrackId == trackId) {
return
}
val match = recognizer.resolveIdentity(bitmap, primary)
val embedding = recognizer.extractEmbeddingAligned(bitmap, primary)
if (embedding.isEmpty()) return
addEmbeddingForFusion(embedding, quality)
val fused = fuseEmbeddings() ?: return
if (fusionEmbeddings.size < 4) {
return
}
val match = recognizer.resolveIdentityFromEmbedding(fused)
analyzedTrackId = trackId
lastFaceIdentityId = match.matchedId?.let { "face_$it" }
lastRecognizedName = match.matchedName
// 从 user_memory 表中获取名字
lastRecognizedName = lastFaceIdentityId?.let { userId ->
userMemoryStore.getMemory(userId)?.displayName
}
Log.i(
AppConfig.TAG,
"[Face] stable track=$trackId faceId=${lastFaceIdentityId} matched=${match.matchedName} score=${match.similarity}"
"[Face] stable track=$trackId faceId=${lastFaceIdentityId} matched=${lastRecognizedName} score=${match.similarity} fusionN=${fusionEmbeddings.size}"
)
}
private fun addEmbeddingForFusion(embedding: FloatArray, quality: Float) {
if (fusionEmbeddings.size >= 8) {
fusionEmbeddings.removeFirst()
fusionQualities.removeFirst()
}
fusionEmbeddings.addLast(embedding)
fusionQualities.addLast(quality.coerceIn(0.1f, 1f))
}
private fun fuseEmbeddings(): FloatArray? {
if (fusionEmbeddings.isEmpty()) return null
val dim = fusionEmbeddings.first().size
if (dim == 0) return null
val out = FloatArray(dim)
var wsum = 0f
val embIter = fusionEmbeddings.iterator()
val qIter = fusionQualities.iterator()
while (embIter.hasNext() && qIter.hasNext()) {
val e = embIter.next()
val w = qIter.next()
if (e.size != dim) continue
for (i in 0 until dim) out[i] += e[i] * w
wsum += w
}
if (wsum <= 1e-6f) return null
for (i in out.indices) out[i] /= wsum
var n = 0f
for (v in out) n += v * v
val norm = kotlin.math.sqrt(n.coerceAtLeast(1e-12f))
for (i in out.indices) out[i] /= norm
return out
}
private fun isFrontal(face: FaceBox, frameW: Int, frameH: Int): Boolean {
if (!face.hasLandmarks || face.landmarks == null || face.landmarks.size < 10) {
return false
}
val lm = face.landmarks
val w = face.right - face.left
val h = face.bottom - face.top
if (w < AppConfig.Face.FRONTAL_MIN_FACE_SIZE || h < AppConfig.Face.FRONTAL_MIN_FACE_SIZE) {
return false
}
val aspectDiff = abs((w / h) - 1f)
if (aspectDiff > AppConfig.Face.FRONTAL_MAX_ASPECT_DIFF) {
return false
}
val leftEyeX = lm[0]
val leftEyeY = lm[1]
val rightEyeX = lm[2]
val rightEyeY = lm[3]
val noseX = lm[4]
val noseY = lm[5]
val mouthLeftY = lm[7]
val mouthRightY = lm[9]
if (rightEyeX <= leftEyeX) return false
if (noseY <= (leftEyeY + rightEyeY) * 0.5f) return false
if ((mouthLeftY + mouthRightY) * 0.5f <= noseY) return false
val eyeDy = kotlin.math.abs(leftEyeY - rightEyeY) / h
if (eyeDy > 0.12f) return false
val eyeMidX = (leftEyeX + rightEyeX) * 0.5f
val noseOffset = kotlin.math.abs(noseX - eyeMidX) / w
if (noseOffset > 0.12f) return false
val cx = (face.left + face.right) * 0.5f
val cy = (face.top + face.bottom) * 0.5f
val minX = frameW * 0.15f
val maxX = frameW * 0.85f
val minY = frameH * 0.15f
val maxY = frameH * 0.85f
val minX = frameW * 0.2f
val maxX = frameW * 0.8f
val minY = frameH * 0.2f
val maxY = frameH * 0.8f
return cx in minX..maxX && cy in minY..maxY
}
private fun estimateQuality(face: FaceBox, frameW: Int, frameH: Int): Float {
val lm = face.landmarks ?: return 0f
if (!face.hasLandmarks || lm.size < 10) return 0f
val w = (face.right - face.left).coerceAtLeast(1f)
val h = (face.bottom - face.top).coerceAtLeast(1f)
val areaRatio = ((w * h) / (frameW * frameH).toFloat()).coerceIn(0f, 1f)
val eyeDy = kotlin.math.abs(lm[1] - lm[3]) / h
val eyeMidX = (lm[0] + lm[2]) * 0.5f
val noseOffset = kotlin.math.abs(lm[4] - eyeMidX) / w
val geom = (1f - (eyeDy / 0.2f)).coerceIn(0f, 1f) * (1f - (noseOffset / 0.2f)).coerceIn(0f, 1f)
val sizeScore = (areaRatio / 0.08f).coerceIn(0f, 1f)
return (face.score.coerceIn(0f, 1f) * 0.45f) + (sizeScore * 0.2f) + (geom * 0.35f)
}
private fun iou(a: FaceBox, b: FaceBox): Float {
val left = maxOf(a.left, b.left)
val top = maxOf(a.top, b.top)
@@ -193,4 +299,8 @@ class FaceDetectionPipeline(
recognizer.release()
initialized.set(false)
}
fun getRecognizer(): FaceRecognizer {
return recognizer
}
}

View File

@@ -11,7 +11,6 @@ import java.nio.ByteOrder
data class FaceProfile(
val id: Long,
val name: String?,
val embedding: FloatArray,
)
@@ -21,7 +20,6 @@ class FaceFeatureStore(context: Context) : SQLiteOpenHelper(context, DB_NAME, nu
"""
CREATE TABLE IF NOT EXISTS face_profiles (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT,
embedding BLOB NOT NULL,
updated_at INTEGER NOT NULL
)
@@ -37,16 +35,14 @@ class FaceFeatureStore(context: Context) : SQLiteOpenHelper(context, DB_NAME, nu
fun loadAllProfiles(): List<FaceProfile> {
val db = readableDatabase
val list = ArrayList<FaceProfile>()
db.rawQuery("SELECT id, name, embedding FROM face_profiles", null).use { c ->
db.rawQuery("SELECT id, embedding FROM face_profiles", null).use { c ->
val idIdx = c.getColumnIndexOrThrow("id")
val nameIdx = c.getColumnIndexOrThrow("name")
val embIdx = c.getColumnIndexOrThrow("embedding")
while (c.moveToNext()) {
val embBlob = c.getBlob(embIdx) ?: continue
list.add(
FaceProfile(
id = c.getLong(idIdx),
name = c.getString(nameIdx),
embedding = blobToFloatArray(embBlob),
)
)
@@ -55,10 +51,8 @@ class FaceFeatureStore(context: Context) : SQLiteOpenHelper(context, DB_NAME, nu
return list
}
fun insertProfile(name: String?, embedding: FloatArray): Long {
val safeName = name?.takeIf { it.isNotBlank() }
fun insertProfile(embedding: FloatArray): Long {
val values = ContentValues().apply {
put("name", safeName)
put("embedding", floatArrayToBlob(embedding))
put("updated_at", System.currentTimeMillis())
}
@@ -68,7 +62,7 @@ class FaceFeatureStore(context: Context) : SQLiteOpenHelper(context, DB_NAME, nu
values,
SQLiteDatabase.CONFLICT_NONE
)
Log.i(AppConfig.TAG, "[FaceFeatureStore] insertProfile name='$safeName' rowId=$rowId dim=${embedding.size}")
Log.i(AppConfig.TAG, "[FaceFeatureStore] insertProfile rowId=$rowId dim=${embedding.size}")
return rowId
}
@@ -88,6 +82,6 @@ class FaceFeatureStore(context: Context) : SQLiteOpenHelper(context, DB_NAME, nu
companion object {
private const val DB_NAME = "face_feature.db"
private const val DB_VERSION = 2
private const val DB_VERSION = 3
}
}

View File

@@ -2,6 +2,8 @@ package com.digitalperson.face
import android.content.Context
import android.graphics.Bitmap
import android.graphics.Canvas
import android.graphics.Matrix
import android.util.Log
import com.digitalperson.config.AppConfig
import com.digitalperson.engine.ArcFaceEngineRKNN
@@ -9,7 +11,6 @@ import kotlin.math.sqrt
data class FaceRecognitionResult(
val matchedId: Long?,
val matchedName: String?,
val similarity: Float,
val embeddingDim: Int,
)
@@ -41,12 +42,11 @@ class FaceRecognizer(context: Context) {
}
fun identify(bitmap: Bitmap, face: FaceBox): FaceRecognitionResult {
if (!initialized) return FaceRecognitionResult(null, null, 0f, 0)
if (!initialized) return FaceRecognitionResult(null, 0f, 0)
val embedding = extractEmbedding(bitmap, face)
if (embedding.isEmpty()) return FaceRecognitionResult(null, null, 0f, 0)
if (embedding.isEmpty()) return FaceRecognitionResult(null, 0f, 0)
var bestId: Long? = null
var bestName: String? = null
var bestScore = -1f
for (p in cache) {
if (p.embedding.size != embedding.size) continue
@@ -54,13 +54,12 @@ class FaceRecognizer(context: Context) {
if (score > bestScore) {
bestScore = score
bestId = p.id
bestName = p.name
}
}
if (bestScore >= AppConfig.FaceRecognition.SIMILARITY_THRESHOLD) {
return FaceRecognitionResult(bestId, bestName, bestScore, embedding.size)
return FaceRecognitionResult(bestId, bestScore, embedding.size)
}
return FaceRecognitionResult(null, null, bestScore, embedding.size)
return FaceRecognitionResult(null, bestScore, embedding.size)
}
fun extractEmbedding(bitmap: Bitmap, face: FaceBox): FloatArray {
@@ -68,28 +67,208 @@ class FaceRecognizer(context: Context) {
return engine.extractEmbedding(bitmap, face.left, face.top, face.right, face.bottom)
}
private fun addProfile(name: String?, embedding: FloatArray): Long {
fun extractEmbeddingAligned(bitmap: Bitmap, face: FaceBox): FloatArray {
if (!initialized) return FloatArray(0)
val aligned = alignFaceToArcFace(bitmap, face) ?: return extractEmbedding(bitmap, face)
return try {
engine.extractEmbedding(
aligned,
0f,
0f,
aligned.width.toFloat(),
aligned.height.toFloat()
)
} finally {
aligned.recycle()
}
}
fun extractEmbeddingAlignedForTest(bitmap: Bitmap, face: FaceBox): FloatArray {
return extractEmbeddingAligned(bitmap, face)
}
/**
* 测试方法:计算同一张图片的两次识别相似度
* @param bitmap 测试图片
* @param face 人脸框
* @return 两次识别的相似度
*/
fun testSimilarity(bitmap: Bitmap, face: FaceBox): Float {
if (!initialized) return -1f
// 第一次提取特征
val embedding1 = extractEmbedding(bitmap, face)
if (embedding1.isEmpty()) return -1f
// 第二次提取特征
val embedding2 = extractEmbedding(bitmap, face)
if (embedding2.isEmpty()) return -1f
// 计算两次特征的相似度
return cosineSimilarity(embedding1, embedding2)
}
/**
* 测试方法:计算两张图片的相似度
* @param bitmap1 第一张图片
* @param face1 第一张图片的人脸框
* @param bitmap2 第二张图片
* @param face2 第二张图片的人脸框
* @return 两张图片的相似度
*/
fun testSimilarityBetween(bitmap1: Bitmap, face1: FaceBox, bitmap2: Bitmap, face2: FaceBox): Float {
Log.d(AppConfig.TAG, "[FaceRecognizer] testSimilarityBetween: initialized=$initialized")
if (!initialized) {
Log.e(AppConfig.TAG, "[FaceRecognizer] testSimilarityBetween: recognizer not initialized")
return -1f
}
// 测试链路优先使用 landmarks 对齐;若缺少关键点则回退到 bbox 裁剪。
Log.d(AppConfig.TAG, "[FaceRecognizer] testSimilarityBetween: extracting aligned embedding1...")
val embedding1 = extractEmbeddingAlignedForTest(bitmap1, face1)
Log.d(AppConfig.TAG, "[FaceRecognizer] testSimilarityBetween: embedding1 size=${embedding1.size}")
Log.d(AppConfig.TAG, "[FaceRecognizer] testSimilarityBetween: extracting aligned embedding2...")
val embedding2 = extractEmbeddingAlignedForTest(bitmap2, face2)
Log.d(AppConfig.TAG, "[FaceRecognizer] testSimilarityBetween: embedding2 size=${embedding2.size}")
if (embedding1.isEmpty() || embedding2.isEmpty()) {
Log.e(AppConfig.TAG, "[FaceRecognizer] testSimilarityBetween: embedding extraction failed - embedding1 empty=${embedding1.isEmpty()}, embedding2 empty=${embedding2.isEmpty()}")
return -1f
}
// 计算相似度
val similarity = cosineSimilarity(embedding1, embedding2)
Log.d(AppConfig.TAG, "[FaceRecognizer] testSimilarityBetween: similarity=$similarity")
return similarity
}
private fun alignFaceToArcFace(bitmap: Bitmap, face: FaceBox): Bitmap? {
val lm = face.landmarks ?: return null
if (!face.hasLandmarks || lm.size < 10) return null
val dstW = 112
val dstH = 112
val src = arrayOf(
floatArrayOf(lm[0], lm[1]), // left eye
floatArrayOf(lm[2], lm[3]), // right eye
floatArrayOf(lm[4], lm[5]), // nose
floatArrayOf(lm[6], lm[7]), // left mouth
floatArrayOf(lm[8], lm[9]), // right mouth
)
val dst = arrayOf(
floatArrayOf(38.2946f, 51.6963f),
floatArrayOf(73.5318f, 51.5014f),
floatArrayOf(56.0252f, 71.7366f),
floatArrayOf(41.5493f, 92.3655f),
floatArrayOf(70.7299f, 92.2041f),
)
val sim = estimateSimilarity(src, dst) ?: return null
val matrix = Matrix()
matrix.setValues(
floatArrayOf(
sim[0], sim[1], sim[2],
sim[3], sim[4], sim[5],
0f, 0f, 1f
)
)
val out = Bitmap.createBitmap(dstW, dstH, Bitmap.Config.ARGB_8888)
val canvas = Canvas(out)
canvas.drawBitmap(bitmap, matrix, null)
return out
}
private fun estimateSimilarity(
src: Array<FloatArray>,
dst: Array<FloatArray>,
): FloatArray? {
if (src.size != dst.size || src.size < 2) return null
var mxs = 0f
var mys = 0f
var mxd = 0f
var myd = 0f
val n = src.size.toFloat()
for (i in src.indices) {
mxs += src[i][0]
mys += src[i][1]
mxd += dst[i][0]
myd += dst[i][1]
}
mxs /= n
mys /= n
mxd /= n
myd /= n
var a = 0f
var b = 0f
var norm = 0f
for (i in src.indices) {
val sx = src[i][0] - mxs
val sy = src[i][1] - mys
val dx = dst[i][0] - mxd
val dy = dst[i][1] - myd
a += sx * dx + sy * dy
b += sx * dy - sy * dx
norm += sx * sx + sy * sy
}
if (norm < 1e-6f) return null
val r = kotlin.math.sqrt(a * a + b * b)
if (r < 1e-6f) return null
val cos = a / r
val sin = b / r
val scale = r / norm
val m00 = scale * cos
val m01 = -scale * sin
val m10 = scale * sin
val m11 = scale * cos
val tx = mxd - (m00 * mxs + m01 * mys)
val ty = myd - (m10 * mxs + m11 * mys)
return floatArrayOf(m00, m01, tx, m10, m11, ty)
}
private fun addProfile(embedding: FloatArray): Long {
val normalized = normalize(embedding)
val rowId = store.insertProfile(name, normalized)
val rowId = store.insertProfile(normalized)
if (rowId > 0) {
cache.add(FaceProfile(id = rowId, name = name, embedding = normalized))
cache.add(FaceProfile(id = rowId, embedding = normalized))
}
return rowId
}
fun resolveIdentity(bitmap: Bitmap, face: FaceBox): FaceRecognitionResult {
val match = identify(bitmap, face)
if (match.matchedId != null) return match
val embedding = extractEmbedding(bitmap, face)
if (embedding.isEmpty()) return match
val newId = addProfile(name = null, embedding = embedding)
if (newId <= 0L) return match
return FaceRecognitionResult(
matchedId = newId,
matchedName = null,
similarity = match.similarity,
embeddingDim = embedding.size
)
val embedding = extractEmbeddingAligned(bitmap, face)
return resolveIdentityFromEmbedding(embedding)
}
fun resolveIdentityFromEmbedding(embedding: FloatArray): FaceRecognitionResult {
if (!initialized || embedding.isEmpty()) return FaceRecognitionResult(null, -1f, 0)
val normalized = normalize(embedding)
var bestId: Long? = null
var bestScore = -1f
for (p in cache) {
if (p.embedding.size != normalized.size) continue
val score = cosineSimilarity(normalized, p.embedding)
if (score > bestScore) {
bestScore = score
bestId = p.id
}
}
if (bestId != null && bestScore >= AppConfig.FaceRecognition.SIMILARITY_THRESHOLD) {
return FaceRecognitionResult(bestId, bestScore, normalized.size)
}
val newId = addProfile(normalized)
if (newId > 0L) {
return FaceRecognitionResult(newId, bestScore, normalized.size)
}
return FaceRecognitionResult(null, bestScore, normalized.size)
}
fun release() {

View File

@@ -0,0 +1,258 @@
package com.digitalperson.interaction
import android.util.Log
import com.digitalperson.data.AppDatabase
import com.digitalperson.data.entity.ChatMessage
import com.digitalperson.data.entity.ChatMessageEntity
import com.digitalperson.data.entity.ConversationSummaryEntity
import com.digitalperson.llm.LLMManager
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.launch
/**
* 对话缓冲区内存管理类
* 用于存储完整的对话历史
*/
class ConversationBufferMemory(private val db: AppDatabase, private val maxMessages: Int = 50) {
private val userMessages = mutableMapOf<String, MutableList<ChatMessage>>()
private val TAG = "ConversationBufferMemory"
/**
* 添加新消息
* @param userId 用户ID
* @param role 角色user 或 assistant
* @param content 消息内容
*/
fun addMessage(userId: String, role: String, content: String) {
val messages = userMessages.getOrPut(userId) { mutableListOf() }
messages.add(ChatMessage(userId = userId, role = role, content = content))
// 限制消息数量,避免内存占用过大
if (messages.size > maxMessages) {
messages.removeAt(0)
}
Log.d(TAG, "Added message for user $userId: $role - $content")
}
/**
* 获取用户的对话历史
* @param userId 用户ID
* @return 对话历史列表
*/
fun getMessages(userId: String): List<ChatMessage> {
return userMessages.getOrDefault(userId, mutableListOf()).toList()
}
/**
* 获取用户的最近N条消息
* @param userId 用户ID
* @param limit 消息数量限制
* @return 最近的消息列表
*/
fun getRecentMessages(userId: String, limit: Int): List<ChatMessage> {
val messages = userMessages.getOrDefault(userId, mutableListOf())
return if (messages.size <= limit) {
messages.toList()
} else {
messages.subList(messages.size - limit, messages.size)
}
}
/**
* 从数据库加载用户历史
* @param userId 用户ID
*/
fun loadFromDatabase(userId: String) {
GlobalScope.launch(Dispatchers.IO) {
try {
val entities = db.chatMessageDao().getMessagesByUserId(userId)
val messages = mutableListOf<ChatMessage>()
entities.forEach { entity ->
messages.add(
ChatMessage(
id = entity.id,
userId = entity.userId,
role = entity.role,
content = entity.content,
timestamp = entity.timestamp
)
)
}
userMessages[userId] = messages
Log.d(TAG, "Loaded ${messages.size} messages for user $userId from database")
} catch (e: Exception) {
Log.e(TAG, "Error loading messages from database: ${e.message}")
}
}
}
/**
* 保存用户历史到数据库
* @param userId 用户ID
*/
fun saveToDatabase(userId: String) {
GlobalScope.launch(Dispatchers.IO) {
try {
db.chatMessageDao().clearMessagesByUserId(userId)
userMessages[userId]?.forEach { message ->
db.chatMessageDao().insert(
ChatMessageEntity(
id = message.id,
userId = message.userId,
role = message.role,
content = message.content,
timestamp = message.timestamp
)
)
}
Log.d(TAG, "Saved messages for user $userId to database")
} catch (e: Exception) {
Log.e(TAG, "Error saving messages to database: ${e.message}")
}
}
}
/**
* 清空用户历史
* @param userId 用户ID
*/
fun clear(userId: String) {
userMessages.remove(userId)
GlobalScope.launch(Dispatchers.IO) {
try {
db.chatMessageDao().clearMessagesByUserId(userId)
Log.d(TAG, "Cleared messages for user $userId")
} catch (e: Exception) {
Log.e(TAG, "Error clearing messages: ${e.message}")
}
}
}
/**
* 转换为LLM上下文格式
* @param userId 用户ID
* @return 上下文字符串
*/
fun toContextString(userId: String): String {
val messages = userMessages.getOrDefault(userId, mutableListOf())
return messages.joinToString("\n") {
"${it.role}: ${it.content}"
}
}
}
/**
* 对话摘要内存管理类
* 用于生成和存储对话摘要
*/
class ConversationSummaryMemory(private val db: AppDatabase, private val llmManager: LLMManager?) {
private val userSummaries = mutableMapOf<String, String>()
private val TAG = "ConversationSummaryMemory"
/**
* 获取用户的对话摘要
* @param userId 用户ID
* @return 对话摘要
*/
fun getSummary(userId: String): String {
return userSummaries.getOrDefault(userId, "")
}
/**
* 生成用户的对话摘要
* @param userId 用户ID
* @param messages 对话消息列表
* @param onComplete 完成回调
*/
fun generateSummary(userId: String, messages: List<ChatMessage>, onComplete: (String) -> Unit) {
if (messages.isEmpty()) {
onComplete("")
return
}
if (llmManager == null) {
Log.w(TAG, "LLM manager is not initialized, cannot generate summary")
onComplete("")
return
}
val historyText = messages.joinToString("\n") {
"${it.role}: ${it.content}"
}
val prompt = """
请总结以下对话的关键信息,包括:
1. 用户的主要问题或需求
2. 重要的个人信息(如姓名、爱好等)
3. 对话的核心结论或共识
输出简洁明了的总结不超过500字。
对话历史:
$historyText
""".trimIndent()
Log.d(TAG, "Generating summary for user $userId")
llmManager.generate(prompt) { summary ->
val trimmedSummary = summary.trim()
userSummaries[userId] = trimmedSummary
// 保存到数据库
GlobalScope.launch(Dispatchers.IO) {
try {
db.conversationSummaryDao().clearSummariesByUserId(userId)
db.conversationSummaryDao().insert(
ConversationSummaryEntity(
id = 0,
userId = userId,
summary = trimmedSummary,
createdAt = System.currentTimeMillis()
)
)
Log.d(TAG, "Saved summary for user $userId to database")
} catch (e: Exception) {
Log.e(TAG, "Error saving summary to database: ${e.message}")
}
}
onComplete(trimmedSummary)
Log.d(TAG, "Generated summary for user $userId: $trimmedSummary")
}
}
/**
* 从数据库加载用户摘要
* @param userId 用户ID
*/
fun loadFromDatabase(userId: String) {
GlobalScope.launch(Dispatchers.IO) {
try {
val summaryEntity = db.conversationSummaryDao().getLatestSummaryByUserId(userId)
if (summaryEntity != null) {
userSummaries[userId] = summaryEntity.summary
Log.d(TAG, "Loaded summary for user $userId from database")
}
} catch (e: Exception) {
Log.e(TAG, "Error loading summary from database: ${e.message}")
}
}
}
/**
* 清空用户摘要
* @param userId 用户ID
*/
fun clear(userId: String) {
userSummaries.remove(userId)
GlobalScope.launch(Dispatchers.IO) {
try {
db.conversationSummaryDao().clearSummariesByUserId(userId)
Log.d(TAG, "Cleared summary for user $userId")
} catch (e: Exception) {
Log.e(TAG, "Error clearing summary: ${e.message}")
}
}
}
}

View File

@@ -1,5 +1,7 @@
package com.digitalperson.interaction
import android.content.Context
import com.digitalperson.util.SmartGreetingUtil
import android.util.Log
import com.digitalperson.config.AppConfig
import kotlinx.coroutines.CoroutineScope
@@ -32,18 +34,25 @@ interface InteractionActionHandler {
fun onRememberUser(faceIdentityId: String, name: String?)
fun saveThought(thought: String)
fun loadLatestThought(): String?
fun loadRecentThoughts(timeRangeMs: Long): List<String>
fun addToChatHistory(role: String, content: String)
fun addAssistantMessageToCloudHistory(content: String)
fun getRandomQuestion(faceId: String): String
}
class DigitalHumanInteractionController(
private val scope: CoroutineScope,
private val handler: InteractionActionHandler,
private val context: Context
) {
private val smartGreetingUtil = SmartGreetingUtil(context)
private val TAG: String = "DigitalHumanInteraction"
private var state: InteractionState = InteractionState.IDLE
private var facePresent: Boolean = false
private var recognizedName: String? = null
private var currentFaceId: String? = null
private var faceSeenSinceMs: Long = 0L
private var proactiveRound = 0
private var hasPendingUserReply = false
@@ -58,59 +67,86 @@ class DigitalHumanInteractionController(
scheduleMemoryMode()
}
fun onFacePresenceChanged(present: Boolean, faceIdentityId: String?, recognized: String?) {
Log.d(TAG, "onFacePresenceChanged: present=$present, faceIdentityId=$faceIdentityId, recognized=$recognized, state=$state")
facePresent = present
if (!faceIdentityId.isNullOrBlank()) {
handler.onRememberUser(faceIdentityId, recognized)
}
if (!recognized.isNullOrBlank()) {
recognizedName = recognized
fun onFacePresenceChanged(present: Boolean) {
Log.d(TAG, "onFacePresenceChanged: present=$present, state=$state")
facePresent = present
val now = System.currentTimeMillis()
if (present) {
if (faceSeenSinceMs == 0L) {
faceSeenSinceMs = now
}
} else {
faceSeenSinceMs = 0L
}
// 首次出现就启动稳定计时,到点只要人还在就问候。
if (present) {
val stableMs = now - faceSeenSinceMs
val remain = AppConfig.Face.STABLE_MS - stableMs
if (remain > 0) {
faceStableJob?.cancel()
faceStableJob = scope.launch {
delay(remain)
if (!facePresent) return@launch
if (state == InteractionState.IDLE || state == InteractionState.MEMORY || state == InteractionState.FAREWELL) {
if (currentFaceId.isNullOrBlank()) {
Log.d(TAG, "Greeting as unknown user: identity still unavailable after stable timeout")
}
enterGreeting()
}
}
} else if (state == InteractionState.IDLE || state == InteractionState.MEMORY || state == InteractionState.FAREWELL) {
enterGreeting()
}
return
}
// 统一延迟处理
// 人脸消失:保留小延迟,避免瞬时抖动导致频繁告别
faceStableJob?.cancel()
faceStableJob = scope.launch {
delay(AppConfig.Face.STABLE_MS)
if (present) {
// 人脸出现后的处理
if (facePresent && (state == InteractionState.IDLE || state == InteractionState.MEMORY)) {
enterGreeting()
} else if (state == InteractionState.FAREWELL) {
enterGreeting()
}
if (facePresent) return@launch
if (state != InteractionState.IDLE && state != InteractionState.MEMORY && state != InteractionState.FAREWELL) {
enterFarewell()
} else {
// 人脸消失后的处理
if (state != InteractionState.IDLE && state != InteractionState.MEMORY && state != InteractionState.FAREWELL) {
enterFarewell()
} else {
scheduleMemoryMode()
}
scheduleMemoryMode()
}
}
}
fun onFaceIdentityUpdated(faceIdentityId: String?, recognized: String?) {
if (faceIdentityId.isNullOrBlank() && recognized.isNullOrBlank()) return
Log.d(TAG, "onFaceIdentityUpdated: faceIdentityId=$faceIdentityId, recognized=$recognized, state=$state")
if (!faceIdentityId.isNullOrBlank()) {
currentFaceId = faceIdentityId
handler.onRememberUser(faceIdentityId, recognized)
}
if (!recognized.isNullOrBlank()) {
recognizedName = recognized
}
}
fun onUserStartSpeaking() {
Log.d(TAG, "onUserStartSpeaking called, current state: $state")
hasPendingUserReply = true
// 立即进入对话状态,无论当前处于什么状态
transitionTo(InteractionState.DIALOGUE)
// 取消任何等待超时的任务
waitReplyJob?.cancel()
proactiveJob?.cancel()
}
fun onUserAsrText(text: String) {
val userText = text.trim()
if (userText.isBlank()) return
if (userText.contains("你在想什么")) {
val thought = handler.loadLatestThought()
if (!thought.isNullOrBlank()) {
handler.speak("我刚才在想:$thought")
handler.appendText("\n[回忆] $thought\n")
transitionTo(InteractionState.DIALOGUE)
handler.playMotion("haru_g_m15.motion3.json")
return
}
}
if (userText.contains("再见")) {
enterFarewell()
return
}
// TODO: 后续应该是通过大模型来进行判断是否进入 farewell 状态
// if (userText.contains("再见")) {
// enterFarewell()
// return
// }
hasPendingUserReply = true
when (state) {
@@ -140,15 +176,46 @@ fun onFacePresenceChanged(present: Boolean, faceIdentityId: String?, recognized:
private fun enterGreeting() {
transitionTo(InteractionState.GREETING)
val greet = if (!recognizedName.isNullOrBlank()) {
handler.playMotion("haru_g_m22.motion3.json")
"你好,$recognizedName,很高兴再次见到你。"
// 使用智能问候语
val (isFestival, festivalName) = smartGreetingUtil.isFestivalToday()
if (isFestival) {
// 节日问候使用本地LLM生成
val prompt = smartGreetingUtil.getFestivalGreetingPrompt(festivalName, recognizedName)
handler.requestLocalThought(prompt) { greeting ->
scope.launch {
if (!greeting.isNullOrBlank()) {
// 使用LLM生成的问候语
handler.playMotion(if (!recognizedName.isNullOrBlank()) "haru_g_m22.motion3.json" else "haru_g_m01.motion3.json")
handler.speak(greeting)
handler.appendText("\n[问候] $greeting\n")
handler.addToChatHistory("assistant", greeting)
handler.addAssistantMessageToCloudHistory(greeting)
transitionTo(InteractionState.WAITING_REPLY)
handler.playMotion("haru_g_m17.motion3.json")
scheduleWaitingReplyTimeout()
} else {
// LLM生成失败使用默认问候语
useDefaultGreeting()
}
}
}
} else {
handler.playMotion("haru_g_m01.motion3.json")
"你好,很高兴见到你。"
// 非节日或无网络,使用默认问候语
useDefaultGreeting()
}
handler.speak(greet)
handler.appendText("\n[问候] $greet\n")
}
private fun useDefaultGreeting() {
val greeting = smartGreetingUtil.getSmartGreeting(recognizedName)
handler.playMotion(if (!recognizedName.isNullOrBlank()) "haru_g_m22.motion3.json" else "haru_g_m01.motion3.json")
handler.speak(greeting)
handler.appendText("\n[问候] $greeting\n")
handler.addToChatHistory("assistant", greeting)
handler.addAssistantMessageToCloudHistory(greeting)
transitionTo(InteractionState.WAITING_REPLY)
handler.playMotion("haru_g_m17.motion3.json")
scheduleWaitingReplyTimeout()
@@ -176,16 +243,13 @@ fun onFacePresenceChanged(present: Boolean, faceIdentityId: String?, recognized:
private fun askProactiveTopic() {
proactiveJob?.cancel()
val topics = listOf(
"我出一道数学题考考你吧1+6等于多少",
"我们上完厕所应该干什么呀?",
"你喜欢什么颜色呀?",
)
val idx = proactiveRound.coerceIn(0, topics.lastIndex)
val topic = topics[idx]
// 从数据库获取问题
val topic = getQuestionFromDatabase()
handler.playMotion(if (proactiveRound == 0) "haru_g_m15.motion3.json" else "haru_g_m22.motion3.json")
handler.speak(topic)
handler.appendText("\n[主动引导] $topic\n")
handler.speak("嗨小朋友,可以帮老师个忙吗?"+topic)
handler.appendText("\n[主动引导] \"嗨小朋友,可以帮老师个忙吗?\"$topic\n")
// 将引导内容添加到对话历史中
handler.addToChatHistory("助手", topic)
// 将引导内容添加到云对话历史中
@@ -209,6 +273,17 @@ fun onFacePresenceChanged(present: Boolean, faceIdentityId: String?, recognized:
}
}
}
private fun getQuestionFromDatabase(): String {
// 从数据库获取随机问题(未被该用户问过的)
val faceId = getCurrentFaceId()
return handler.getRandomQuestion(faceId)
}
private fun getCurrentFaceId(): String {
// 返回当前用户的 faceId
return currentFaceId ?: "default_face_id"
}
private fun enterFarewell() {
transitionTo(InteractionState.FAREWELL)
@@ -224,6 +299,7 @@ fun onFacePresenceChanged(present: Boolean, faceIdentityId: String?, recognized:
}
private fun scheduleMemoryMode() {
return // TODO: remove this after testing done! Now just skipping memory
memoryJob?.cancel()
if (facePresent) return
memoryJob = scope.launch {
@@ -248,8 +324,20 @@ fun onFacePresenceChanged(present: Boolean, faceIdentityId: String?, recognized:
private fun requestDialogueReply(userText: String) {
waitReplyJob?.cancel()
proactiveJob?.cancel()
// 获取最近的回忆信息
val recentThoughts = handler.loadRecentThoughts(10 * 60 * 1000)
val context = if (recentThoughts.isNotEmpty()) {
val thoughtsContext = recentThoughts.joinToString("\n") { "- $it" }
"以下是我最近10分钟内的想法\n$thoughtsContext\n\n"
} else {
""
}
// 将回忆信息作为上下文传递给LLM
val userTextWithContext = "$context$userText"
// 按产品要求用户对话统一走云端LLM
handler.requestCloudReply(userText)
handler.requestCloudReply(userTextWithContext)
}
private fun transitionTo(newState: InteractionState) {
@@ -265,4 +353,5 @@ fun onFacePresenceChanged(present: Boolean, faceIdentityId: String?, recognized:
memoryJob?.cancel()
farewellJob?.cancel()
}
}

View File

@@ -1,80 +1,52 @@
package com.digitalperson.interaction
import android.content.ContentValues
import android.content.Context
import android.database.sqlite.SQLiteDatabase
import android.database.sqlite.SQLiteOpenHelper
import com.digitalperson.data.AppDatabase
import com.digitalperson.data.entity.Question
import com.digitalperson.data.entity.UserAnswer
import com.digitalperson.data.entity.UserMemory
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
data class UserMemory(
val userId: String,
val displayName: String?,
val lastSeenAt: Long,
val age: String?,
val gender: String?,
val hobbies: String?,
val preferences: String?,
val lastTopics: String?,
val lastThought: String?,
val profileSummary: String?,
)
class UserMemoryStore(context: Context) : SQLiteOpenHelper(context, DB_NAME, null, DB_VERSION) {
class UserMemoryStore(context: Context) {
private val db = AppDatabase.getInstance(context)
private val userMemoryDao = db.userMemoryDao()
private val questionDao = db.questionDao()
private val userAnswerDao = db.userAnswerDao()
private val memoryCache = LinkedHashMap<String, UserMemory>()
@Volatile private var latestThoughtCache: String? = null
override fun onCreate(db: SQLiteDatabase) {
db.execSQL(
"""
CREATE TABLE IF NOT EXISTS user_memory (
user_id TEXT PRIMARY KEY,
display_name TEXT,
last_seen_at INTEGER NOT NULL,
age TEXT,
gender TEXT,
hobbies TEXT,
preferences TEXT,
last_topics TEXT,
last_thought TEXT,
profile_summary TEXT
)
""".trimIndent()
)
}
override fun onUpgrade(db: SQLiteDatabase, oldVersion: Int, newVersion: Int) {
db.execSQL("DROP TABLE IF EXISTS user_memory")
onCreate(db)
// 评价类型
enum class AnswerEvaluation(val displayName: String) {
CORRECT("正确"),
INCORRECT("错误"),
PARTIALLY_CORRECT("基本正确")
}
fun upsertUserSeen(userId: String, displayName: String?) {
val existing = memoryCache[userId] ?: getMemory(userId)
val now = System.currentTimeMillis()
val mergedName = displayName?.takeIf { it.isNotBlank() } ?: existing?.displayName
val values = ContentValues().apply {
put("user_id", userId)
put("display_name", mergedName)
put("last_seen_at", now)
put("age", existing?.age)
put("gender", existing?.gender)
put("hobbies", existing?.hobbies)
put("preferences", existing?.preferences)
put("last_topics", existing?.lastTopics)
put("last_thought", existing?.lastThought)
put("profile_summary", existing?.profileSummary)
GlobalScope.launch(Dispatchers.IO) {
val existing = getMemorySync(userId)
val now = System.currentTimeMillis()
val mergedName = displayName?.takeIf { it.isNotBlank() } ?: existing?.displayName
val userMemory = UserMemory(
userId = userId,
displayName = mergedName,
lastSeenAt = now,
age = existing?.age,
gender = existing?.gender,
hobbies = existing?.hobbies,
preferences = existing?.preferences,
lastTopics = existing?.lastTopics,
lastThought = existing?.lastThought,
profileSummary = existing?.profileSummary,
)
userMemoryDao.insert(userMemory)
memoryCache[userId] = userMemory
}
writableDatabase.insertWithOnConflict("user_memory", null, values, SQLiteDatabase.CONFLICT_REPLACE)
memoryCache[userId] = UserMemory(
userId = userId,
displayName = mergedName,
lastSeenAt = now,
age = existing?.age,
gender = existing?.gender,
hobbies = existing?.hobbies,
preferences = existing?.preferences,
lastTopics = existing?.lastTopics,
lastThought = existing?.lastThought,
profileSummary = existing?.profileSummary,
)
}
fun updateDisplayName(userId: String, displayName: String?) {
@@ -83,95 +55,164 @@ class UserMemoryStore(context: Context) : SQLiteOpenHelper(context, DB_NAME, nul
}
fun updateThought(userId: String, thought: String) {
upsertUserSeen(userId, null)
val now = System.currentTimeMillis()
val values = ContentValues().apply {
put("last_thought", thought)
put("last_seen_at", now)
GlobalScope.launch(Dispatchers.IO) {
upsertUserSeen(userId, null)
val now = System.currentTimeMillis()
userMemoryDao.updateThought(userId, thought, now)
latestThoughtCache = thought
val existing = memoryCache[userId]
if (existing != null) {
memoryCache[userId] = existing.copy(
lastThought = thought,
lastSeenAt = now
)
}
}
writableDatabase.update("user_memory", values, "user_id=?", arrayOf(userId))
latestThoughtCache = thought
val cached = memoryCache[userId]
memoryCache[userId] = UserMemory(
userId = userId,
displayName = cached?.displayName,
lastSeenAt = now,
age = cached?.age,
gender = cached?.gender,
hobbies = cached?.hobbies,
preferences = cached?.preferences,
lastTopics = cached?.lastTopics,
lastThought = thought,
profileSummary = cached?.profileSummary,
)
}
fun updateProfile(userId: String, age: String?, gender: String?, hobbies: String?, summary: String?) {
upsertUserSeen(userId, null)
val now = System.currentTimeMillis()
val values = ContentValues().apply {
if (age != null) put("age", age)
if (gender != null) put("gender", gender)
if (hobbies != null) put("hobbies", hobbies)
if (summary != null) put("profile_summary", summary)
put("last_seen_at", now)
GlobalScope.launch(Dispatchers.IO) {
upsertUserSeen(userId, null)
val now = System.currentTimeMillis()
userMemoryDao.updateProfile(userId, age, gender, hobbies, summary, now)
val existing = memoryCache[userId]
if (existing != null) {
memoryCache[userId] = existing.copy(
age = age ?: existing.age,
gender = gender ?: existing.gender,
hobbies = hobbies ?: existing.hobbies,
profileSummary = summary ?: existing.profileSummary,
lastSeenAt = now
)
}
}
writableDatabase.update("user_memory", values, "user_id=?", arrayOf(userId))
val cached = memoryCache[userId]
memoryCache[userId] = UserMemory(
userId = userId,
displayName = cached?.displayName,
lastSeenAt = now,
age = age ?: cached?.age,
gender = gender ?: cached?.gender,
hobbies = hobbies ?: cached?.hobbies,
preferences = cached?.preferences,
lastTopics = cached?.lastTopics,
lastThought = cached?.lastThought,
profileSummary = summary ?: cached?.profileSummary,
)
}
fun getMemory(userId: String): UserMemory? {
memoryCache[userId]?.let { return it }
readableDatabase.rawQuery(
"SELECT user_id, display_name, last_seen_at, age, gender, hobbies, preferences, last_topics, last_thought, profile_summary FROM user_memory WHERE user_id=?",
arrayOf(userId)
).use { c ->
if (!c.moveToFirst()) return null
val memory = UserMemory(
userId = c.getString(0),
displayName = c.getString(1),
lastSeenAt = c.getLong(2),
age = c.getString(3),
gender = c.getString(4),
hobbies = c.getString(5),
preferences = c.getString(6),
lastTopics = c.getString(7),
lastThought = c.getString(8),
profileSummary = c.getString(9),
)
memoryCache[userId] = memory
if (!memory.lastThought.isNullOrBlank()) {
latestThoughtCache = memory.lastThought
}
return memory
val memory = userMemoryDao.getUserMemory(userId)
memory?.let { memoryCache[userId] = it }
if (memory?.lastThought != null) {
latestThoughtCache = memory.lastThought
}
return memory
}
// 同步版本,用于内部使用
private fun getMemorySync(userId: String): UserMemory? {
memoryCache[userId]?.let { return it }
val memory = userMemoryDao.getUserMemory(userId)
memory?.let { memoryCache[userId] = it }
if (memory?.lastThought != null) {
latestThoughtCache = memory.lastThought
}
return memory
}
fun getLatestThought(): String? {
latestThoughtCache?.let { return it }
readableDatabase.rawQuery(
"SELECT last_thought FROM user_memory WHERE last_thought IS NOT NULL AND last_thought != '' ORDER BY last_seen_at DESC LIMIT 1",
null
).use { c ->
if (!c.moveToFirst()) return null
return c.getString(0).also { latestThoughtCache = it }
val thought = userMemoryDao.getLatestThought()
thought?.let { latestThoughtCache = it }
return thought
}
fun getRecentThoughts(timeRangeMs: Long = 10 * 60 * 1000): List<String> {
val cutoffTime = System.currentTimeMillis() - timeRangeMs
val recentMemories = userMemoryDao.getRecentThoughts(cutoffTime)
return recentMemories.mapNotNull { it.lastThought }
}
// 问题相关方法
fun addQuestion(content: String, answer: String?, subject: String?, grade: Int?, difficulty: Int = 1): Long {
val question = Question(
id = 0,
content = content,
answer = answer,
subject = subject,
grade = grade,
difficulty = difficulty,
createdAt = System.currentTimeMillis()
)
return questionDao.insert(question)
}
fun getQuestionsBySubject(subject: String): List<Question> {
return questionDao.getQuestionsBySubject(subject)
}
fun getQuestionsByGrade(grade: Int): List<Question> {
return questionDao.getQuestionsByGrade(grade)
}
fun getRandomQuestion(subject: String? = null, grade: Int? = null): Question? {
return when {
subject != null && grade != null ->
questionDao.getRandomQuestionBySubjectAndGrade(subject, grade)
subject != null ->
questionDao.getRandomQuestionBySubject(subject)
else ->
questionDao.getRandomQuestion()
}
}
fun getRandomUnansweredQuestion(userId: String): Question? {
return questionDao.getRandomUnansweredQuestion(userId)
}
fun recordUserAnswer(userId: String, questionId: Long, userAnswer: String?, evaluation: AnswerEvaluation?) {
GlobalScope.launch(Dispatchers.IO) {
val userAnswerEntity = UserAnswer(
id = 0,
userId = userId,
questionId = questionId,
userAnswer = userAnswer,
evaluation = evaluation?.name,
answeredAt = System.currentTimeMillis()
)
userAnswerDao.insert(userAnswerEntity)
}
}
companion object {
private const val DB_NAME = "digital_human_memory.db"
private const val DB_VERSION = 2
fun getUserAnswers(userId: String, limit: Int = 50): List<UserAnswer> {
return userAnswerDao.getUserAnswers(userId, limit)
}
fun getIncorrectAnswers(userId: String, limit: Int = 10): List<Pair<Question, String>> {
val incorrectAnswers = userAnswerDao.getIncorrectAnswers(userId, limit)
return incorrectAnswers.mapNotNull { answerWithQuestion ->
val question = questionDao.getQuestionById(answerWithQuestion.questionId)
question?.let { Pair(it, answerWithQuestion.userAnswer ?: "") }
}
}
fun getQuestionStatistics(userId: String): Map<String, Int> {
val stats = mutableMapOf(
"total" to 0,
"correct" to 0,
"incorrect" to 0,
"partially_correct" to 0
)
val answerStats = userAnswerDao.getAnswerStatistics(userId)
answerStats.forEach {
val count = it.count
stats["total"] = stats["total"]!! + count
when (it.evaluation) {
AnswerEvaluation.CORRECT.name -> stats["correct"] = count
AnswerEvaluation.INCORRECT.name -> stats["incorrect"] = count
AnswerEvaluation.PARTIALLY_CORRECT.name -> stats["partially_correct"] = count
}
}
return stats
}
fun getQuestionById(questionId: Long): Question? {
return questionDao.getQuestionById(questionId)
}
fun getAllSubjects(): List<String> {
return questionDao.getAllSubjects()
}
}

View File

@@ -43,4 +43,30 @@ class LLMManager(modelPath: String, callback: LLMManagerCallback) :
val msg = "<System>$systemPrompt<User>$userPrompt<Assistant>"
say(msg)
}
/**
* 生成文本
* @param prompt 提示词
* @param onComplete 完成回调
*/
fun generate(prompt: String, onComplete: (String) -> Unit) {
val msg = "<User>$prompt<Assistant>"
var result = ""
// 创建临时回调
val tempCallback = object : LLMCallback {
override fun onCallback(data: String, state: LLMCallback.State) {
if (state == LLMCallback.State.NORMAL) {
if (data != "<think>" && data != "</think>" && data != "\n") {
result += data
}
} else {
onComplete(result)
}
}
}
// 调用底层方法
say(msg, tempCallback)
}
}

View File

@@ -38,6 +38,23 @@ open class RKLLM(modelPath: String, callback: LLMCallback) {
infer(mInstance, text)
}
protected fun say(text: String, callback: LLMCallback) {
if (mInstance == 0L) {
callback.onCallback("RKLLM is not initialized", LLMCallback.State.ERROR)
return
}
// 保存原始回调
val originalCallback = mCallback
// 临时替换回调
mCallback = callback
try {
infer(mInstance, text)
} finally {
// 恢复原始回调
mCallback = originalCallback
}
}
fun callbackFromNative(data: String, state: Int) {
var s = LLMCallback.State.ERROR
s = if (state == 0) LLMCallback.State.FINISH

View File

@@ -19,8 +19,11 @@ import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import java.nio.ByteBuffer
import java.util.UUID
import java.util.concurrent.CountDownLatch
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicLong
class QCloudTtsManager(private val context: Context) {
@@ -37,11 +40,16 @@ class QCloudTtsManager(private val context: Context) {
data object End : TtsQueueItem()
}
private val ttsQueue = LinkedBlockingQueue<TtsQueueItem>()
private val ttsStopped = AtomicBoolean(false)
private val ttsWorkerRunning = AtomicBoolean(false)
private val ttsPlaying = AtomicBoolean(false)
private val ttsWorkerRunning = AtomicBoolean(false)
private val interrupting = AtomicBoolean(false)
private val playEpoch = AtomicLong(0)
private val writtenFrames = AtomicLong(0)
private val audioTrackLock = Any()
private val ttsQueue = LinkedBlockingQueue<TtsQueueItem>()
@Volatile private var lastEnqueuedText: String = ""
@Volatile private var lastEnqueuedAtMs: Long = 0L
private val ioScope = CoroutineScope(Dispatchers.IO)
@@ -100,6 +108,15 @@ class QCloudTtsManager(private val context: Context) {
ttsStopped.set(false)
}
val cleanedSeg = seg.trimEnd('.', '。', '!', '', '?', '', ',', '', ';', '', ':', '')
if (cleanedSeg.isBlank()) return
val now = System.currentTimeMillis()
// Guard against accidental duplicate enqueue from streaming/final flush overlap.
if (cleanedSeg == lastEnqueuedText && now - lastEnqueuedAtMs < 3000L) {
Log.w(TAG, "Skip duplicate segment within 3s: '$cleanedSeg'")
return
}
lastEnqueuedText = cleanedSeg
lastEnqueuedAtMs = now
ttsQueue.offer(TtsQueueItem.Segment(cleanedSeg))
ensureTtsWorker()
}
@@ -126,20 +143,23 @@ class QCloudTtsManager(private val context: Context) {
fun stop() {
ttsStopped.set(true)
ttsPlaying.set(false)
playEpoch.incrementAndGet()
writtenFrames.set(0L)
ttsQueue.clear()
ttsQueue.offer(TtsQueueItem.End)
try {
synthesizer?.cancel()
synthesizer = null
audioTrack?.pause()
audioTrack?.flush()
pauseAndFlushAudioTrack()
} catch (_: Throwable) {
}
}
fun interruptForNewTurn(waitTimeoutMs: Long = 300): Boolean {
if (!interrupting.compareAndSet(false, true)) return false
if (!interrupting.compareAndSet(false, true)) {
return false
}
try {
val hadPendingPlayback = ttsPlaying.get() || ttsWorkerRunning.get() || ttsQueue.isNotEmpty()
if (!hadPendingPlayback) {
@@ -150,15 +170,17 @@ class QCloudTtsManager(private val context: Context) {
ttsStopped.set(true)
ttsPlaying.set(false)
playEpoch.incrementAndGet()
writtenFrames.set(0L)
ttsQueue.clear()
ttsQueue.offer(TtsQueueItem.End)
try {
synthesizer?.cancel()
synthesizer = null
audioTrack?.pause()
audioTrack?.flush()
} catch (_: Throwable) {
pauseAndFlushAudioTrack()
} catch (e: Throwable) {
Log.e(TAG, "Error during interruption: ${e.message}")
}
val deadline = System.currentTimeMillis() + waitTimeoutMs
@@ -212,21 +234,28 @@ class QCloudTtsManager(private val context: Context) {
while (true) {
val item = ttsQueue.take()
if (ttsStopped.get()) break
if (ttsStopped.get()) {
break
}
when (item) {
is TtsQueueItem.Segment -> {
ttsPlaying.set(true)
callback?.onSetSpeaking(true)
Log.d(TAG, "QCloud TTS started: processing segment '${item.text}'")
callback?.onTtsStarted(item.text)
val startMs = System.currentTimeMillis()
val segmentEpoch = playEpoch.get()
val segmentDone = CountDownLatch(1)
try {
if (audioTrack.playState != AudioTrack.PLAYSTATE_PLAYING) {
audioTrack.play()
playAudioTrack(audioTrack)
// Always cancel stale instance before creating a new one.
try {
synthesizer?.cancel()
} catch (_: Throwable) {
}
synthesizer = null
val credential = Credential(
AppConfig.QCloud.APP_ID,
@@ -270,40 +299,46 @@ class QCloudTtsManager(private val context: Context) {
val listener = object : RealTimeSpeechSynthesizerListener() {
override fun onSynthesisStart(response: SpeechSynthesizerResponse) {
Log.d(TAG, "onSynthesisStart: ${response.sessionId}")
Log.d(TAG, "Starting synthesizer: ${response.sessionId}")
}
override fun onSynthesisEnd(response: SpeechSynthesizerResponse) {
Log.d(TAG, "onSynthesisEnd: ${response.sessionId}")
val ttsMs = System.currentTimeMillis() - startMs
callback?.onTtsSegmentCompleted(ttsMs)
segmentDone.countDown()
}
override fun onAudioResult(buffer: ByteBuffer) {
if (ttsStopped.get() || segmentEpoch != playEpoch.get()) {
return
}
val data = ByteArray(buffer.remaining())
buffer.get(data)
// 播放pcm
audioTrack.write(data, 0, data.size)
writeAudioTrack(audioTrack, data)
}
override fun onTextResult(response: SpeechSynthesizerResponse) {
Log.d(TAG, "onTextResult: ${response.sessionId}")
}
override fun onTextResult(response: SpeechSynthesizerResponse) {}
override fun onSynthesisCancel() {
Log.d(TAG, "onSynthesisCancel")
segmentDone.countDown()
}
override fun onSynthesisFail(response: SpeechSynthesizerResponse) {
Log.e(TAG, "onSynthesisFail: ${response.sessionId}, error: ${response.message}")
segmentDone.countDown()
}
}
synthesizer = RealTimeSpeechSynthesizer(proxy, credential, request, listener)
synthesizer?.start()
segmentDone.await(40, TimeUnit.SECONDS)
} catch (e: Exception) {
Log.e(TAG, "QCloud TTS error: ${e.message}", e)
} finally {
if (synthesizer != null && (ttsStopped.get() || segmentEpoch != playEpoch.get())) {
try { synthesizer?.cancel() } catch (_: Throwable) {}
}
}
}
@@ -324,7 +359,54 @@ class QCloudTtsManager(private val context: Context) {
}
private fun waitForPlaybackComplete(audioTrack: AudioTrack) {
// 等待音频播放完成
Thread.sleep(1000)
val targetFrames = writtenFrames.get()
if (targetFrames <= 0L) return
val timeoutAt = System.currentTimeMillis() + 1800L
while (System.currentTimeMillis() < timeoutAt) {
val playedFrames = audioTrack.playbackHeadPosition.toLong() and 0xFFFFFFFFL
if (playedFrames >= targetFrames) break
try {
Thread.sleep(20)
} catch (_: Throwable) {
break
}
}
writtenFrames.set(0L)
}
private fun playAudioTrack(track: AudioTrack) {
synchronized(audioTrackLock) {
if (track.playState != AudioTrack.PLAYSTATE_PLAYING) {
track.play()
}
}
}
private fun writeAudioTrack(track: AudioTrack, data: ByteArray) {
synchronized(audioTrackLock) {
if (track.playState == AudioTrack.PLAYSTATE_PLAYING && data.isNotEmpty()) {
track.write(data, 0, data.size)
writtenFrames.addAndGet((data.size / 2).toLong())
}
}
}
private fun pauseAndFlushAudioTrack() {
synchronized(audioTrackLock) {
val track = audioTrack ?: return
try {
if (track.playState == AudioTrack.PLAYSTATE_PLAYING) {
track.pause()
}
} catch (e: Throwable) {
Log.e(TAG, "Error pausing track: ${e.message}")
}
try {
track.flush()
} catch (e: Throwable) {
Log.e(TAG, "Error flushing track: ${e.message}")
}
}
}
}

View File

@@ -11,7 +11,7 @@ class TtsController(private val context: Context) {
private var localTts: TtsManager? = null
private var qcloudTts: QCloudTtsManager? = null
private var useQCloudTts = false
private var useQCloudTts = true
interface TtsCallback {
fun onTtsStarted(text: String)
@@ -27,29 +27,34 @@ class TtsController(private val context: Context) {
fun setCallback(callback: TtsCallback) {
this.callback = callback
bindCallbacksIfReady()
}
private fun bindCallbacksIfReady() {
val cb = callback ?: return
localTts?.setCallback(object : TtsManager.TtsCallback {
override fun onTtsStarted(text: String) {
callback.onTtsStarted(text)
cb.onTtsStarted(text)
}
override fun onTtsCompleted() {
callback.onTtsCompleted()
cb.onTtsCompleted()
}
override fun onTtsSegmentCompleted(durationMs: Long) {
callback.onTtsSegmentCompleted(durationMs)
cb.onTtsSegmentCompleted(durationMs)
}
override fun isTtsStopped(): Boolean {
return callback.isTtsStopped()
return cb.isTtsStopped()
}
override fun onClearAsrQueue() {
callback.onClearAsrQueue()
cb.onClearAsrQueue()
}
override fun onSetSpeaking(speaking: Boolean) {
callback.onSetSpeaking(speaking)
cb.onSetSpeaking(speaking)
}
override fun getCurrentTrace() = null
@@ -73,36 +78,36 @@ class TtsController(private val context: Context) {
}
override fun onEndTurn() {
callback.onEndTurn()
cb.onEndTurn()
}
})
qcloudTts?.setCallback(object : QCloudTtsManager.TtsCallback {
override fun onTtsStarted(text: String) {
callback.onTtsStarted(text)
cb.onTtsStarted(text)
}
override fun onTtsCompleted() {
callback.onTtsCompleted()
cb.onTtsCompleted()
}
override fun onTtsSegmentCompleted(durationMs: Long) {
callback.onTtsSegmentCompleted(durationMs)
cb.onTtsSegmentCompleted(durationMs)
}
override fun isTtsStopped(): Boolean {
return callback.isTtsStopped()
return cb.isTtsStopped()
}
override fun onClearAsrQueue() {
callback.onClearAsrQueue()
cb.onClearAsrQueue()
}
override fun onSetSpeaking(speaking: Boolean) {
callback.onSetSpeaking(speaking)
cb.onSetSpeaking(speaking)
}
override fun onEndTurn() {
callback.onEndTurn()
cb.onEndTurn()
}
})
}
@@ -118,6 +123,9 @@ class TtsController(private val context: Context) {
val qcloudInit = qcloudTts?.init() ?: false
Log.d(TAG, "QCloud TTS init: $qcloudInit")
// setCallback() is usually called before init(), so rebind here.
bindCallbacksIfReady()
return localInit || qcloudInit
}

View File

@@ -317,4 +317,17 @@ object FileHelper {
connection.disconnect()
}
}
/**
* 下载测试图片
*/
fun downloadTestImage(url: String, destination: File): Boolean {
return try {
downloadFile(url, destination)
true
} catch (e: Exception) {
Log.e(TAG, "Failed to download test image: ${e.message}", e)
false
}
}
}

View File

@@ -0,0 +1,99 @@
package com.digitalperson.util
import android.content.Context
import android.location.LocationManager
import android.net.ConnectivityManager
import android.net.NetworkCapabilities
import android.os.Build
import java.text.SimpleDateFormat
import java.util.*
class SmartGreetingUtil(private val context: Context) {
// 中国主要节日(月,日)
private val festivals = mapOf(
Pair(1, 1) to "元旦",
Pair(2, 14) to "情人节",
Pair(3, 8) to "妇女节",
Pair(3, 12) to "植树节",
Pair(4, 1) to "愚人节",
Pair(5, 1) to "劳动节",
Pair(5, 4) to "青年节",
Pair(6, 1) to "儿童节",
Pair(7, 1) to "建党节",
Pair(8, 1) to "建军节",
Pair(9, 10) to "教师节",
Pair(10, 1) to "国庆节",
Pair(12, 25) to "圣诞节"
)
// 获取当前时间问候语
fun getTimeBasedGreeting(): String {
val hour = Calendar.getInstance().get(Calendar.HOUR_OF_DAY)
return when {
hour in 6..11 -> "早上好"
hour in 12..17 -> "下午好"
hour in 18..22 -> "晚上好"
else -> "夜深了"
}
}
// 检查今天是否是节日
fun isFestivalToday(): Pair<Boolean, String> {
val calendar = Calendar.getInstance()
val month = calendar.get(Calendar.MONTH) + 1 // 月份从0开始
val day = calendar.get(Calendar.DAY_OF_MONTH)
val festival = festivals[Pair(month, day)]
return Pair(festival != null, festival ?: "")
}
// 获取完整的问候语
fun getSmartGreeting(name: String?): String {
val (isFestival, festivalName) = isFestivalToday()
return if (isFestival) {
if (name.isNullOrBlank()) {
"小朋友好!今天是${festivalName},祝你${festivalName}快乐哦!"
} else {
"你好,${name}!今天是${festivalName},祝你${festivalName}快乐哦!"
}
} else {
val timeGreeting = getTimeBasedGreeting()
if (name.isNullOrBlank()) {
"${timeGreeting},很高兴见到你。请问你叫什么名字呀?"
} else {
"${timeGreeting}${name},很高兴见到你。"
}
}
}
// 检查网络连接
fun isNetworkConnected(): Boolean {
val connectivityManager = context.getSystemService(Context.CONNECTIVITY_SERVICE) as ConnectivityManager
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
val network = connectivityManager.activeNetwork ?: return false
val capabilities = connectivityManager.getNetworkCapabilities(network) ?: return false
return capabilities.hasCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
} else {
val networkInfo = connectivityManager.activeNetworkInfo ?: return false
return networkInfo.isConnected
}
}
// 获取当前日期和时间
fun getCurrentDateTime(): String {
val sdf = SimpleDateFormat("yyyy年MM月dd日 HH:mm", Locale.CHINA)
return sdf.format(Date())
}
// 生成节日问候语提示词用于本地LLM
fun getFestivalGreetingPrompt(festivalName: String, name: String?): String {
val basePrompt = "今天是$festivalName,请生成一条适合小学生的节日问候语,语气活泼友好,简短亲切。"
return if (name.isNullOrBlank()) {
basePrompt
} else {
"$basePrompt 问候对象叫$name"
}
}
}

View File

@@ -130,7 +130,7 @@
android:id="@+id/tts_mode_switch"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:checked="false" />
android:checked="true" />
</LinearLayout>
<LinearLayout

View File

@@ -0,0 +1,4 @@
<?xml version="1.0" encoding="utf-8"?>
<resources>
<!-- Latvian language resources -->
</resources>

View File

@@ -3,5 +3,5 @@
<string name="start">开始</string>
<string name="stop">结束</string>
<string name="hint">点击“开始”说话;识别后会请求大模型并用 TTS 播放回复。</string>
<string name="system_prompt">你是一名小学女老师喜欢回答学生的各种问题请简洁但温柔地回答每个回答不超过30字。在每次回复的最前面,用方括号标注你的心情,格式为[中性、悲伤、高兴、生气、恐惧、撒娇、震惊、厌恶],例如:[高兴]同学你好呀!请问有什么问题吗?</string>
<string name="system_prompt">你是一个特殊学校一年级的数字人老师,你的名字叫,小鱼老师,你的任务是教这些特殊学校的学生一些基础的生活常识。和这些小学生说话要有耐心,一定要讲明白,尽量用简短的语句、活泼的语气来回复。你可以和他们日常对话和《教材》相关的话题。在生成回复后,请你先检查一下内容是否符合我们约定的主题。请使用口语对话的形式跟学生聊天。在每次回复的最前面,用方括号标注你的心情,格式为[中性、悲伤、高兴、生气、恐惧、撒娇、震惊、厌恶],例如:[高兴]同学你好呀!请问有什么问题吗?</string>
</resources>