face regonition refine
This commit is contained in:
69
app/src/main/assets/questions.json
Normal file
69
app/src/main/assets/questions.json
Normal file
@@ -0,0 +1,69 @@
|
||||
{
|
||||
"version": 1,
|
||||
"questions": [
|
||||
{
|
||||
"content": "老师要去菜市场买菜,需要付4元钱。我给了老奶奶5元。需要找回多少钱?",
|
||||
"answer": "1",
|
||||
"subject": "数学",
|
||||
"grade": 1,
|
||||
"difficulty": 1
|
||||
},
|
||||
{
|
||||
"content": "老师去超市买了一个2块钱的苹果,一个3块钱的火龙果,一共要花多少钱呀?",
|
||||
"answer": "5",
|
||||
"subject": "数学",
|
||||
"grade": 1,
|
||||
"difficulty": 1
|
||||
},
|
||||
{
|
||||
"content": "10-4等于多少?",
|
||||
"answer": "6",
|
||||
"subject": "数学",
|
||||
"grade": 1,
|
||||
"difficulty": 1
|
||||
},
|
||||
{
|
||||
"content": "5+7等于多少?",
|
||||
"answer": "12",
|
||||
"subject": "数学",
|
||||
"grade": 1,
|
||||
"difficulty": 1
|
||||
},
|
||||
{
|
||||
"content": "8-3等于多少?",
|
||||
"answer": "5",
|
||||
"subject": "数学",
|
||||
"grade": 1,
|
||||
"difficulty": 1
|
||||
},
|
||||
{
|
||||
"content": "我们上完厕所应该干什么呀?",
|
||||
"answer": "洗手",
|
||||
"subject": "生活适应",
|
||||
"grade": 1,
|
||||
"difficulty": 1
|
||||
},
|
||||
{
|
||||
"content": "太阳从哪边升起呀?",
|
||||
"answer": "东方",
|
||||
"subject": "生活适应",
|
||||
"grade": 1,
|
||||
"difficulty": 1
|
||||
},
|
||||
{
|
||||
"content": "一年有几个季节呀?",
|
||||
"answer": "4",
|
||||
"subject": "生活适应",
|
||||
"grade": 1,
|
||||
"difficulty": 1
|
||||
},
|
||||
{
|
||||
"content": "你知道厕所怎么走吗?帮老师指个路",
|
||||
"answer": "4",
|
||||
"subject": "生活适应",
|
||||
"grade": 1,
|
||||
"difficulty": 1
|
||||
}
|
||||
|
||||
]
|
||||
}
|
||||
@@ -123,10 +123,12 @@ std::vector<RetinaFaceEngineRKNN::PriorBox> RetinaFaceEngineRKNN::buildPriors()
|
||||
bool RetinaFaceEngineRKNN::parseRetinaOutputs(
|
||||
rknn_output* outputs,
|
||||
std::vector<float>* locOut,
|
||||
std::vector<float>* scoreOut) const {
|
||||
std::vector<float>* scoreOut,
|
||||
std::vector<float>* landmarkOut) const {
|
||||
std::vector<std::vector<float>> locCandidates;
|
||||
std::vector<std::vector<float>> confCandidates2;
|
||||
std::vector<std::vector<float>> scoreCandidates1;
|
||||
std::vector<std::vector<float>> landmarkCandidates;
|
||||
|
||||
const int anchors8 = (inputSize_ / 8) * (inputSize_ / 8) * 2;
|
||||
const int anchors16 = (inputSize_ / 16) * (inputSize_ / 16) * 2;
|
||||
@@ -141,6 +143,9 @@ bool RetinaFaceEngineRKNN::parseRetinaOutputs(
|
||||
const int expectedConf8_1 = anchors8;
|
||||
const int expectedConf16_1 = anchors16;
|
||||
const int expectedConf32_1 = anchors32;
|
||||
const int expectedLmk8 = anchors8 * 10;
|
||||
const int expectedLmk16 = anchors16 * 10;
|
||||
const int expectedLmk32 = anchors32 * 10;
|
||||
|
||||
for (uint32_t i = 0; i < ioNum_.n_output; ++i) {
|
||||
const size_t elems = tensorElemCount(outputAttrs_[i]);
|
||||
@@ -162,10 +167,15 @@ bool RetinaFaceEngineRKNN::parseRetinaOutputs(
|
||||
scoreCandidates1.push_back(std::move(data));
|
||||
continue;
|
||||
}
|
||||
if (e == expectedLmk8 || e == expectedLmk16 || e == expectedLmk32 || e == totalAnchors * 10) {
|
||||
landmarkCandidates.push_back(std::move(data));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
locOut->clear();
|
||||
scoreOut->clear();
|
||||
landmarkOut->clear();
|
||||
|
||||
auto sortBySize = [](const std::vector<float>& a, const std::vector<float>& b) {
|
||||
return a.size() > b.size();
|
||||
@@ -173,6 +183,7 @@ bool RetinaFaceEngineRKNN::parseRetinaOutputs(
|
||||
std::sort(locCandidates.begin(), locCandidates.end(), sortBySize);
|
||||
std::sort(confCandidates2.begin(), confCandidates2.end(), sortBySize);
|
||||
std::sort(scoreCandidates1.begin(), scoreCandidates1.end(), sortBySize);
|
||||
std::sort(landmarkCandidates.begin(), landmarkCandidates.end(), sortBySize);
|
||||
|
||||
auto mergeLoc = [&]() -> bool {
|
||||
if (locCandidates.empty()) return false;
|
||||
@@ -246,16 +257,41 @@ bool RetinaFaceEngineRKNN::parseRetinaOutputs(
|
||||
return false;
|
||||
};
|
||||
|
||||
auto mergeLandmark = [&]() -> bool {
|
||||
if (landmarkCandidates.empty()) return false;
|
||||
if (landmarkCandidates.size() >= 3 &&
|
||||
static_cast<int>(landmarkCandidates[0].size()) == expectedLmk8 &&
|
||||
static_cast<int>(landmarkCandidates[1].size()) == expectedLmk16 &&
|
||||
static_cast<int>(landmarkCandidates[2].size()) == expectedLmk32) {
|
||||
landmarkOut->reserve(static_cast<size_t>(totalAnchors) * 10);
|
||||
landmarkOut->insert(landmarkOut->end(), landmarkCandidates[0].begin(), landmarkCandidates[0].end());
|
||||
landmarkOut->insert(landmarkOut->end(), landmarkCandidates[1].begin(), landmarkCandidates[1].end());
|
||||
landmarkOut->insert(landmarkOut->end(), landmarkCandidates[2].begin(), landmarkCandidates[2].end());
|
||||
return true;
|
||||
}
|
||||
for (const auto& c : landmarkCandidates) {
|
||||
if (static_cast<int>(c.size()) == totalAnchors * 10) {
|
||||
*landmarkOut = c;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
const bool locOk = mergeLoc();
|
||||
bool scoreOk = mergeScoreFrom2Class();
|
||||
if (!scoreOk) {
|
||||
scoreOk = mergeScoreFrom1Class();
|
||||
}
|
||||
const bool lmkOk = mergeLandmark();
|
||||
if (!locOk || !scoreOk) {
|
||||
LOGW("Unable to parse retina outputs, loc_candidates=%zu, conf2_candidates=%zu, conf1_candidates=%zu",
|
||||
locCandidates.size(), confCandidates2.size(), scoreCandidates1.size());
|
||||
LOGW("Unable to parse retina outputs, loc_candidates=%zu, conf2_candidates=%zu, conf1_candidates=%zu, lmk_candidates=%zu",
|
||||
locCandidates.size(), confCandidates2.size(), scoreCandidates1.size(), landmarkCandidates.size());
|
||||
return false;
|
||||
}
|
||||
if (!lmkOk) {
|
||||
LOGW("Retina outputs parsed without landmarks, lmk_candidates=%zu", landmarkCandidates.size());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -369,7 +405,8 @@ std::vector<float> RetinaFaceEngineRKNN::detect(
|
||||
|
||||
std::vector<float> loc;
|
||||
std::vector<float> scores;
|
||||
if (!parseRetinaOutputs(outputs.data(), &loc, &scores)) {
|
||||
std::vector<float> landmarks;
|
||||
if (!parseRetinaOutputs(outputs.data(), &loc, &scores, &landmarks)) {
|
||||
rknn_outputs_release(ctx_, ioNum_.n_output, outputs.data());
|
||||
return empty;
|
||||
}
|
||||
@@ -381,6 +418,7 @@ std::vector<float> RetinaFaceEngineRKNN::detect(
|
||||
rknn_outputs_release(ctx_, ioNum_.n_output, outputs.data());
|
||||
return empty;
|
||||
}
|
||||
const bool hasLandmarks = landmarks.size() >= anchorCount * 10;
|
||||
|
||||
std::vector<FaceCandidate> candidates;
|
||||
candidates.reserve(anchorCount / 8);
|
||||
@@ -405,6 +443,17 @@ std::vector<float> RetinaFaceEngineRKNN::detect(
|
||||
box.right = std::min(static_cast<float>(width), (cx + w * 0.5f) * width);
|
||||
box.bottom = std::min(static_cast<float>(height), (cy + h * 0.5f) * height);
|
||||
box.score = score;
|
||||
box.hasLandmarks = hasLandmarks;
|
||||
if (hasLandmarks) {
|
||||
for (int k = 0; k < 5; ++k) {
|
||||
const float lx = priors[i].cx + landmarks[i * 10 + k * 2 + 0] * kVariance0 * priors[i].w;
|
||||
const float ly = priors[i].cy + landmarks[i * 10 + k * 2 + 1] * kVariance0 * priors[i].h;
|
||||
box.landmarks[k * 2 + 0] = lx * width;
|
||||
box.landmarks[k * 2 + 1] = ly * height;
|
||||
}
|
||||
} else {
|
||||
box.landmarks.fill(-1.0f);
|
||||
}
|
||||
candidates.push_back(box);
|
||||
}
|
||||
|
||||
@@ -412,13 +461,16 @@ std::vector<float> RetinaFaceEngineRKNN::detect(
|
||||
|
||||
std::vector<FaceCandidate> filtered = nms(candidates, nmsThreshold_);
|
||||
std::vector<float> result;
|
||||
result.reserve(filtered.size() * 5);
|
||||
result.reserve(filtered.size() * 15);
|
||||
for (const auto& f : filtered) {
|
||||
result.push_back(f.left);
|
||||
result.push_back(f.top);
|
||||
result.push_back(f.right);
|
||||
result.push_back(f.bottom);
|
||||
result.push_back(f.score);
|
||||
for (int k = 0; k < 10; ++k) {
|
||||
result.push_back(f.hasLandmarks ? f.landmarks[k] : -1.0f);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <array>
|
||||
|
||||
#include "rknn_api.h"
|
||||
|
||||
@@ -30,6 +31,8 @@ private:
|
||||
float right;
|
||||
float bottom;
|
||||
float score;
|
||||
std::array<float, 10> landmarks;
|
||||
bool hasLandmarks = false;
|
||||
};
|
||||
|
||||
static size_t tensorElemCount(const rknn_tensor_attr& attr);
|
||||
@@ -40,7 +43,8 @@ private:
|
||||
bool parseRetinaOutputs(
|
||||
rknn_output* outputs,
|
||||
std::vector<float>* locOut,
|
||||
std::vector<float>* scoreOut) const;
|
||||
std::vector<float>* scoreOut,
|
||||
std::vector<float>* landmarkOut) const;
|
||||
|
||||
rknn_context ctx_ = 0;
|
||||
bool initialized_ = false;
|
||||
|
||||
@@ -6,7 +6,10 @@ import android.graphics.Bitmap
|
||||
import android.os.Bundle
|
||||
import android.util.Log
|
||||
import android.widget.Toast
|
||||
import androidx.appcompat.app.AlertDialog
|
||||
import androidx.camera.core.CameraSelector
|
||||
import com.digitalperson.engine.RetinaFaceEngineRKNN
|
||||
import com.digitalperson.face.FaceBox
|
||||
import androidx.camera.core.ImageAnalysis
|
||||
import androidx.camera.core.ImageProxy
|
||||
import androidx.camera.core.Preview
|
||||
@@ -28,13 +31,19 @@ import com.digitalperson.metrics.TraceManager
|
||||
import com.digitalperson.metrics.TraceSession
|
||||
import com.digitalperson.tts.TtsController
|
||||
import com.digitalperson.interaction.DigitalHumanInteractionController
|
||||
import com.digitalperson.data.DatabaseInitializer
|
||||
import com.digitalperson.interaction.InteractionActionHandler
|
||||
import com.digitalperson.interaction.InteractionState
|
||||
import com.digitalperson.interaction.UserMemoryStore
|
||||
import com.digitalperson.llm.LLMManager
|
||||
import com.digitalperson.llm.LLMManagerCallback
|
||||
import com.digitalperson.util.FileHelper
|
||||
import com.digitalperson.data.AppDatabase
|
||||
import com.digitalperson.data.entity.ChatMessage
|
||||
import com.digitalperson.interaction.ConversationBufferMemory
|
||||
import com.digitalperson.interaction.ConversationSummaryMemory
|
||||
import java.io.File
|
||||
import android.graphics.BitmapFactory
|
||||
import org.json.JSONObject
|
||||
import java.util.concurrent.ExecutorService
|
||||
import java.util.concurrent.Executors
|
||||
@@ -92,6 +101,8 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
private lateinit var faceDetectionPipeline: FaceDetectionPipeline
|
||||
private lateinit var interactionController: DigitalHumanInteractionController
|
||||
private lateinit var userMemoryStore: UserMemoryStore
|
||||
private lateinit var conversationBufferMemory: ConversationBufferMemory
|
||||
private lateinit var conversationSummaryMemory: ConversationSummaryMemory
|
||||
private var facePipelineReady: Boolean = false
|
||||
private var cameraProvider: ProcessCameraProvider? = null
|
||||
private lateinit var cameraAnalyzerExecutor: ExecutorService
|
||||
@@ -102,6 +113,8 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
private val recentConversationLines = ArrayList<String>()
|
||||
private var recentConversationDirty: Boolean = false
|
||||
private var lastFacePresent: Boolean = false
|
||||
private var lastFaceIdentityId: String? = null
|
||||
private var lastFaceRecognizedName: String? = null
|
||||
|
||||
override fun onRequestPermissionsResult(
|
||||
requestCode: Int,
|
||||
@@ -152,12 +165,20 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
speakingPlayerViewId = 0,
|
||||
live2dViewId = R.id.live2d_view
|
||||
)
|
||||
|
||||
|
||||
cameraPreviewView = findViewById(R.id.camera_preview)
|
||||
cameraPreviewView.implementationMode = PreviewView.ImplementationMode.COMPATIBLE
|
||||
faceOverlayView = findViewById(R.id.face_overlay)
|
||||
cameraAnalyzerExecutor = Executors.newSingleThreadExecutor()
|
||||
|
||||
// 初始化数据库
|
||||
val databaseInitializer = DatabaseInitializer(applicationContext)
|
||||
databaseInitializer.initialize()
|
||||
|
||||
userMemoryStore = UserMemoryStore(applicationContext)
|
||||
val database = AppDatabase.getInstance(applicationContext)
|
||||
conversationBufferMemory = ConversationBufferMemory(database)
|
||||
conversationSummaryMemory = ConversationSummaryMemory(database, llmManager)
|
||||
interactionController = DigitalHumanInteractionController(
|
||||
scope = ioScope,
|
||||
handler = object : InteractionActionHandler {
|
||||
@@ -165,8 +186,10 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
runOnUiThread {
|
||||
uiManager.appendToUi("\n[State] $state\n")
|
||||
}
|
||||
Log.i(TAG_ACTIVITY, "\n[State] $state\n")
|
||||
if (state == InteractionState.IDLE) {
|
||||
analyzeUserProfileInIdleIfNeeded()
|
||||
Log.i(TAG_ACTIVITY, "[analyze] done")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -205,6 +228,8 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
|
||||
override fun loadLatestThought(): String? = userMemoryStore.getLatestThought()
|
||||
|
||||
override fun loadRecentThoughts(timeRangeMs: Long): List<String> = userMemoryStore.getRecentThoughts(timeRangeMs)
|
||||
|
||||
override fun addToChatHistory(role: String, content: String) {
|
||||
appendConversationLine(role, content)
|
||||
}
|
||||
@@ -212,7 +237,14 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
override fun addAssistantMessageToCloudHistory(content: String) {
|
||||
cloudApiManager.addAssistantMessage(content)
|
||||
}
|
||||
}
|
||||
|
||||
override fun getRandomQuestion(faceId: String): String {
|
||||
// 从数据库获取该faceId未被问过的问题
|
||||
val question = userMemoryStore.getRandomUnansweredQuestion(faceId)
|
||||
return question?.content ?: "你喜欢什么颜色呀?"
|
||||
}
|
||||
},
|
||||
context = applicationContext
|
||||
)
|
||||
faceDetectionPipeline = FaceDetectionPipeline(
|
||||
context = applicationContext,
|
||||
@@ -220,10 +252,21 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
faceOverlayView.updateResult(result)
|
||||
},
|
||||
onPresenceChanged = { present, faceIdentityId, recognizedName ->
|
||||
if (present == lastFacePresent) return@FaceDetectionPipeline
|
||||
lastFacePresent = present
|
||||
Log.d(TAG_ACTIVITY, "present=$present, faceIdentityId=$faceIdentityId, recognized=$recognizedName")
|
||||
interactionController.onFacePresenceChanged(present, faceIdentityId, recognizedName)
|
||||
if (present != lastFacePresent) {
|
||||
lastFacePresent = present
|
||||
Log.d(TAG_ACTIVITY, "presence changed: present=$present")
|
||||
interactionController.onFacePresenceChanged(present)
|
||||
if (!present) {
|
||||
lastFaceIdentityId = null
|
||||
lastFaceRecognizedName = null
|
||||
}
|
||||
}
|
||||
if (present && (faceIdentityId != lastFaceIdentityId || recognizedName != lastFaceRecognizedName)) {
|
||||
lastFaceIdentityId = faceIdentityId
|
||||
lastFaceRecognizedName = recognizedName
|
||||
Log.d(TAG_ACTIVITY, "identity update: faceIdentityId=$faceIdentityId, recognized=$recognizedName")
|
||||
interactionController.onFaceIdentityUpdated(faceIdentityId, recognizedName)
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
@@ -261,7 +304,7 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
|
||||
try {
|
||||
val ttsModeSwitch = findViewById<android.widget.Switch>(R.id.tts_mode_switch)
|
||||
ttsModeSwitch.isChecked = false // 默认使用本地TTS
|
||||
ttsModeSwitch.isChecked = true // 默认使用本地TTS
|
||||
ttsModeSwitch.setOnCheckedChangeListener { _, isChecked ->
|
||||
ttsController.setUseQCloudTts(isChecked)
|
||||
uiManager.showToast("TTS模式已切换到${if (isChecked) "腾讯云" else "本地"}")
|
||||
@@ -297,10 +340,6 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
vadManager = VadManager(this)
|
||||
vadManager.setCallback(createVadCallback())
|
||||
|
||||
// 初始化本地 LLM(用于 memory 状态)
|
||||
initLLM()
|
||||
interactionController.start()
|
||||
|
||||
// 检查是否需要下载模型
|
||||
if (!FileHelper.isLocalLLMAvailable(this)) {
|
||||
// 显示下载进度对话框
|
||||
@@ -332,19 +371,31 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
Log.i(AppConfig.TAG, "Local LLM is available, enabling local LLM switch")
|
||||
// 显示本地 LLM 开关,并同步状态
|
||||
uiManager.showLLMSwitch(false)
|
||||
// 初始化本地 LLM
|
||||
initLLM()
|
||||
// 重新初始化 ConversationSummaryMemory
|
||||
conversationSummaryMemory = ConversationSummaryMemory(database, llmManager)
|
||||
// 启动交互控制器
|
||||
interactionController.start()
|
||||
// 下载完成后初始化其他组件
|
||||
initializeOtherComponents()
|
||||
}
|
||||
} else {
|
||||
Log.e(AppConfig.TAG, "Failed to download model files: $message")
|
||||
uiManager.showToast("模型下载失败: $message", Toast.LENGTH_LONG)
|
||||
// 显示错误弹窗,阻止应用继续运行
|
||||
showModelDownloadErrorDialog(message)
|
||||
}
|
||||
// 下载完成后初始化其他组件
|
||||
initializeOtherComponents()
|
||||
}
|
||||
}
|
||||
)
|
||||
} else {
|
||||
// 模型已存在,直接初始化其他组件
|
||||
// 模型已存在,初始化本地 LLM
|
||||
initLLM()
|
||||
// 重新初始化 ConversationSummaryMemory
|
||||
conversationSummaryMemory = ConversationSummaryMemory(database, llmManager)
|
||||
// 启动交互控制器
|
||||
interactionController.start()
|
||||
// 直接初始化其他组件
|
||||
initializeOtherComponents()
|
||||
// 显示本地 LLM 开关,并同步状态
|
||||
uiManager.showLLMSwitch(false)
|
||||
@@ -404,6 +455,304 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
ioScope.launch {
|
||||
asrManager.runAsrWorker()
|
||||
}
|
||||
|
||||
// 测试人脸识别(延迟执行,确保所有组件初始化完成)
|
||||
// ioScope.launch {
|
||||
// kotlinx.coroutines.delay(10000) // 等待3秒,确保所有组件初始化完成
|
||||
// runOnUiThread {
|
||||
// runFaceRecognitionTest()
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
/**
|
||||
* 显示模型下载错误弹窗,阻止应用继续运行
|
||||
*/
|
||||
private fun showModelDownloadErrorDialog(errorMessage: String) {
|
||||
AlertDialog.Builder(this)
|
||||
.setTitle("模型下载失败")
|
||||
.setMessage("本地 LLM 模型下载失败,应用无法正常运行。\n\n错误信息:$errorMessage\n\n请检查网络连接后重启应用。")
|
||||
.setCancelable(false)
|
||||
.setPositiveButton("退出应用") { _, _ ->
|
||||
finish()
|
||||
}
|
||||
.show()
|
||||
}
|
||||
|
||||
/**
|
||||
* 运行人脸识别相似度测试
|
||||
* 使用网络服务器上的测试图片
|
||||
*/
|
||||
private fun runFaceRecognitionTest() {
|
||||
Log.i(TAG_ACTIVITY, "Starting face recognition test...")
|
||||
uiManager.appendToUi("\n[测试] 开始人脸识别相似度测试...\n")
|
||||
|
||||
// 从服务器获取目录下的所有图片文件列表
|
||||
ioScope.launch {
|
||||
try {
|
||||
val imageUrls = fetchImageListFromServer("http://192.168.1.19:5000/api/face_test_images")
|
||||
|
||||
if (imageUrls.isEmpty()) {
|
||||
Log.e(AppConfig.TAG, "No images found in server directory")
|
||||
runOnUiThread {
|
||||
uiManager.appendToUi("\n[测试] 服务器目录中没有找到图片文件\n")
|
||||
}
|
||||
return@launch
|
||||
}
|
||||
|
||||
Log.i(AppConfig.TAG, "[测试]Found ${imageUrls.size} images: $imageUrls")
|
||||
runOnUiThread {
|
||||
uiManager.appendToUi("\n[测试] 发现 ${imageUrls.size} 张测试图片\n")
|
||||
}
|
||||
|
||||
val bitmaps = mutableListOf<Pair<String, Bitmap>>()
|
||||
|
||||
// 下载所有图片
|
||||
for (url in imageUrls) {
|
||||
Log.d(AppConfig.TAG, "[测试]Downloading test image: $url")
|
||||
val bitmap = downloadImage(url)
|
||||
if (bitmap != null) {
|
||||
val fileName = url.substringAfterLast("/")
|
||||
bitmaps.add(fileName to bitmap)
|
||||
Log.d(AppConfig.TAG, "[测试]Downloaded image $fileName successfully")
|
||||
} else {
|
||||
Log.e(AppConfig.TAG, "[测试]Failed to download image: $url")
|
||||
}
|
||||
}
|
||||
|
||||
if (bitmaps.size < 2) {
|
||||
Log.e(AppConfig.TAG, "[测试]Not enough test images downloaded")
|
||||
runOnUiThread {
|
||||
uiManager.appendToUi("\n[测试] 测试图片下载失败,无法进行测试\n")
|
||||
}
|
||||
return@launch
|
||||
}
|
||||
|
||||
// 对所有图片两两比较
|
||||
Log.i(AppConfig.TAG, "[测试]Starting similarity comparison for ${bitmaps.size} images...")
|
||||
for (i in 0 until bitmaps.size) {
|
||||
for (j in i + 1 until bitmaps.size) {
|
||||
val (fileName1, bitmap1) = bitmaps[i]
|
||||
val (fileName2, bitmap2) = bitmaps[j]
|
||||
|
||||
Log.d(AppConfig.TAG, "[测试]Comparing $fileName1 with $fileName2")
|
||||
|
||||
// 检测人脸
|
||||
val face1 = detectFace(bitmap1)
|
||||
val face2 = detectFace(bitmap2)
|
||||
|
||||
Log.d(AppConfig.TAG, "[测试]Face detection result: face1=$face1, face2=$face2")
|
||||
|
||||
if (face1 != null && face2 != null) {
|
||||
// 计算相似度
|
||||
Log.d(AppConfig.TAG, "[测试]Detected faces, calculating similarity...")
|
||||
val similarity = faceDetectionPipeline?.getRecognizer()?.testSimilarityBetween(
|
||||
bitmap1, face1, bitmap2, face2
|
||||
)
|
||||
val similarityRaw = faceDetectionPipeline?.getRecognizer()?.run {
|
||||
val emb1 = extractEmbedding(bitmap1, face1)
|
||||
val emb2 = extractEmbedding(bitmap2, face2)
|
||||
if (emb1.isNotEmpty() && emb2.isNotEmpty()) {
|
||||
var dot = 0f
|
||||
var n1 = 0f
|
||||
var n2 = 0f
|
||||
for (k in emb1.indices) {
|
||||
dot += emb1[k] * emb2[k]
|
||||
n1 += emb1[k] * emb1[k]
|
||||
n2 += emb2[k] * emb2[k]
|
||||
}
|
||||
if (n1 > 1e-12f && n2 > 1e-12f) {
|
||||
(dot / (kotlin.math.sqrt(n1) * kotlin.math.sqrt(n2))).coerceIn(-1f, 1f)
|
||||
} else -1f
|
||||
} else -1f
|
||||
}
|
||||
|
||||
Log.d(AppConfig.TAG, "[测试]Similarity result: $similarity")
|
||||
|
||||
if (similarity != null && similarity >= 0) {
|
||||
val message = "[测试] 图片 $fileName1 与 $fileName2 的相似度: $similarity"
|
||||
val compareMessage = "[测试] 对齐后=$similarity, 原始裁剪=$similarityRaw"
|
||||
Log.i(AppConfig.TAG, message)
|
||||
Log.i(AppConfig.TAG, compareMessage)
|
||||
runOnUiThread {
|
||||
uiManager.appendToUi("\n$message\n")
|
||||
uiManager.appendToUi("$compareMessage\n")
|
||||
}
|
||||
} else {
|
||||
Log.w(AppConfig.TAG, "[测试]Failed to calculate similarity: $similarity")
|
||||
runOnUiThread {
|
||||
uiManager.appendToUi("\n[测试] 计算相似度失败: $similarity\n")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
val message = "[测试] 无法检测到人脸: $fileName1 或 $fileName2"
|
||||
Log.w(AppConfig.TAG, message)
|
||||
runOnUiThread {
|
||||
uiManager.appendToUi("\n$message\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Log.i(AppConfig.TAG, "[测试]Face recognition test completed")
|
||||
runOnUiThread {
|
||||
uiManager.appendToUi("\n[测试] 人脸识别相似度测试完成\n")
|
||||
}
|
||||
|
||||
} catch (e: Exception) {
|
||||
Log.e(AppConfig.TAG, "Error during face recognition test: ${e.message}", e)
|
||||
runOnUiThread {
|
||||
uiManager.appendToUi("\n[测试] 测试过程中发生错误: ${e.message}\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 从服务器获取目录下的图片文件列表
|
||||
* 调用 API 接口获取图片列表
|
||||
*/
|
||||
private fun fetchImageListFromServer(apiUrl: String): List<String> {
|
||||
val imageUrls = mutableListOf<String>()
|
||||
|
||||
return try {
|
||||
// 调用 API 接口
|
||||
val connection = java.net.URL(apiUrl).openConnection() as java.net.HttpURLConnection
|
||||
connection.requestMethod = "GET"
|
||||
connection.connectTimeout = 10000
|
||||
connection.readTimeout = 10000
|
||||
connection.setRequestProperty("Accept", "application/json")
|
||||
|
||||
try {
|
||||
val responseCode = connection.responseCode
|
||||
if (responseCode == 200) {
|
||||
connection.inputStream.use { input ->
|
||||
val content = input.bufferedReader().use { it.readText() }
|
||||
Log.d(AppConfig.TAG, "API response: $content")
|
||||
|
||||
// 解析 JSON 响应
|
||||
val jsonObject = org.json.JSONObject(content)
|
||||
val imagesArray = jsonObject.getJSONArray("images")
|
||||
|
||||
// 构建完整的图片 URL
|
||||
val baseUrl = apiUrl.replace("/api/face_test_images", "/shared_files/face_test")
|
||||
for (i in 0 until imagesArray.length()) {
|
||||
val fileName = imagesArray.getString(i)
|
||||
val fullUrl = "$baseUrl/$fileName"
|
||||
imageUrls.add(fullUrl)
|
||||
Log.d(AppConfig.TAG, "Added image URL: $fullUrl")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Log.e(AppConfig.TAG, "API request failed with code: $responseCode")
|
||||
}
|
||||
} finally {
|
||||
connection.disconnect()
|
||||
}
|
||||
|
||||
imageUrls
|
||||
} catch (e: Exception) {
|
||||
Log.e(AppConfig.TAG, "Failed to fetch image list: ${e.message}", e)
|
||||
// 如果获取失败,返回空列表
|
||||
emptyList()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查 URL 是否存在
|
||||
*/
|
||||
private fun checkUrlExists(url: String): Boolean {
|
||||
return try {
|
||||
val connection = java.net.URL(url).openConnection() as java.net.HttpURLConnection
|
||||
connection.requestMethod = "HEAD"
|
||||
connection.connectTimeout = 3000
|
||||
connection.readTimeout = 3000
|
||||
val responseCode = connection.responseCode
|
||||
connection.disconnect()
|
||||
responseCode == 200
|
||||
} catch (e: Exception) {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 从网络下载图片
|
||||
*/
|
||||
private fun downloadImage(url: String): Bitmap? {
|
||||
return try {
|
||||
// 使用与大模型相同的下载方式
|
||||
val tempFile = File(cacheDir, "temp_test_image_${System.currentTimeMillis()}.jpg")
|
||||
val success = FileHelper.downloadTestImage(url, tempFile)
|
||||
|
||||
if (success && tempFile.exists()) {
|
||||
val bitmap = BitmapFactory.decodeFile(tempFile.absolutePath)
|
||||
tempFile.delete() // 删除临时文件
|
||||
bitmap
|
||||
} else {
|
||||
Log.e(AppConfig.TAG, "Failed to download image: $url")
|
||||
null
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.e(AppConfig.TAG, "Failed to download image: ${e.message}", e)
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 检测图片中的人脸
|
||||
*/
|
||||
private fun detectFace(bitmap: Bitmap): FaceBox? {
|
||||
Log.d(AppConfig.TAG, "[测试]Detecting face in bitmap: ${bitmap.width}x${bitmap.height}")
|
||||
return try {
|
||||
val engine = RetinaFaceEngineRKNN()
|
||||
Log.d(AppConfig.TAG, "[测试]Initializing RetinaFace engine...")
|
||||
if (engine.initialize(applicationContext)) {
|
||||
Log.d(AppConfig.TAG, "[测试]RetinaFace engine initialized successfully")
|
||||
val raw = engine.detect(bitmap)
|
||||
Log.d(AppConfig.TAG, "[测试]Face detection result: ${raw.joinToString(", ")}")
|
||||
engine.release()
|
||||
|
||||
if (raw.isNotEmpty()) {
|
||||
val stride = when {
|
||||
raw.size % 15 == 0 -> 15
|
||||
raw.size % 5 == 0 -> 5
|
||||
else -> 0
|
||||
}
|
||||
|
||||
Log.d(AppConfig.TAG, "[测试]Stride: $stride, raw size: ${raw.size}")
|
||||
|
||||
if (stride > 0) {
|
||||
val faceCount = raw.size / stride
|
||||
Log.d(AppConfig.TAG, "[测试]Detected $faceCount faces")
|
||||
if (faceCount > 0) {
|
||||
val i = 0
|
||||
val lm = if (stride >= 15) raw.copyOfRange(i + 5, i + 15) else null
|
||||
val hasLm = lm?.all { it >= 0f } == true
|
||||
|
||||
val faceBox = FaceBox(
|
||||
left = raw[i],
|
||||
top = raw[i + 1],
|
||||
right = raw[i + 2],
|
||||
bottom = raw[i + 3],
|
||||
score = raw[i + 4],
|
||||
hasLandmarks = hasLm,
|
||||
landmarks = if (hasLm) lm else null
|
||||
)
|
||||
Log.d(AppConfig.TAG, "[测试]Created face box: $faceBox")
|
||||
return faceBox
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Log.w(AppConfig.TAG, "[测试]No faces detected in bitmap")
|
||||
}
|
||||
} else {
|
||||
Log.e(AppConfig.TAG, "[测试]Failed to initialize RetinaFace engine")
|
||||
}
|
||||
null
|
||||
} catch (e: Exception) {
|
||||
Log.e(AppConfig.TAG, "[测试]Failed to detect face: ${e.message}", e)
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
private fun createAsrCallback() = object : AsrManager.AsrCallback {
|
||||
@@ -557,7 +906,6 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
try { cameraAnalyzerExecutor.shutdown() } catch (_: Throwable) {}
|
||||
try { ttsController.release() } catch (_: Throwable) {}
|
||||
try { llmManager?.destroy() } catch (_: Throwable) {}
|
||||
try { userMemoryStore.close() } catch (_: Throwable) {}
|
||||
try { uiManager.release() } catch (_: Throwable) {}
|
||||
try { audioProcessor.release() } catch (_: Throwable) {}
|
||||
}
|
||||
@@ -686,6 +1034,9 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
uiManager.appendToUi("\n[LOG] 已打断TTS播放\n")
|
||||
}
|
||||
|
||||
// 通知状态机用户开始说话,立即进入对话状态
|
||||
interactionController.onUserStartSpeaking()
|
||||
|
||||
if (!audioProcessor.initMicrophone(micPermissions, AppConfig.REQUEST_RECORD_AUDIO_PERMISSION)) {
|
||||
uiManager.showToast("麦克风初始化失败/无权限")
|
||||
return
|
||||
@@ -844,6 +1195,14 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
recentConversationLines.removeAt(0)
|
||||
}
|
||||
recentConversationDirty = true
|
||||
|
||||
// 同时添加到对话记忆中
|
||||
val memoryRole = if (role == "用户") "user" else "assistant"
|
||||
conversationBufferMemory.addMessage(activeUserId, memoryRole, text.trim())
|
||||
// 定期保存到数据库
|
||||
if (recentConversationLines.size % 5 == 0) {
|
||||
conversationBufferMemory.saveToDatabase(activeUserId)
|
||||
}
|
||||
}
|
||||
|
||||
private fun buildCloudPromptWithUserProfile(userText: String): String {
|
||||
@@ -854,6 +1213,13 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
profile.gender?.takeIf { it.isNotBlank() }?.let { profileParts.add("性别:$it") }
|
||||
profile.hobbies?.takeIf { it.isNotBlank() }?.let { profileParts.add("爱好:$it") }
|
||||
profile.profileSummary?.takeIf { it.isNotBlank() }?.let { profileParts.add("画像:$it") }
|
||||
|
||||
// 添加对话摘要
|
||||
val conversationSummary = conversationSummaryMemory.getSummary(activeUserId)
|
||||
if (conversationSummary.isNotBlank()) {
|
||||
profileParts.add("对话摘要:$conversationSummary")
|
||||
}
|
||||
|
||||
if (profileParts.isEmpty()) return userText
|
||||
return buildString {
|
||||
append("[用户画像]\n")
|
||||
@@ -864,9 +1230,26 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
}
|
||||
|
||||
private fun analyzeUserProfileInIdleIfNeeded() {
|
||||
if (!recentConversationDirty || !activeUserId.startsWith("face_")) return
|
||||
if (recentConversationLines.isEmpty()) return
|
||||
val dialogue = recentConversationLines.joinToString("\n")
|
||||
if (!activeUserId.startsWith("face_")) {
|
||||
Log.d(AppConfig.TAG, "faceID is not face_")
|
||||
return
|
||||
}
|
||||
|
||||
// 使用 conversationBufferMemory 获取对话消息
|
||||
val messages = conversationBufferMemory.getMessages(activeUserId)
|
||||
Log.d(AppConfig.TAG, "msg is empty? ${messages.isEmpty()}")
|
||||
val hasUserMessages = messages.any { it.role == "user" }
|
||||
Log.d(AppConfig.TAG, "msg has user messages? $hasUserMessages")
|
||||
if (messages.isEmpty() || !hasUserMessages) return
|
||||
|
||||
// 生成对话摘要
|
||||
conversationSummaryMemory.generateSummary(activeUserId, messages) { summary ->
|
||||
Log.d(AppConfig.TAG, "Generated conversation summary for $activeUserId: $summary")
|
||||
}
|
||||
|
||||
// 使用 conversationBufferMemory 的对话记录提取用户信息
|
||||
val dialogue = messages.joinToString("\n") { "${it.role}: ${it.content}" }
|
||||
|
||||
requestLocalProfileExtraction(dialogue) { raw ->
|
||||
try {
|
||||
val json = parseFirstJsonObject(raw)
|
||||
@@ -879,7 +1262,10 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
userMemoryStore.updateDisplayName(activeUserId, name)
|
||||
}
|
||||
userMemoryStore.updateProfile(activeUserId, age, gender, hobbies, summary)
|
||||
recentConversationDirty = false
|
||||
|
||||
// 清空已处理的对话记录
|
||||
conversationBufferMemory.clear(activeUserId)
|
||||
|
||||
runOnUiThread {
|
||||
uiManager.appendToUi("\n[Memory] 已更新用户画像: $activeUserId\n")
|
||||
}
|
||||
@@ -901,7 +1287,7 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
Log.i(TAG_LLM, "Routing profile extraction to LOCAL")
|
||||
local.generateResponseWithSystem(
|
||||
"你是信息抽取器。仅输出JSON对象,不要其他文字。字段为name,age,gender,hobbies,summary。",
|
||||
"请从以下对话提取用户信息,未知填空字符串:\n$dialogue"
|
||||
"请从以下对话提取用户信息,未知填空字符串,注意不需要:\n$dialogue"
|
||||
)
|
||||
} catch (e: Exception) {
|
||||
pendingLocalProfileCallback = null
|
||||
|
||||
@@ -50,7 +50,7 @@ object AppConfig {
|
||||
object FaceRecognition {
|
||||
const val MODEL_DIR = "Insightface"
|
||||
const val MODEL_NAME = "ms1mv3_arcface_r18.rknn"
|
||||
const val SIMILARITY_THRESHOLD = 0.5f
|
||||
const val SIMILARITY_THRESHOLD = 0.6f
|
||||
const val GREETING_COOLDOWN_MS = 6000L
|
||||
}
|
||||
|
||||
|
||||
54
app/src/main/java/com/digitalperson/data/AppDatabase.kt
Normal file
54
app/src/main/java/com/digitalperson/data/AppDatabase.kt
Normal file
@@ -0,0 +1,54 @@
|
||||
package com.digitalperson.data
|
||||
|
||||
import android.content.Context
|
||||
import androidx.room.Database
|
||||
import androidx.room.Room
|
||||
import androidx.room.RoomDatabase
|
||||
import com.digitalperson.data.dao.QuestionDao
|
||||
import com.digitalperson.data.dao.UserAnswerDao
|
||||
import com.digitalperson.data.dao.UserMemoryDao
|
||||
import com.digitalperson.data.dao.ChatMessageDao
|
||||
import com.digitalperson.data.dao.ConversationSummaryDao
|
||||
import com.digitalperson.data.entity.Question
|
||||
import com.digitalperson.data.entity.UserAnswer
|
||||
import com.digitalperson.data.entity.UserMemory
|
||||
import com.digitalperson.data.entity.ChatMessageEntity
|
||||
import com.digitalperson.data.entity.ConversationSummaryEntity
|
||||
|
||||
@Database(
|
||||
entities = [UserMemory::class, Question::class, UserAnswer::class, ChatMessageEntity::class, ConversationSummaryEntity::class],
|
||||
version = 4,
|
||||
exportSchema = false
|
||||
)
|
||||
abstract class AppDatabase : RoomDatabase() {
|
||||
abstract fun userMemoryDao(): UserMemoryDao
|
||||
abstract fun questionDao(): QuestionDao
|
||||
abstract fun userAnswerDao(): UserAnswerDao
|
||||
abstract fun chatMessageDao(): ChatMessageDao
|
||||
abstract fun conversationSummaryDao(): ConversationSummaryDao
|
||||
|
||||
companion object {
|
||||
private const val DATABASE_NAME = "digital_human.db"
|
||||
|
||||
@Volatile
|
||||
private var INSTANCE: AppDatabase? = null
|
||||
|
||||
fun getInstance(context: Context): AppDatabase {
|
||||
return INSTANCE ?: synchronized(this) {
|
||||
INSTANCE ?: buildDatabase(context).also {
|
||||
INSTANCE = it
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun buildDatabase(context: Context): AppDatabase {
|
||||
return Room.databaseBuilder(
|
||||
context.applicationContext,
|
||||
AppDatabase::class.java,
|
||||
DATABASE_NAME
|
||||
)
|
||||
.fallbackToDestructiveMigration()
|
||||
.build()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
package com.digitalperson.data
|
||||
|
||||
import android.content.Context
|
||||
import com.digitalperson.data.entity.Question
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.GlobalScope
|
||||
import kotlinx.coroutines.launch
|
||||
import org.json.JSONObject
|
||||
import java.io.IOException
|
||||
import java.io.InputStream
|
||||
|
||||
class DatabaseInitializer(private val context: Context) {
|
||||
fun initialize() {
|
||||
// 在后台线程初始化数据库
|
||||
GlobalScope.launch(Dispatchers.IO) {
|
||||
val db = AppDatabase.getInstance(context)
|
||||
val questionDao = db.questionDao()
|
||||
|
||||
// 检查是否已有问题数据
|
||||
if (questionDao.getAllSubjects().isEmpty()) {
|
||||
// 从JSON文件导入问题
|
||||
importQuestionsFromJson()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun importQuestionsFromJson() {
|
||||
try {
|
||||
// 读取JSON文件
|
||||
val inputStream: InputStream = context.assets.open("questions.json")
|
||||
val size: Int = inputStream.available()
|
||||
val buffer = ByteArray(size)
|
||||
inputStream.read(buffer)
|
||||
inputStream.close()
|
||||
|
||||
val jsonString = String(buffer, Charsets.UTF_8)
|
||||
val jsonObject = JSONObject(jsonString)
|
||||
val questionsArray = jsonObject.getJSONArray("questions")
|
||||
|
||||
val db = AppDatabase.getInstance(context)
|
||||
val questionDao = db.questionDao()
|
||||
|
||||
// 插入问题
|
||||
for (i in 0 until questionsArray.length()) {
|
||||
val questionJson = questionsArray.getJSONObject(i)
|
||||
val content = questionJson.getString("content")
|
||||
val answer = if (questionJson.isNull("answer")) null else questionJson.getString("answer")
|
||||
val subject = if (questionJson.isNull("subject")) null else questionJson.getString("subject")
|
||||
val grade = if (questionJson.isNull("grade")) null else questionJson.getInt("grade")
|
||||
val difficulty = if (questionJson.isNull("difficulty")) 1 else questionJson.getInt("difficulty")
|
||||
|
||||
val question = Question(
|
||||
id = 0,
|
||||
content = content,
|
||||
answer = answer,
|
||||
subject = subject,
|
||||
grade = grade,
|
||||
difficulty = difficulty,
|
||||
createdAt = System.currentTimeMillis()
|
||||
)
|
||||
|
||||
questionDao.insert(question)
|
||||
}
|
||||
|
||||
println("Successfully imported ${questionsArray.length()} questions from JSON")
|
||||
|
||||
} catch (e: IOException) {
|
||||
e.printStackTrace()
|
||||
println("Error reading questions.json: ${e.message}")
|
||||
} catch (e: Exception) {
|
||||
e.printStackTrace()
|
||||
println("Error importing questions: ${e.message}")
|
||||
}
|
||||
}
|
||||
|
||||
// 强制重新导入JSON数据
|
||||
fun forceImportQuestions() {
|
||||
GlobalScope.launch(Dispatchers.IO) {
|
||||
val db = AppDatabase.getInstance(context)
|
||||
// 清空数据库
|
||||
db.clearAllTables()
|
||||
// 重新导入
|
||||
importQuestionsFromJson()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
package com.digitalperson.data.dao
|
||||
|
||||
import androidx.room.Dao
|
||||
import androidx.room.Insert
|
||||
import androidx.room.Query
|
||||
import com.digitalperson.data.entity.ChatMessageEntity
|
||||
import com.digitalperson.data.entity.ConversationSummaryEntity
|
||||
|
||||
/**
|
||||
* 对话消息DAO接口
|
||||
*/
|
||||
@Dao
|
||||
interface ChatMessageDao {
|
||||
@Insert
|
||||
fun insert(message: ChatMessageEntity): Long
|
||||
|
||||
@Query("SELECT * FROM chat_messages WHERE user_id = :userId ORDER BY timestamp ASC")
|
||||
fun getMessagesByUserId(userId: String): List<ChatMessageEntity>
|
||||
|
||||
@Query("""
|
||||
SELECT * FROM chat_messages
|
||||
WHERE user_id = :userId AND timestamp >= :startTime
|
||||
ORDER BY timestamp ASC
|
||||
""")
|
||||
fun getMessagesByUserIdAndTime(userId: String, startTime: Long): List<ChatMessageEntity>
|
||||
|
||||
@Query("DELETE FROM chat_messages WHERE user_id = :userId")
|
||||
fun clearMessagesByUserId(userId: String)
|
||||
|
||||
@Query("SELECT COUNT(*) FROM chat_messages WHERE user_id = :userId")
|
||||
fun getMessageCountByUserId(userId: String): Int
|
||||
}
|
||||
|
||||
/**
|
||||
* 对话摘要DAO接口
|
||||
*/
|
||||
@Dao
|
||||
interface ConversationSummaryDao {
|
||||
@Insert
|
||||
fun insert(summary: ConversationSummaryEntity): Long
|
||||
|
||||
@Query("SELECT * FROM conversation_summaries WHERE user_id = :userId ORDER BY created_at DESC LIMIT 1")
|
||||
fun getLatestSummaryByUserId(userId: String): ConversationSummaryEntity?
|
||||
|
||||
@Query("DELETE FROM conversation_summaries WHERE user_id = :userId")
|
||||
fun clearSummariesByUserId(userId: String)
|
||||
}
|
||||
44
app/src/main/java/com/digitalperson/data/dao/QuestionDao.kt
Normal file
44
app/src/main/java/com/digitalperson/data/dao/QuestionDao.kt
Normal file
@@ -0,0 +1,44 @@
|
||||
package com.digitalperson.data.dao
|
||||
|
||||
import androidx.room.Dao
|
||||
import androidx.room.Insert
|
||||
import androidx.room.Query
|
||||
import com.digitalperson.data.entity.Question
|
||||
|
||||
@Dao
|
||||
interface QuestionDao {
|
||||
@Insert
|
||||
fun insert(question: Question): Long
|
||||
|
||||
@Query("SELECT * FROM questions WHERE subject = :subject ORDER BY difficulty")
|
||||
fun getQuestionsBySubject(subject: String): List<Question>
|
||||
|
||||
@Query("SELECT * FROM questions WHERE grade = :grade ORDER BY difficulty")
|
||||
fun getQuestionsByGrade(grade: Int): List<Question>
|
||||
|
||||
@Query("SELECT * FROM questions WHERE subject = :subject AND grade = :grade ORDER BY RANDOM() LIMIT 1")
|
||||
fun getRandomQuestionBySubjectAndGrade(subject: String, grade: Int): Question?
|
||||
|
||||
@Query("SELECT * FROM questions WHERE subject = :subject ORDER BY RANDOM() LIMIT 1")
|
||||
fun getRandomQuestionBySubject(subject: String): Question?
|
||||
|
||||
@Query("SELECT * FROM questions ORDER BY RANDOM() LIMIT 1")
|
||||
fun getRandomQuestion(): Question?
|
||||
|
||||
@Query("""
|
||||
SELECT q.*
|
||||
FROM questions q
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM user_answers ua
|
||||
WHERE ua.question_id = q.id AND ua.user_id = :userId
|
||||
)
|
||||
ORDER BY RANDOM() LIMIT 1
|
||||
""")
|
||||
fun getRandomUnansweredQuestion(userId: String): Question?
|
||||
|
||||
@Query("SELECT * FROM questions WHERE id = :questionId")
|
||||
fun getQuestionById(questionId: Long): Question?
|
||||
|
||||
@Query("SELECT DISTINCT subject FROM questions WHERE subject IS NOT NULL ORDER BY subject")
|
||||
fun getAllSubjects(): List<String>
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package com.digitalperson.data.dao
|
||||
|
||||
import androidx.room.Dao
|
||||
import androidx.room.Insert
|
||||
import androidx.room.Query
|
||||
import androidx.room.ColumnInfo
|
||||
import com.digitalperson.data.entity.UserAnswer
|
||||
|
||||
@Dao
|
||||
interface UserAnswerDao {
|
||||
@Insert
|
||||
fun insert(userAnswer: UserAnswer): Long
|
||||
|
||||
@Query("SELECT * FROM user_answers WHERE user_id = :userId ORDER BY answered_at DESC LIMIT :limit")
|
||||
fun getUserAnswers(userId: String, limit: Int = 50): List<UserAnswer>
|
||||
|
||||
@Query("""
|
||||
SELECT ua.id, ua.user_id as userId, ua.question_id as questionId, ua.user_answer as userAnswer, ua.evaluation, ua.answered_at as answeredAt, q.content as question_content, q.answer as question_answer
|
||||
FROM user_answers ua
|
||||
JOIN questions q ON ua.question_id = q.id
|
||||
WHERE ua.user_id = :userId AND ua.evaluation = 'INCORRECT'
|
||||
ORDER BY ua.answered_at DESC
|
||||
LIMIT :limit
|
||||
""")
|
||||
fun getIncorrectAnswers(userId: String, limit: Int = 10): List<UserAnswerWithQuestion>
|
||||
|
||||
@Query("""
|
||||
SELECT evaluation, COUNT(*) as count
|
||||
FROM user_answers
|
||||
WHERE user_id = :userId
|
||||
GROUP BY evaluation
|
||||
""")
|
||||
fun getAnswerStatistics(userId: String): List<AnswerStatistic>
|
||||
}
|
||||
|
||||
// 用于存储答案统计结果
|
||||
data class AnswerStatistic(
|
||||
val evaluation: String?,
|
||||
val count: Int
|
||||
)
|
||||
|
||||
// 用于关联查询用户答案和问题
|
||||
data class UserAnswerWithQuestion(
|
||||
val id: Long,
|
||||
val userId: String,
|
||||
val questionId: Long,
|
||||
val userAnswer: String?,
|
||||
val evaluation: String?,
|
||||
val answeredAt: Long,
|
||||
@ColumnInfo(name = "question_content") val question_content: String,
|
||||
@ColumnInfo(name = "question_answer") val question_answer: String?
|
||||
)
|
||||
@@ -0,0 +1,31 @@
|
||||
package com.digitalperson.data.dao
|
||||
|
||||
import androidx.room.Dao
|
||||
import androidx.room.Insert
|
||||
import androidx.room.OnConflictStrategy
|
||||
import androidx.room.Query
|
||||
import com.digitalperson.data.entity.UserMemory
|
||||
|
||||
@Dao
|
||||
interface UserMemoryDao {
|
||||
@Insert(onConflict = OnConflictStrategy.REPLACE)
|
||||
fun insert(userMemory: UserMemory)
|
||||
|
||||
@Query("SELECT * FROM user_memory WHERE user_id = :userId")
|
||||
fun getUserMemory(userId: String): UserMemory?
|
||||
|
||||
@Query("SELECT last_thought FROM user_memory WHERE last_thought IS NOT NULL AND last_thought != '' ORDER BY last_seen_at DESC LIMIT 1")
|
||||
fun getLatestThought(): String?
|
||||
|
||||
@Query("UPDATE user_memory SET last_thought = :thought, last_seen_at = :timestamp WHERE user_id = :userId")
|
||||
fun updateThought(userId: String, thought: String, timestamp: Long)
|
||||
|
||||
@Query("UPDATE user_memory SET display_name = :displayName, last_seen_at = :timestamp WHERE user_id = :userId")
|
||||
fun updateDisplayName(userId: String, displayName: String, timestamp: Long)
|
||||
|
||||
@Query("UPDATE user_memory SET age = :age, gender = :gender, hobbies = :hobbies, profile_summary = :summary, last_seen_at = :timestamp WHERE user_id = :userId")
|
||||
fun updateProfile(userId: String, age: String?, gender: String?, hobbies: String?, summary: String?, timestamp: Long)
|
||||
|
||||
@Query("SELECT * FROM user_memory WHERE last_thought IS NOT NULL AND last_thought != '' AND last_seen_at >= :cutoffTime ORDER BY last_seen_at DESC")
|
||||
fun getRecentThoughts(cutoffTime: Long): List<UserMemory>
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package com.digitalperson.data.entity
|
||||
|
||||
import androidx.room.Entity
|
||||
import androidx.room.PrimaryKey
|
||||
import androidx.room.ColumnInfo
|
||||
|
||||
/**
|
||||
* 对话消息数据类
|
||||
* @param id 消息ID
|
||||
* @param userId 用户ID(如 face_1)
|
||||
* @param role 角色(user 或 assistant)
|
||||
* @param content 消息内容
|
||||
* @param timestamp 时间戳
|
||||
*/
|
||||
data class ChatMessage(
|
||||
val id: Long = 0,
|
||||
val userId: String,
|
||||
val role: String,
|
||||
val content: String,
|
||||
val timestamp: Long = System.currentTimeMillis()
|
||||
)
|
||||
|
||||
/**
|
||||
* 对话消息数据库实体
|
||||
*/
|
||||
@Entity(tableName = "chat_messages")
|
||||
data class ChatMessageEntity(
|
||||
@PrimaryKey(autoGenerate = true) val id: Long,
|
||||
@ColumnInfo(name = "user_id") val userId: String,
|
||||
@ColumnInfo(name = "role") val role: String,
|
||||
@ColumnInfo(name = "content") val content: String,
|
||||
@ColumnInfo(name = "timestamp") val timestamp: Long
|
||||
)
|
||||
|
||||
/**
|
||||
* 对话摘要数据库实体
|
||||
*/
|
||||
@Entity(tableName = "conversation_summaries")
|
||||
data class ConversationSummaryEntity(
|
||||
@PrimaryKey(autoGenerate = true) val id: Long,
|
||||
@ColumnInfo(name = "user_id") val userId: String,
|
||||
@ColumnInfo(name = "summary") val summary: String,
|
||||
@ColumnInfo(name = "created_at") val createdAt: Long
|
||||
)
|
||||
16
app/src/main/java/com/digitalperson/data/entity/Question.kt
Normal file
16
app/src/main/java/com/digitalperson/data/entity/Question.kt
Normal file
@@ -0,0 +1,16 @@
|
||||
package com.digitalperson.data.entity
|
||||
|
||||
import androidx.room.Entity
|
||||
import androidx.room.PrimaryKey
|
||||
import androidx.room.ColumnInfo
|
||||
|
||||
@Entity(tableName = "questions")
|
||||
data class Question(
|
||||
@PrimaryKey(autoGenerate = true) val id: Long,
|
||||
@ColumnInfo(name = "content") val content: String,
|
||||
@ColumnInfo(name = "answer") val answer: String?,
|
||||
@ColumnInfo(name = "subject") val subject: String?,
|
||||
@ColumnInfo(name = "grade") val grade: Int?,
|
||||
@ColumnInfo(name = "difficulty") val difficulty: Int,
|
||||
@ColumnInfo(name = "created_at") val createdAt: Long,
|
||||
)
|
||||
@@ -0,0 +1,15 @@
|
||||
package com.digitalperson.data.entity
|
||||
|
||||
import androidx.room.Entity
|
||||
import androidx.room.PrimaryKey
|
||||
import androidx.room.ColumnInfo
|
||||
|
||||
@Entity(tableName = "user_answers")
|
||||
data class UserAnswer(
|
||||
@PrimaryKey(autoGenerate = true) val id: Long,
|
||||
@ColumnInfo(name = "user_id") val userId: String,
|
||||
@ColumnInfo(name = "question_id") val questionId: Long,
|
||||
@ColumnInfo(name = "user_answer") val userAnswer: String?,
|
||||
@ColumnInfo(name = "evaluation") val evaluation: String?,
|
||||
@ColumnInfo(name = "answered_at") val answeredAt: Long,
|
||||
)
|
||||
@@ -0,0 +1,19 @@
|
||||
package com.digitalperson.data.entity
|
||||
|
||||
import androidx.room.Entity
|
||||
import androidx.room.PrimaryKey
|
||||
import androidx.room.ColumnInfo
|
||||
|
||||
@Entity(tableName = "user_memory")
|
||||
data class UserMemory(
|
||||
@PrimaryKey @ColumnInfo(name = "user_id") val userId: String,
|
||||
@ColumnInfo(name = "display_name") val displayName: String?,
|
||||
@ColumnInfo(name = "last_seen_at") val lastSeenAt: Long,
|
||||
@ColumnInfo(name = "age") val age: String?,
|
||||
@ColumnInfo(name = "gender") val gender: String?,
|
||||
@ColumnInfo(name = "hobbies") val hobbies: String?,
|
||||
@ColumnInfo(name = "preferences") val preferences: String?,
|
||||
@ColumnInfo(name = "last_topics") val lastTopics: String?,
|
||||
@ColumnInfo(name = "last_thought") val lastThought: String?,
|
||||
@ColumnInfo(name = "profile_summary") val profileSummary: String?,
|
||||
)
|
||||
@@ -5,8 +5,8 @@ import android.graphics.Bitmap
|
||||
import android.util.Log
|
||||
import com.digitalperson.config.AppConfig
|
||||
import com.digitalperson.engine.RetinaFaceEngineRKNN
|
||||
import java.util.ArrayDeque
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
import kotlin.math.abs
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.SupervisorJob
|
||||
@@ -20,6 +20,8 @@ data class FaceBox(
|
||||
val right: Float,
|
||||
val bottom: Float,
|
||||
val score: Float,
|
||||
val hasLandmarks: Boolean = false,
|
||||
val landmarks: FloatArray? = null,
|
||||
)
|
||||
|
||||
data class FaceDetectionResult(
|
||||
@@ -29,24 +31,27 @@ data class FaceDetectionResult(
|
||||
)
|
||||
|
||||
class FaceDetectionPipeline(
|
||||
private val context: Context,
|
||||
context: Context,
|
||||
private val onResult: (FaceDetectionResult) -> Unit,
|
||||
private val onPresenceChanged: (present: Boolean, faceIdentityId: String?, recognizedName: String?) -> Unit,
|
||||
) {
|
||||
private val appContext = context.applicationContext
|
||||
private val engine = RetinaFaceEngineRKNN()
|
||||
private val recognizer = FaceRecognizer(context)
|
||||
private val recognizer = FaceRecognizer(appContext)
|
||||
private val userMemoryStore = com.digitalperson.interaction.UserMemoryStore(appContext)
|
||||
private val scope = CoroutineScope(SupervisorJob() + Dispatchers.Default)
|
||||
private val frameInFlight = AtomicBoolean(false)
|
||||
private val initialized = AtomicBoolean(false)
|
||||
private var trackFace: FaceBox? = null
|
||||
private var trackId: Long = 0
|
||||
private var trackStableSinceMs: Long = 0
|
||||
private var analyzedTrackId: Long = -1
|
||||
private var lastFaceIdentityId: String? = null
|
||||
private var lastRecognizedName: String? = null
|
||||
private val fusionEmbeddings = ArrayDeque<FloatArray>()
|
||||
private val fusionQualities = ArrayDeque<Float>()
|
||||
|
||||
fun initialize(): Boolean {
|
||||
val detectorOk = engine.initialize(context)
|
||||
val detectorOk = engine.initialize(appContext)
|
||||
val recognizerOk = recognizer.initialize()
|
||||
val ok = detectorOk && recognizerOk
|
||||
initialized.set(ok)
|
||||
@@ -69,11 +74,28 @@ class FaceDetectionPipeline(
|
||||
val width = bitmap.width
|
||||
val height = bitmap.height
|
||||
val raw = engine.detect(bitmap)
|
||||
if (raw.isEmpty()) {
|
||||
withContext(Dispatchers.Main) {
|
||||
onPresenceChanged(false, null, null)
|
||||
onResult(FaceDetectionResult(width, height, emptyList()))
|
||||
}
|
||||
return@launch
|
||||
}
|
||||
|
||||
val faceCount = raw.size / 5
|
||||
val stride = when {
|
||||
raw.size % 15 == 0 -> 15
|
||||
raw.size % 5 == 0 -> 5
|
||||
else -> 0
|
||||
}
|
||||
if (stride == 0) {
|
||||
Log.w(AppConfig.TAG, "[Face] invalid detector output size=${raw.size}")
|
||||
}
|
||||
val faceCount = if (stride == 0) 0 else raw.size / stride
|
||||
val faces = ArrayList<FaceBox>(faceCount)
|
||||
var i = 0
|
||||
while (i + 4 < raw.size) {
|
||||
while (i + 4 < raw.size && stride > 0) {
|
||||
val lm = if (stride >= 15) raw.copyOfRange(i + 5, i + 15) else null
|
||||
val hasLm = lm?.all { it >= 0f } == true
|
||||
faces.add(
|
||||
FaceBox(
|
||||
left = raw[i],
|
||||
@@ -81,9 +103,11 @@ class FaceDetectionPipeline(
|
||||
right = raw[i + 2],
|
||||
bottom = raw[i + 3],
|
||||
score = raw[i + 4],
|
||||
hasLandmarks = hasLm,
|
||||
landmarks = if (hasLm) lm else null,
|
||||
)
|
||||
)
|
||||
i += 5
|
||||
i += stride
|
||||
}
|
||||
// 过滤太小的人脸
|
||||
val minFaceSize = 50 // 最小人脸大小(像素)
|
||||
@@ -117,10 +141,11 @@ class FaceDetectionPipeline(
|
||||
val now = System.currentTimeMillis()
|
||||
if (faces.isEmpty()) {
|
||||
trackFace = null
|
||||
trackStableSinceMs = 0
|
||||
analyzedTrackId = -1
|
||||
lastFaceIdentityId = null
|
||||
lastRecognizedName = null
|
||||
fusionEmbeddings.clear()
|
||||
fusionQualities.clear()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -128,51 +153,132 @@ class FaceDetectionPipeline(
|
||||
val prev = trackFace
|
||||
if (prev == null || iou(prev, primary) < AppConfig.Face.TRACK_IOU_THRESHOLD) {
|
||||
trackId += 1
|
||||
trackStableSinceMs = now
|
||||
analyzedTrackId = -1
|
||||
lastFaceIdentityId = null
|
||||
lastRecognizedName = null
|
||||
fusionEmbeddings.clear()
|
||||
fusionQualities.clear()
|
||||
}
|
||||
trackFace = primary
|
||||
|
||||
val stableMs = now - trackStableSinceMs
|
||||
val frontal = isFrontal(primary, bitmap.width, bitmap.height)
|
||||
if (stableMs < AppConfig.Face.STABLE_MS || !frontal) {
|
||||
val quality = estimateQuality(primary, bitmap.width, bitmap.height)
|
||||
// Log.e(AppConfig.TAG, "estimateQuality: ${quality}")
|
||||
// 识别尽早进行:只要人脸清晰且朝向满足,就先完成 faceId 解析;问候稳定时机由状态机控制。
|
||||
if (!frontal || quality < 0.65f) {
|
||||
return
|
||||
}
|
||||
if (analyzedTrackId == trackId) {
|
||||
return
|
||||
}
|
||||
|
||||
val match = recognizer.resolveIdentity(bitmap, primary)
|
||||
val embedding = recognizer.extractEmbeddingAligned(bitmap, primary)
|
||||
if (embedding.isEmpty()) return
|
||||
addEmbeddingForFusion(embedding, quality)
|
||||
val fused = fuseEmbeddings() ?: return
|
||||
if (fusionEmbeddings.size < 4) {
|
||||
return
|
||||
}
|
||||
val match = recognizer.resolveIdentityFromEmbedding(fused)
|
||||
analyzedTrackId = trackId
|
||||
lastFaceIdentityId = match.matchedId?.let { "face_$it" }
|
||||
lastRecognizedName = match.matchedName
|
||||
// 从 user_memory 表中获取名字
|
||||
lastRecognizedName = lastFaceIdentityId?.let { userId ->
|
||||
userMemoryStore.getMemory(userId)?.displayName
|
||||
}
|
||||
Log.i(
|
||||
AppConfig.TAG,
|
||||
"[Face] stable track=$trackId faceId=${lastFaceIdentityId} matched=${match.matchedName} score=${match.similarity}"
|
||||
"[Face] stable track=$trackId faceId=${lastFaceIdentityId} matched=${lastRecognizedName} score=${match.similarity} fusionN=${fusionEmbeddings.size}"
|
||||
)
|
||||
}
|
||||
|
||||
private fun addEmbeddingForFusion(embedding: FloatArray, quality: Float) {
|
||||
if (fusionEmbeddings.size >= 8) {
|
||||
fusionEmbeddings.removeFirst()
|
||||
fusionQualities.removeFirst()
|
||||
}
|
||||
fusionEmbeddings.addLast(embedding)
|
||||
fusionQualities.addLast(quality.coerceIn(0.1f, 1f))
|
||||
}
|
||||
|
||||
private fun fuseEmbeddings(): FloatArray? {
|
||||
if (fusionEmbeddings.isEmpty()) return null
|
||||
val dim = fusionEmbeddings.first().size
|
||||
if (dim == 0) return null
|
||||
val out = FloatArray(dim)
|
||||
var wsum = 0f
|
||||
val embIter = fusionEmbeddings.iterator()
|
||||
val qIter = fusionQualities.iterator()
|
||||
while (embIter.hasNext() && qIter.hasNext()) {
|
||||
val e = embIter.next()
|
||||
val w = qIter.next()
|
||||
if (e.size != dim) continue
|
||||
for (i in 0 until dim) out[i] += e[i] * w
|
||||
wsum += w
|
||||
}
|
||||
if (wsum <= 1e-6f) return null
|
||||
for (i in out.indices) out[i] /= wsum
|
||||
var n = 0f
|
||||
for (v in out) n += v * v
|
||||
val norm = kotlin.math.sqrt(n.coerceAtLeast(1e-12f))
|
||||
for (i in out.indices) out[i] /= norm
|
||||
return out
|
||||
}
|
||||
|
||||
private fun isFrontal(face: FaceBox, frameW: Int, frameH: Int): Boolean {
|
||||
if (!face.hasLandmarks || face.landmarks == null || face.landmarks.size < 10) {
|
||||
return false
|
||||
}
|
||||
val lm = face.landmarks
|
||||
val w = face.right - face.left
|
||||
val h = face.bottom - face.top
|
||||
if (w < AppConfig.Face.FRONTAL_MIN_FACE_SIZE || h < AppConfig.Face.FRONTAL_MIN_FACE_SIZE) {
|
||||
return false
|
||||
}
|
||||
val aspectDiff = abs((w / h) - 1f)
|
||||
if (aspectDiff > AppConfig.Face.FRONTAL_MAX_ASPECT_DIFF) {
|
||||
return false
|
||||
}
|
||||
|
||||
val leftEyeX = lm[0]
|
||||
val leftEyeY = lm[1]
|
||||
val rightEyeX = lm[2]
|
||||
val rightEyeY = lm[3]
|
||||
val noseX = lm[4]
|
||||
val noseY = lm[5]
|
||||
val mouthLeftY = lm[7]
|
||||
val mouthRightY = lm[9]
|
||||
if (rightEyeX <= leftEyeX) return false
|
||||
if (noseY <= (leftEyeY + rightEyeY) * 0.5f) return false
|
||||
if ((mouthLeftY + mouthRightY) * 0.5f <= noseY) return false
|
||||
|
||||
val eyeDy = kotlin.math.abs(leftEyeY - rightEyeY) / h
|
||||
if (eyeDy > 0.12f) return false
|
||||
|
||||
val eyeMidX = (leftEyeX + rightEyeX) * 0.5f
|
||||
val noseOffset = kotlin.math.abs(noseX - eyeMidX) / w
|
||||
if (noseOffset > 0.12f) return false
|
||||
|
||||
val cx = (face.left + face.right) * 0.5f
|
||||
val cy = (face.top + face.bottom) * 0.5f
|
||||
val minX = frameW * 0.15f
|
||||
val maxX = frameW * 0.85f
|
||||
val minY = frameH * 0.15f
|
||||
val maxY = frameH * 0.85f
|
||||
val minX = frameW * 0.2f
|
||||
val maxX = frameW * 0.8f
|
||||
val minY = frameH * 0.2f
|
||||
val maxY = frameH * 0.8f
|
||||
|
||||
return cx in minX..maxX && cy in minY..maxY
|
||||
}
|
||||
|
||||
private fun estimateQuality(face: FaceBox, frameW: Int, frameH: Int): Float {
|
||||
val lm = face.landmarks ?: return 0f
|
||||
if (!face.hasLandmarks || lm.size < 10) return 0f
|
||||
val w = (face.right - face.left).coerceAtLeast(1f)
|
||||
val h = (face.bottom - face.top).coerceAtLeast(1f)
|
||||
val areaRatio = ((w * h) / (frameW * frameH).toFloat()).coerceIn(0f, 1f)
|
||||
val eyeDy = kotlin.math.abs(lm[1] - lm[3]) / h
|
||||
val eyeMidX = (lm[0] + lm[2]) * 0.5f
|
||||
val noseOffset = kotlin.math.abs(lm[4] - eyeMidX) / w
|
||||
val geom = (1f - (eyeDy / 0.2f)).coerceIn(0f, 1f) * (1f - (noseOffset / 0.2f)).coerceIn(0f, 1f)
|
||||
val sizeScore = (areaRatio / 0.08f).coerceIn(0f, 1f)
|
||||
return (face.score.coerceIn(0f, 1f) * 0.45f) + (sizeScore * 0.2f) + (geom * 0.35f)
|
||||
}
|
||||
|
||||
private fun iou(a: FaceBox, b: FaceBox): Float {
|
||||
val left = maxOf(a.left, b.left)
|
||||
val top = maxOf(a.top, b.top)
|
||||
@@ -193,4 +299,8 @@ class FaceDetectionPipeline(
|
||||
recognizer.release()
|
||||
initialized.set(false)
|
||||
}
|
||||
|
||||
fun getRecognizer(): FaceRecognizer {
|
||||
return recognizer
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import java.nio.ByteOrder
|
||||
|
||||
data class FaceProfile(
|
||||
val id: Long,
|
||||
val name: String?,
|
||||
val embedding: FloatArray,
|
||||
)
|
||||
|
||||
@@ -21,7 +20,6 @@ class FaceFeatureStore(context: Context) : SQLiteOpenHelper(context, DB_NAME, nu
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS face_profiles (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT,
|
||||
embedding BLOB NOT NULL,
|
||||
updated_at INTEGER NOT NULL
|
||||
)
|
||||
@@ -37,16 +35,14 @@ class FaceFeatureStore(context: Context) : SQLiteOpenHelper(context, DB_NAME, nu
|
||||
fun loadAllProfiles(): List<FaceProfile> {
|
||||
val db = readableDatabase
|
||||
val list = ArrayList<FaceProfile>()
|
||||
db.rawQuery("SELECT id, name, embedding FROM face_profiles", null).use { c ->
|
||||
db.rawQuery("SELECT id, embedding FROM face_profiles", null).use { c ->
|
||||
val idIdx = c.getColumnIndexOrThrow("id")
|
||||
val nameIdx = c.getColumnIndexOrThrow("name")
|
||||
val embIdx = c.getColumnIndexOrThrow("embedding")
|
||||
while (c.moveToNext()) {
|
||||
val embBlob = c.getBlob(embIdx) ?: continue
|
||||
list.add(
|
||||
FaceProfile(
|
||||
id = c.getLong(idIdx),
|
||||
name = c.getString(nameIdx),
|
||||
embedding = blobToFloatArray(embBlob),
|
||||
)
|
||||
)
|
||||
@@ -55,10 +51,8 @@ class FaceFeatureStore(context: Context) : SQLiteOpenHelper(context, DB_NAME, nu
|
||||
return list
|
||||
}
|
||||
|
||||
fun insertProfile(name: String?, embedding: FloatArray): Long {
|
||||
val safeName = name?.takeIf { it.isNotBlank() }
|
||||
fun insertProfile(embedding: FloatArray): Long {
|
||||
val values = ContentValues().apply {
|
||||
put("name", safeName)
|
||||
put("embedding", floatArrayToBlob(embedding))
|
||||
put("updated_at", System.currentTimeMillis())
|
||||
}
|
||||
@@ -68,7 +62,7 @@ class FaceFeatureStore(context: Context) : SQLiteOpenHelper(context, DB_NAME, nu
|
||||
values,
|
||||
SQLiteDatabase.CONFLICT_NONE
|
||||
)
|
||||
Log.i(AppConfig.TAG, "[FaceFeatureStore] insertProfile name='$safeName' rowId=$rowId dim=${embedding.size}")
|
||||
Log.i(AppConfig.TAG, "[FaceFeatureStore] insertProfile rowId=$rowId dim=${embedding.size}")
|
||||
return rowId
|
||||
}
|
||||
|
||||
@@ -88,6 +82,6 @@ class FaceFeatureStore(context: Context) : SQLiteOpenHelper(context, DB_NAME, nu
|
||||
|
||||
companion object {
|
||||
private const val DB_NAME = "face_feature.db"
|
||||
private const val DB_VERSION = 2
|
||||
private const val DB_VERSION = 3
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ package com.digitalperson.face
|
||||
|
||||
import android.content.Context
|
||||
import android.graphics.Bitmap
|
||||
import android.graphics.Canvas
|
||||
import android.graphics.Matrix
|
||||
import android.util.Log
|
||||
import com.digitalperson.config.AppConfig
|
||||
import com.digitalperson.engine.ArcFaceEngineRKNN
|
||||
@@ -9,7 +11,6 @@ import kotlin.math.sqrt
|
||||
|
||||
data class FaceRecognitionResult(
|
||||
val matchedId: Long?,
|
||||
val matchedName: String?,
|
||||
val similarity: Float,
|
||||
val embeddingDim: Int,
|
||||
)
|
||||
@@ -41,12 +42,11 @@ class FaceRecognizer(context: Context) {
|
||||
}
|
||||
|
||||
fun identify(bitmap: Bitmap, face: FaceBox): FaceRecognitionResult {
|
||||
if (!initialized) return FaceRecognitionResult(null, null, 0f, 0)
|
||||
if (!initialized) return FaceRecognitionResult(null, 0f, 0)
|
||||
val embedding = extractEmbedding(bitmap, face)
|
||||
if (embedding.isEmpty()) return FaceRecognitionResult(null, null, 0f, 0)
|
||||
if (embedding.isEmpty()) return FaceRecognitionResult(null, 0f, 0)
|
||||
|
||||
var bestId: Long? = null
|
||||
var bestName: String? = null
|
||||
var bestScore = -1f
|
||||
for (p in cache) {
|
||||
if (p.embedding.size != embedding.size) continue
|
||||
@@ -54,13 +54,12 @@ class FaceRecognizer(context: Context) {
|
||||
if (score > bestScore) {
|
||||
bestScore = score
|
||||
bestId = p.id
|
||||
bestName = p.name
|
||||
}
|
||||
}
|
||||
if (bestScore >= AppConfig.FaceRecognition.SIMILARITY_THRESHOLD) {
|
||||
return FaceRecognitionResult(bestId, bestName, bestScore, embedding.size)
|
||||
return FaceRecognitionResult(bestId, bestScore, embedding.size)
|
||||
}
|
||||
return FaceRecognitionResult(null, null, bestScore, embedding.size)
|
||||
return FaceRecognitionResult(null, bestScore, embedding.size)
|
||||
}
|
||||
|
||||
fun extractEmbedding(bitmap: Bitmap, face: FaceBox): FloatArray {
|
||||
@@ -68,28 +67,208 @@ class FaceRecognizer(context: Context) {
|
||||
return engine.extractEmbedding(bitmap, face.left, face.top, face.right, face.bottom)
|
||||
}
|
||||
|
||||
private fun addProfile(name: String?, embedding: FloatArray): Long {
|
||||
fun extractEmbeddingAligned(bitmap: Bitmap, face: FaceBox): FloatArray {
|
||||
if (!initialized) return FloatArray(0)
|
||||
val aligned = alignFaceToArcFace(bitmap, face) ?: return extractEmbedding(bitmap, face)
|
||||
return try {
|
||||
engine.extractEmbedding(
|
||||
aligned,
|
||||
0f,
|
||||
0f,
|
||||
aligned.width.toFloat(),
|
||||
aligned.height.toFloat()
|
||||
)
|
||||
} finally {
|
||||
aligned.recycle()
|
||||
}
|
||||
}
|
||||
|
||||
fun extractEmbeddingAlignedForTest(bitmap: Bitmap, face: FaceBox): FloatArray {
|
||||
return extractEmbeddingAligned(bitmap, face)
|
||||
}
|
||||
|
||||
/**
|
||||
* 测试方法:计算同一张图片的两次识别相似度
|
||||
* @param bitmap 测试图片
|
||||
* @param face 人脸框
|
||||
* @return 两次识别的相似度
|
||||
*/
|
||||
fun testSimilarity(bitmap: Bitmap, face: FaceBox): Float {
|
||||
if (!initialized) return -1f
|
||||
|
||||
// 第一次提取特征
|
||||
val embedding1 = extractEmbedding(bitmap, face)
|
||||
if (embedding1.isEmpty()) return -1f
|
||||
|
||||
// 第二次提取特征
|
||||
val embedding2 = extractEmbedding(bitmap, face)
|
||||
if (embedding2.isEmpty()) return -1f
|
||||
|
||||
// 计算两次特征的相似度
|
||||
return cosineSimilarity(embedding1, embedding2)
|
||||
}
|
||||
|
||||
/**
|
||||
* 测试方法:计算两张图片的相似度
|
||||
* @param bitmap1 第一张图片
|
||||
* @param face1 第一张图片的人脸框
|
||||
* @param bitmap2 第二张图片
|
||||
* @param face2 第二张图片的人脸框
|
||||
* @return 两张图片的相似度
|
||||
*/
|
||||
fun testSimilarityBetween(bitmap1: Bitmap, face1: FaceBox, bitmap2: Bitmap, face2: FaceBox): Float {
|
||||
Log.d(AppConfig.TAG, "[FaceRecognizer] testSimilarityBetween: initialized=$initialized")
|
||||
if (!initialized) {
|
||||
Log.e(AppConfig.TAG, "[FaceRecognizer] testSimilarityBetween: recognizer not initialized")
|
||||
return -1f
|
||||
}
|
||||
|
||||
// 测试链路优先使用 landmarks 对齐;若缺少关键点则回退到 bbox 裁剪。
|
||||
Log.d(AppConfig.TAG, "[FaceRecognizer] testSimilarityBetween: extracting aligned embedding1...")
|
||||
val embedding1 = extractEmbeddingAlignedForTest(bitmap1, face1)
|
||||
Log.d(AppConfig.TAG, "[FaceRecognizer] testSimilarityBetween: embedding1 size=${embedding1.size}")
|
||||
|
||||
Log.d(AppConfig.TAG, "[FaceRecognizer] testSimilarityBetween: extracting aligned embedding2...")
|
||||
val embedding2 = extractEmbeddingAlignedForTest(bitmap2, face2)
|
||||
Log.d(AppConfig.TAG, "[FaceRecognizer] testSimilarityBetween: embedding2 size=${embedding2.size}")
|
||||
|
||||
if (embedding1.isEmpty() || embedding2.isEmpty()) {
|
||||
Log.e(AppConfig.TAG, "[FaceRecognizer] testSimilarityBetween: embedding extraction failed - embedding1 empty=${embedding1.isEmpty()}, embedding2 empty=${embedding2.isEmpty()}")
|
||||
return -1f
|
||||
}
|
||||
|
||||
// 计算相似度
|
||||
val similarity = cosineSimilarity(embedding1, embedding2)
|
||||
Log.d(AppConfig.TAG, "[FaceRecognizer] testSimilarityBetween: similarity=$similarity")
|
||||
return similarity
|
||||
}
|
||||
|
||||
private fun alignFaceToArcFace(bitmap: Bitmap, face: FaceBox): Bitmap? {
|
||||
val lm = face.landmarks ?: return null
|
||||
if (!face.hasLandmarks || lm.size < 10) return null
|
||||
|
||||
val dstW = 112
|
||||
val dstH = 112
|
||||
val src = arrayOf(
|
||||
floatArrayOf(lm[0], lm[1]), // left eye
|
||||
floatArrayOf(lm[2], lm[3]), // right eye
|
||||
floatArrayOf(lm[4], lm[5]), // nose
|
||||
floatArrayOf(lm[6], lm[7]), // left mouth
|
||||
floatArrayOf(lm[8], lm[9]), // right mouth
|
||||
)
|
||||
val dst = arrayOf(
|
||||
floatArrayOf(38.2946f, 51.6963f),
|
||||
floatArrayOf(73.5318f, 51.5014f),
|
||||
floatArrayOf(56.0252f, 71.7366f),
|
||||
floatArrayOf(41.5493f, 92.3655f),
|
||||
floatArrayOf(70.7299f, 92.2041f),
|
||||
)
|
||||
|
||||
val sim = estimateSimilarity(src, dst) ?: return null
|
||||
|
||||
val matrix = Matrix()
|
||||
matrix.setValues(
|
||||
floatArrayOf(
|
||||
sim[0], sim[1], sim[2],
|
||||
sim[3], sim[4], sim[5],
|
||||
0f, 0f, 1f
|
||||
)
|
||||
)
|
||||
|
||||
val out = Bitmap.createBitmap(dstW, dstH, Bitmap.Config.ARGB_8888)
|
||||
val canvas = Canvas(out)
|
||||
canvas.drawBitmap(bitmap, matrix, null)
|
||||
return out
|
||||
}
|
||||
|
||||
private fun estimateSimilarity(
|
||||
src: Array<FloatArray>,
|
||||
dst: Array<FloatArray>,
|
||||
): FloatArray? {
|
||||
if (src.size != dst.size || src.size < 2) return null
|
||||
|
||||
var mxs = 0f
|
||||
var mys = 0f
|
||||
var mxd = 0f
|
||||
var myd = 0f
|
||||
val n = src.size.toFloat()
|
||||
for (i in src.indices) {
|
||||
mxs += src[i][0]
|
||||
mys += src[i][1]
|
||||
mxd += dst[i][0]
|
||||
myd += dst[i][1]
|
||||
}
|
||||
mxs /= n
|
||||
mys /= n
|
||||
mxd /= n
|
||||
myd /= n
|
||||
|
||||
var a = 0f
|
||||
var b = 0f
|
||||
var norm = 0f
|
||||
for (i in src.indices) {
|
||||
val sx = src[i][0] - mxs
|
||||
val sy = src[i][1] - mys
|
||||
val dx = dst[i][0] - mxd
|
||||
val dy = dst[i][1] - myd
|
||||
a += sx * dx + sy * dy
|
||||
b += sx * dy - sy * dx
|
||||
norm += sx * sx + sy * sy
|
||||
}
|
||||
if (norm < 1e-6f) return null
|
||||
|
||||
val r = kotlin.math.sqrt(a * a + b * b)
|
||||
if (r < 1e-6f) return null
|
||||
val cos = a / r
|
||||
val sin = b / r
|
||||
val scale = r / norm
|
||||
|
||||
val m00 = scale * cos
|
||||
val m01 = -scale * sin
|
||||
val m10 = scale * sin
|
||||
val m11 = scale * cos
|
||||
val tx = mxd - (m00 * mxs + m01 * mys)
|
||||
val ty = myd - (m10 * mxs + m11 * mys)
|
||||
return floatArrayOf(m00, m01, tx, m10, m11, ty)
|
||||
}
|
||||
|
||||
private fun addProfile(embedding: FloatArray): Long {
|
||||
val normalized = normalize(embedding)
|
||||
val rowId = store.insertProfile(name, normalized)
|
||||
val rowId = store.insertProfile(normalized)
|
||||
if (rowId > 0) {
|
||||
cache.add(FaceProfile(id = rowId, name = name, embedding = normalized))
|
||||
cache.add(FaceProfile(id = rowId, embedding = normalized))
|
||||
}
|
||||
return rowId
|
||||
}
|
||||
|
||||
fun resolveIdentity(bitmap: Bitmap, face: FaceBox): FaceRecognitionResult {
|
||||
val match = identify(bitmap, face)
|
||||
if (match.matchedId != null) return match
|
||||
val embedding = extractEmbedding(bitmap, face)
|
||||
if (embedding.isEmpty()) return match
|
||||
val newId = addProfile(name = null, embedding = embedding)
|
||||
if (newId <= 0L) return match
|
||||
return FaceRecognitionResult(
|
||||
matchedId = newId,
|
||||
matchedName = null,
|
||||
similarity = match.similarity,
|
||||
embeddingDim = embedding.size
|
||||
)
|
||||
val embedding = extractEmbeddingAligned(bitmap, face)
|
||||
return resolveIdentityFromEmbedding(embedding)
|
||||
}
|
||||
|
||||
fun resolveIdentityFromEmbedding(embedding: FloatArray): FaceRecognitionResult {
|
||||
if (!initialized || embedding.isEmpty()) return FaceRecognitionResult(null, -1f, 0)
|
||||
val normalized = normalize(embedding)
|
||||
|
||||
var bestId: Long? = null
|
||||
var bestScore = -1f
|
||||
for (p in cache) {
|
||||
if (p.embedding.size != normalized.size) continue
|
||||
val score = cosineSimilarity(normalized, p.embedding)
|
||||
if (score > bestScore) {
|
||||
bestScore = score
|
||||
bestId = p.id
|
||||
}
|
||||
}
|
||||
if (bestId != null && bestScore >= AppConfig.FaceRecognition.SIMILARITY_THRESHOLD) {
|
||||
return FaceRecognitionResult(bestId, bestScore, normalized.size)
|
||||
}
|
||||
|
||||
val newId = addProfile(normalized)
|
||||
if (newId > 0L) {
|
||||
return FaceRecognitionResult(newId, bestScore, normalized.size)
|
||||
}
|
||||
return FaceRecognitionResult(null, bestScore, normalized.size)
|
||||
}
|
||||
|
||||
fun release() {
|
||||
|
||||
@@ -0,0 +1,258 @@
|
||||
package com.digitalperson.interaction
|
||||
|
||||
import android.util.Log
|
||||
import com.digitalperson.data.AppDatabase
|
||||
import com.digitalperson.data.entity.ChatMessage
|
||||
import com.digitalperson.data.entity.ChatMessageEntity
|
||||
import com.digitalperson.data.entity.ConversationSummaryEntity
|
||||
import com.digitalperson.llm.LLMManager
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.GlobalScope
|
||||
import kotlinx.coroutines.launch
|
||||
|
||||
/**
|
||||
* 对话缓冲区内存管理类
|
||||
* 用于存储完整的对话历史
|
||||
*/
|
||||
class ConversationBufferMemory(private val db: AppDatabase, private val maxMessages: Int = 50) {
|
||||
private val userMessages = mutableMapOf<String, MutableList<ChatMessage>>()
|
||||
private val TAG = "ConversationBufferMemory"
|
||||
|
||||
/**
|
||||
* 添加新消息
|
||||
* @param userId 用户ID
|
||||
* @param role 角色(user 或 assistant)
|
||||
* @param content 消息内容
|
||||
*/
|
||||
fun addMessage(userId: String, role: String, content: String) {
|
||||
val messages = userMessages.getOrPut(userId) { mutableListOf() }
|
||||
messages.add(ChatMessage(userId = userId, role = role, content = content))
|
||||
|
||||
// 限制消息数量,避免内存占用过大
|
||||
if (messages.size > maxMessages) {
|
||||
messages.removeAt(0)
|
||||
}
|
||||
|
||||
Log.d(TAG, "Added message for user $userId: $role - $content")
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取用户的对话历史
|
||||
* @param userId 用户ID
|
||||
* @return 对话历史列表
|
||||
*/
|
||||
fun getMessages(userId: String): List<ChatMessage> {
|
||||
return userMessages.getOrDefault(userId, mutableListOf()).toList()
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取用户的最近N条消息
|
||||
* @param userId 用户ID
|
||||
* @param limit 消息数量限制
|
||||
* @return 最近的消息列表
|
||||
*/
|
||||
fun getRecentMessages(userId: String, limit: Int): List<ChatMessage> {
|
||||
val messages = userMessages.getOrDefault(userId, mutableListOf())
|
||||
return if (messages.size <= limit) {
|
||||
messages.toList()
|
||||
} else {
|
||||
messages.subList(messages.size - limit, messages.size)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 从数据库加载用户历史
|
||||
* @param userId 用户ID
|
||||
*/
|
||||
fun loadFromDatabase(userId: String) {
|
||||
GlobalScope.launch(Dispatchers.IO) {
|
||||
try {
|
||||
val entities = db.chatMessageDao().getMessagesByUserId(userId)
|
||||
val messages = mutableListOf<ChatMessage>()
|
||||
entities.forEach { entity ->
|
||||
messages.add(
|
||||
ChatMessage(
|
||||
id = entity.id,
|
||||
userId = entity.userId,
|
||||
role = entity.role,
|
||||
content = entity.content,
|
||||
timestamp = entity.timestamp
|
||||
)
|
||||
)
|
||||
}
|
||||
userMessages[userId] = messages
|
||||
Log.d(TAG, "Loaded ${messages.size} messages for user $userId from database")
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Error loading messages from database: ${e.message}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 保存用户历史到数据库
|
||||
* @param userId 用户ID
|
||||
*/
|
||||
fun saveToDatabase(userId: String) {
|
||||
GlobalScope.launch(Dispatchers.IO) {
|
||||
try {
|
||||
db.chatMessageDao().clearMessagesByUserId(userId)
|
||||
userMessages[userId]?.forEach { message ->
|
||||
db.chatMessageDao().insert(
|
||||
ChatMessageEntity(
|
||||
id = message.id,
|
||||
userId = message.userId,
|
||||
role = message.role,
|
||||
content = message.content,
|
||||
timestamp = message.timestamp
|
||||
)
|
||||
)
|
||||
}
|
||||
Log.d(TAG, "Saved messages for user $userId to database")
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Error saving messages to database: ${e.message}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 清空用户历史
|
||||
* @param userId 用户ID
|
||||
*/
|
||||
fun clear(userId: String) {
|
||||
userMessages.remove(userId)
|
||||
GlobalScope.launch(Dispatchers.IO) {
|
||||
try {
|
||||
db.chatMessageDao().clearMessagesByUserId(userId)
|
||||
Log.d(TAG, "Cleared messages for user $userId")
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Error clearing messages: ${e.message}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 转换为LLM上下文格式
|
||||
* @param userId 用户ID
|
||||
* @return 上下文字符串
|
||||
*/
|
||||
fun toContextString(userId: String): String {
|
||||
val messages = userMessages.getOrDefault(userId, mutableListOf())
|
||||
return messages.joinToString("\n") {
|
||||
"${it.role}: ${it.content}"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 对话摘要内存管理类
|
||||
* 用于生成和存储对话摘要
|
||||
*/
|
||||
class ConversationSummaryMemory(private val db: AppDatabase, private val llmManager: LLMManager?) {
|
||||
private val userSummaries = mutableMapOf<String, String>()
|
||||
private val TAG = "ConversationSummaryMemory"
|
||||
|
||||
/**
|
||||
* 获取用户的对话摘要
|
||||
* @param userId 用户ID
|
||||
* @return 对话摘要
|
||||
*/
|
||||
fun getSummary(userId: String): String {
|
||||
return userSummaries.getOrDefault(userId, "")
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成用户的对话摘要
|
||||
* @param userId 用户ID
|
||||
* @param messages 对话消息列表
|
||||
* @param onComplete 完成回调
|
||||
*/
|
||||
fun generateSummary(userId: String, messages: List<ChatMessage>, onComplete: (String) -> Unit) {
|
||||
if (messages.isEmpty()) {
|
||||
onComplete("")
|
||||
return
|
||||
}
|
||||
|
||||
if (llmManager == null) {
|
||||
Log.w(TAG, "LLM manager is not initialized, cannot generate summary")
|
||||
onComplete("")
|
||||
return
|
||||
}
|
||||
|
||||
val historyText = messages.joinToString("\n") {
|
||||
"${it.role}: ${it.content}"
|
||||
}
|
||||
|
||||
val prompt = """
|
||||
请总结以下对话的关键信息,包括:
|
||||
1. 用户的主要问题或需求
|
||||
2. 重要的个人信息(如姓名、爱好等)
|
||||
3. 对话的核心结论或共识
|
||||
|
||||
输出简洁明了的总结,不超过500字。
|
||||
|
||||
对话历史:
|
||||
$historyText
|
||||
""".trimIndent()
|
||||
|
||||
Log.d(TAG, "Generating summary for user $userId")
|
||||
llmManager.generate(prompt) { summary ->
|
||||
val trimmedSummary = summary.trim()
|
||||
userSummaries[userId] = trimmedSummary
|
||||
|
||||
// 保存到数据库
|
||||
GlobalScope.launch(Dispatchers.IO) {
|
||||
try {
|
||||
db.conversationSummaryDao().clearSummariesByUserId(userId)
|
||||
db.conversationSummaryDao().insert(
|
||||
ConversationSummaryEntity(
|
||||
id = 0,
|
||||
userId = userId,
|
||||
summary = trimmedSummary,
|
||||
createdAt = System.currentTimeMillis()
|
||||
)
|
||||
)
|
||||
Log.d(TAG, "Saved summary for user $userId to database")
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Error saving summary to database: ${e.message}")
|
||||
}
|
||||
}
|
||||
|
||||
onComplete(trimmedSummary)
|
||||
Log.d(TAG, "Generated summary for user $userId: $trimmedSummary")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 从数据库加载用户摘要
|
||||
* @param userId 用户ID
|
||||
*/
|
||||
fun loadFromDatabase(userId: String) {
|
||||
GlobalScope.launch(Dispatchers.IO) {
|
||||
try {
|
||||
val summaryEntity = db.conversationSummaryDao().getLatestSummaryByUserId(userId)
|
||||
if (summaryEntity != null) {
|
||||
userSummaries[userId] = summaryEntity.summary
|
||||
Log.d(TAG, "Loaded summary for user $userId from database")
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Error loading summary from database: ${e.message}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 清空用户摘要
|
||||
* @param userId 用户ID
|
||||
*/
|
||||
fun clear(userId: String) {
|
||||
userSummaries.remove(userId)
|
||||
GlobalScope.launch(Dispatchers.IO) {
|
||||
try {
|
||||
db.conversationSummaryDao().clearSummariesByUserId(userId)
|
||||
Log.d(TAG, "Cleared summary for user $userId")
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Error clearing summary: ${e.message}")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,7 @@
|
||||
package com.digitalperson.interaction
|
||||
|
||||
import android.content.Context
|
||||
import com.digitalperson.util.SmartGreetingUtil
|
||||
import android.util.Log
|
||||
import com.digitalperson.config.AppConfig
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
@@ -32,18 +34,25 @@ interface InteractionActionHandler {
|
||||
fun onRememberUser(faceIdentityId: String, name: String?)
|
||||
fun saveThought(thought: String)
|
||||
fun loadLatestThought(): String?
|
||||
fun loadRecentThoughts(timeRangeMs: Long): List<String>
|
||||
fun addToChatHistory(role: String, content: String)
|
||||
fun addAssistantMessageToCloudHistory(content: String)
|
||||
fun getRandomQuestion(faceId: String): String
|
||||
}
|
||||
|
||||
class DigitalHumanInteractionController(
|
||||
private val scope: CoroutineScope,
|
||||
private val handler: InteractionActionHandler,
|
||||
private val context: Context
|
||||
) {
|
||||
|
||||
private val smartGreetingUtil = SmartGreetingUtil(context)
|
||||
private val TAG: String = "DigitalHumanInteraction"
|
||||
private var state: InteractionState = InteractionState.IDLE
|
||||
private var facePresent: Boolean = false
|
||||
private var recognizedName: String? = null
|
||||
private var currentFaceId: String? = null
|
||||
private var faceSeenSinceMs: Long = 0L
|
||||
private var proactiveRound = 0
|
||||
private var hasPendingUserReply = false
|
||||
|
||||
@@ -58,59 +67,86 @@ class DigitalHumanInteractionController(
|
||||
scheduleMemoryMode()
|
||||
}
|
||||
|
||||
fun onFacePresenceChanged(present: Boolean, faceIdentityId: String?, recognized: String?) {
|
||||
Log.d(TAG, "onFacePresenceChanged: present=$present, faceIdentityId=$faceIdentityId, recognized=$recognized, state=$state")
|
||||
|
||||
facePresent = present
|
||||
if (!faceIdentityId.isNullOrBlank()) {
|
||||
handler.onRememberUser(faceIdentityId, recognized)
|
||||
}
|
||||
if (!recognized.isNullOrBlank()) {
|
||||
recognizedName = recognized
|
||||
fun onFacePresenceChanged(present: Boolean) {
|
||||
Log.d(TAG, "onFacePresenceChanged: present=$present, state=$state")
|
||||
|
||||
facePresent = present
|
||||
val now = System.currentTimeMillis()
|
||||
if (present) {
|
||||
if (faceSeenSinceMs == 0L) {
|
||||
faceSeenSinceMs = now
|
||||
}
|
||||
} else {
|
||||
faceSeenSinceMs = 0L
|
||||
}
|
||||
|
||||
// 首次出现就启动稳定计时,到点只要人还在就问候。
|
||||
if (present) {
|
||||
val stableMs = now - faceSeenSinceMs
|
||||
val remain = AppConfig.Face.STABLE_MS - stableMs
|
||||
|
||||
if (remain > 0) {
|
||||
faceStableJob?.cancel()
|
||||
faceStableJob = scope.launch {
|
||||
delay(remain)
|
||||
if (!facePresent) return@launch
|
||||
if (state == InteractionState.IDLE || state == InteractionState.MEMORY || state == InteractionState.FAREWELL) {
|
||||
if (currentFaceId.isNullOrBlank()) {
|
||||
Log.d(TAG, "Greeting as unknown user: identity still unavailable after stable timeout")
|
||||
}
|
||||
enterGreeting()
|
||||
}
|
||||
}
|
||||
} else if (state == InteractionState.IDLE || state == InteractionState.MEMORY || state == InteractionState.FAREWELL) {
|
||||
enterGreeting()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 统一延迟处理
|
||||
// 人脸消失:保留小延迟,避免瞬时抖动导致频繁告别
|
||||
faceStableJob?.cancel()
|
||||
faceStableJob = scope.launch {
|
||||
delay(AppConfig.Face.STABLE_MS)
|
||||
|
||||
if (present) {
|
||||
// 人脸出现后的处理
|
||||
if (facePresent && (state == InteractionState.IDLE || state == InteractionState.MEMORY)) {
|
||||
enterGreeting()
|
||||
} else if (state == InteractionState.FAREWELL) {
|
||||
enterGreeting()
|
||||
}
|
||||
if (facePresent) return@launch
|
||||
if (state != InteractionState.IDLE && state != InteractionState.MEMORY && state != InteractionState.FAREWELL) {
|
||||
enterFarewell()
|
||||
} else {
|
||||
// 人脸消失后的处理
|
||||
if (state != InteractionState.IDLE && state != InteractionState.MEMORY && state != InteractionState.FAREWELL) {
|
||||
enterFarewell()
|
||||
} else {
|
||||
scheduleMemoryMode()
|
||||
}
|
||||
scheduleMemoryMode()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun onFaceIdentityUpdated(faceIdentityId: String?, recognized: String?) {
|
||||
if (faceIdentityId.isNullOrBlank() && recognized.isNullOrBlank()) return
|
||||
Log.d(TAG, "onFaceIdentityUpdated: faceIdentityId=$faceIdentityId, recognized=$recognized, state=$state")
|
||||
if (!faceIdentityId.isNullOrBlank()) {
|
||||
currentFaceId = faceIdentityId
|
||||
handler.onRememberUser(faceIdentityId, recognized)
|
||||
}
|
||||
if (!recognized.isNullOrBlank()) {
|
||||
recognizedName = recognized
|
||||
}
|
||||
}
|
||||
|
||||
fun onUserStartSpeaking() {
|
||||
Log.d(TAG, "onUserStartSpeaking called, current state: $state")
|
||||
hasPendingUserReply = true
|
||||
// 立即进入对话状态,无论当前处于什么状态
|
||||
transitionTo(InteractionState.DIALOGUE)
|
||||
// 取消任何等待超时的任务
|
||||
waitReplyJob?.cancel()
|
||||
proactiveJob?.cancel()
|
||||
}
|
||||
|
||||
fun onUserAsrText(text: String) {
|
||||
val userText = text.trim()
|
||||
if (userText.isBlank()) return
|
||||
|
||||
if (userText.contains("你在想什么")) {
|
||||
val thought = handler.loadLatestThought()
|
||||
if (!thought.isNullOrBlank()) {
|
||||
handler.speak("我刚才在想:$thought")
|
||||
handler.appendText("\n[回忆] $thought\n")
|
||||
transitionTo(InteractionState.DIALOGUE)
|
||||
handler.playMotion("haru_g_m15.motion3.json")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if (userText.contains("再见")) {
|
||||
enterFarewell()
|
||||
return
|
||||
}
|
||||
// TODO: 后续应该是通过大模型来进行判断是否进入 farewell 状态
|
||||
// if (userText.contains("再见")) {
|
||||
// enterFarewell()
|
||||
// return
|
||||
// }
|
||||
|
||||
hasPendingUserReply = true
|
||||
when (state) {
|
||||
@@ -140,15 +176,46 @@ fun onFacePresenceChanged(present: Boolean, faceIdentityId: String?, recognized:
|
||||
|
||||
private fun enterGreeting() {
|
||||
transitionTo(InteractionState.GREETING)
|
||||
val greet = if (!recognizedName.isNullOrBlank()) {
|
||||
handler.playMotion("haru_g_m22.motion3.json")
|
||||
"你好,$recognizedName,很高兴再次见到你。"
|
||||
|
||||
// 使用智能问候语
|
||||
val (isFestival, festivalName) = smartGreetingUtil.isFestivalToday()
|
||||
|
||||
if (isFestival) {
|
||||
// 节日问候,使用本地LLM生成
|
||||
val prompt = smartGreetingUtil.getFestivalGreetingPrompt(festivalName, recognizedName)
|
||||
handler.requestLocalThought(prompt) { greeting ->
|
||||
scope.launch {
|
||||
if (!greeting.isNullOrBlank()) {
|
||||
// 使用LLM生成的问候语
|
||||
handler.playMotion(if (!recognizedName.isNullOrBlank()) "haru_g_m22.motion3.json" else "haru_g_m01.motion3.json")
|
||||
handler.speak(greeting)
|
||||
handler.appendText("\n[问候] $greeting\n")
|
||||
handler.addToChatHistory("assistant", greeting)
|
||||
handler.addAssistantMessageToCloudHistory(greeting)
|
||||
|
||||
transitionTo(InteractionState.WAITING_REPLY)
|
||||
handler.playMotion("haru_g_m17.motion3.json")
|
||||
scheduleWaitingReplyTimeout()
|
||||
} else {
|
||||
// LLM生成失败,使用默认问候语
|
||||
useDefaultGreeting()
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
handler.playMotion("haru_g_m01.motion3.json")
|
||||
"你好,很高兴见到你。"
|
||||
// 非节日或无网络,使用默认问候语
|
||||
useDefaultGreeting()
|
||||
}
|
||||
handler.speak(greet)
|
||||
handler.appendText("\n[问候] $greet\n")
|
||||
}
|
||||
|
||||
private fun useDefaultGreeting() {
|
||||
val greeting = smartGreetingUtil.getSmartGreeting(recognizedName)
|
||||
handler.playMotion(if (!recognizedName.isNullOrBlank()) "haru_g_m22.motion3.json" else "haru_g_m01.motion3.json")
|
||||
handler.speak(greeting)
|
||||
handler.appendText("\n[问候] $greeting\n")
|
||||
handler.addToChatHistory("assistant", greeting)
|
||||
handler.addAssistantMessageToCloudHistory(greeting)
|
||||
|
||||
transitionTo(InteractionState.WAITING_REPLY)
|
||||
handler.playMotion("haru_g_m17.motion3.json")
|
||||
scheduleWaitingReplyTimeout()
|
||||
@@ -176,16 +243,13 @@ fun onFacePresenceChanged(present: Boolean, faceIdentityId: String?, recognized:
|
||||
|
||||
private fun askProactiveTopic() {
|
||||
proactiveJob?.cancel()
|
||||
val topics = listOf(
|
||||
"我出一道数学题考考你吧?1+6等于多少?",
|
||||
"我们上完厕所应该干什么呀?",
|
||||
"你喜欢什么颜色呀?",
|
||||
)
|
||||
val idx = proactiveRound.coerceIn(0, topics.lastIndex)
|
||||
val topic = topics[idx]
|
||||
|
||||
// 从数据库获取问题
|
||||
val topic = getQuestionFromDatabase()
|
||||
|
||||
handler.playMotion(if (proactiveRound == 0) "haru_g_m15.motion3.json" else "haru_g_m22.motion3.json")
|
||||
handler.speak(topic)
|
||||
handler.appendText("\n[主动引导] $topic\n")
|
||||
handler.speak("嗨小朋友,可以帮老师个忙吗?"+topic)
|
||||
handler.appendText("\n[主动引导] \"嗨小朋友,可以帮老师个忙吗?\"$topic\n")
|
||||
// 将引导内容添加到对话历史中
|
||||
handler.addToChatHistory("助手", topic)
|
||||
// 将引导内容添加到云对话历史中
|
||||
@@ -209,6 +273,17 @@ fun onFacePresenceChanged(present: Boolean, faceIdentityId: String?, recognized:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun getQuestionFromDatabase(): String {
|
||||
// 从数据库获取随机问题(未被该用户问过的)
|
||||
val faceId = getCurrentFaceId()
|
||||
return handler.getRandomQuestion(faceId)
|
||||
}
|
||||
|
||||
private fun getCurrentFaceId(): String {
|
||||
// 返回当前用户的 faceId
|
||||
return currentFaceId ?: "default_face_id"
|
||||
}
|
||||
|
||||
private fun enterFarewell() {
|
||||
transitionTo(InteractionState.FAREWELL)
|
||||
@@ -224,6 +299,7 @@ fun onFacePresenceChanged(present: Boolean, faceIdentityId: String?, recognized:
|
||||
}
|
||||
|
||||
private fun scheduleMemoryMode() {
|
||||
return // TODO: remove this after testing done! Now just skipping memory
|
||||
memoryJob?.cancel()
|
||||
if (facePresent) return
|
||||
memoryJob = scope.launch {
|
||||
@@ -248,8 +324,20 @@ fun onFacePresenceChanged(present: Boolean, faceIdentityId: String?, recognized:
|
||||
private fun requestDialogueReply(userText: String) {
|
||||
waitReplyJob?.cancel()
|
||||
proactiveJob?.cancel()
|
||||
|
||||
// 获取最近的回忆信息
|
||||
val recentThoughts = handler.loadRecentThoughts(10 * 60 * 1000)
|
||||
val context = if (recentThoughts.isNotEmpty()) {
|
||||
val thoughtsContext = recentThoughts.joinToString("\n") { "- $it" }
|
||||
"以下是我最近10分钟内的想法:\n$thoughtsContext\n\n"
|
||||
} else {
|
||||
""
|
||||
}
|
||||
|
||||
// 将回忆信息作为上下文传递给LLM
|
||||
val userTextWithContext = "$context$userText"
|
||||
// 按产品要求:用户对话统一走云端LLM
|
||||
handler.requestCloudReply(userText)
|
||||
handler.requestCloudReply(userTextWithContext)
|
||||
}
|
||||
|
||||
private fun transitionTo(newState: InteractionState) {
|
||||
@@ -265,4 +353,5 @@ fun onFacePresenceChanged(present: Boolean, faceIdentityId: String?, recognized:
|
||||
memoryJob?.cancel()
|
||||
farewellJob?.cancel()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,80 +1,52 @@
|
||||
package com.digitalperson.interaction
|
||||
|
||||
import android.content.ContentValues
|
||||
import android.content.Context
|
||||
import android.database.sqlite.SQLiteDatabase
|
||||
import android.database.sqlite.SQLiteOpenHelper
|
||||
import com.digitalperson.data.AppDatabase
|
||||
import com.digitalperson.data.entity.Question
|
||||
import com.digitalperson.data.entity.UserAnswer
|
||||
import com.digitalperson.data.entity.UserMemory
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.GlobalScope
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.withContext
|
||||
|
||||
data class UserMemory(
|
||||
val userId: String,
|
||||
val displayName: String?,
|
||||
val lastSeenAt: Long,
|
||||
val age: String?,
|
||||
val gender: String?,
|
||||
val hobbies: String?,
|
||||
val preferences: String?,
|
||||
val lastTopics: String?,
|
||||
val lastThought: String?,
|
||||
val profileSummary: String?,
|
||||
)
|
||||
|
||||
class UserMemoryStore(context: Context) : SQLiteOpenHelper(context, DB_NAME, null, DB_VERSION) {
|
||||
class UserMemoryStore(context: Context) {
|
||||
private val db = AppDatabase.getInstance(context)
|
||||
private val userMemoryDao = db.userMemoryDao()
|
||||
private val questionDao = db.questionDao()
|
||||
private val userAnswerDao = db.userAnswerDao()
|
||||
private val memoryCache = LinkedHashMap<String, UserMemory>()
|
||||
@Volatile private var latestThoughtCache: String? = null
|
||||
|
||||
override fun onCreate(db: SQLiteDatabase) {
|
||||
db.execSQL(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS user_memory (
|
||||
user_id TEXT PRIMARY KEY,
|
||||
display_name TEXT,
|
||||
last_seen_at INTEGER NOT NULL,
|
||||
age TEXT,
|
||||
gender TEXT,
|
||||
hobbies TEXT,
|
||||
preferences TEXT,
|
||||
last_topics TEXT,
|
||||
last_thought TEXT,
|
||||
profile_summary TEXT
|
||||
)
|
||||
""".trimIndent()
|
||||
)
|
||||
}
|
||||
|
||||
override fun onUpgrade(db: SQLiteDatabase, oldVersion: Int, newVersion: Int) {
|
||||
db.execSQL("DROP TABLE IF EXISTS user_memory")
|
||||
onCreate(db)
|
||||
// 评价类型
|
||||
enum class AnswerEvaluation(val displayName: String) {
|
||||
CORRECT("正确"),
|
||||
INCORRECT("错误"),
|
||||
PARTIALLY_CORRECT("基本正确")
|
||||
}
|
||||
|
||||
fun upsertUserSeen(userId: String, displayName: String?) {
|
||||
val existing = memoryCache[userId] ?: getMemory(userId)
|
||||
val now = System.currentTimeMillis()
|
||||
val mergedName = displayName?.takeIf { it.isNotBlank() } ?: existing?.displayName
|
||||
val values = ContentValues().apply {
|
||||
put("user_id", userId)
|
||||
put("display_name", mergedName)
|
||||
put("last_seen_at", now)
|
||||
put("age", existing?.age)
|
||||
put("gender", existing?.gender)
|
||||
put("hobbies", existing?.hobbies)
|
||||
put("preferences", existing?.preferences)
|
||||
put("last_topics", existing?.lastTopics)
|
||||
put("last_thought", existing?.lastThought)
|
||||
put("profile_summary", existing?.profileSummary)
|
||||
GlobalScope.launch(Dispatchers.IO) {
|
||||
val existing = getMemorySync(userId)
|
||||
val now = System.currentTimeMillis()
|
||||
val mergedName = displayName?.takeIf { it.isNotBlank() } ?: existing?.displayName
|
||||
|
||||
val userMemory = UserMemory(
|
||||
userId = userId,
|
||||
displayName = mergedName,
|
||||
lastSeenAt = now,
|
||||
age = existing?.age,
|
||||
gender = existing?.gender,
|
||||
hobbies = existing?.hobbies,
|
||||
preferences = existing?.preferences,
|
||||
lastTopics = existing?.lastTopics,
|
||||
lastThought = existing?.lastThought,
|
||||
profileSummary = existing?.profileSummary,
|
||||
)
|
||||
|
||||
userMemoryDao.insert(userMemory)
|
||||
memoryCache[userId] = userMemory
|
||||
}
|
||||
writableDatabase.insertWithOnConflict("user_memory", null, values, SQLiteDatabase.CONFLICT_REPLACE)
|
||||
memoryCache[userId] = UserMemory(
|
||||
userId = userId,
|
||||
displayName = mergedName,
|
||||
lastSeenAt = now,
|
||||
age = existing?.age,
|
||||
gender = existing?.gender,
|
||||
hobbies = existing?.hobbies,
|
||||
preferences = existing?.preferences,
|
||||
lastTopics = existing?.lastTopics,
|
||||
lastThought = existing?.lastThought,
|
||||
profileSummary = existing?.profileSummary,
|
||||
)
|
||||
}
|
||||
|
||||
fun updateDisplayName(userId: String, displayName: String?) {
|
||||
@@ -83,95 +55,164 @@ class UserMemoryStore(context: Context) : SQLiteOpenHelper(context, DB_NAME, nul
|
||||
}
|
||||
|
||||
fun updateThought(userId: String, thought: String) {
|
||||
upsertUserSeen(userId, null)
|
||||
val now = System.currentTimeMillis()
|
||||
val values = ContentValues().apply {
|
||||
put("last_thought", thought)
|
||||
put("last_seen_at", now)
|
||||
GlobalScope.launch(Dispatchers.IO) {
|
||||
upsertUserSeen(userId, null)
|
||||
val now = System.currentTimeMillis()
|
||||
userMemoryDao.updateThought(userId, thought, now)
|
||||
latestThoughtCache = thought
|
||||
|
||||
val existing = memoryCache[userId]
|
||||
if (existing != null) {
|
||||
memoryCache[userId] = existing.copy(
|
||||
lastThought = thought,
|
||||
lastSeenAt = now
|
||||
)
|
||||
}
|
||||
}
|
||||
writableDatabase.update("user_memory", values, "user_id=?", arrayOf(userId))
|
||||
latestThoughtCache = thought
|
||||
val cached = memoryCache[userId]
|
||||
memoryCache[userId] = UserMemory(
|
||||
userId = userId,
|
||||
displayName = cached?.displayName,
|
||||
lastSeenAt = now,
|
||||
age = cached?.age,
|
||||
gender = cached?.gender,
|
||||
hobbies = cached?.hobbies,
|
||||
preferences = cached?.preferences,
|
||||
lastTopics = cached?.lastTopics,
|
||||
lastThought = thought,
|
||||
profileSummary = cached?.profileSummary,
|
||||
)
|
||||
}
|
||||
|
||||
fun updateProfile(userId: String, age: String?, gender: String?, hobbies: String?, summary: String?) {
|
||||
upsertUserSeen(userId, null)
|
||||
val now = System.currentTimeMillis()
|
||||
val values = ContentValues().apply {
|
||||
if (age != null) put("age", age)
|
||||
if (gender != null) put("gender", gender)
|
||||
if (hobbies != null) put("hobbies", hobbies)
|
||||
if (summary != null) put("profile_summary", summary)
|
||||
put("last_seen_at", now)
|
||||
GlobalScope.launch(Dispatchers.IO) {
|
||||
upsertUserSeen(userId, null)
|
||||
val now = System.currentTimeMillis()
|
||||
userMemoryDao.updateProfile(userId, age, gender, hobbies, summary, now)
|
||||
|
||||
val existing = memoryCache[userId]
|
||||
if (existing != null) {
|
||||
memoryCache[userId] = existing.copy(
|
||||
age = age ?: existing.age,
|
||||
gender = gender ?: existing.gender,
|
||||
hobbies = hobbies ?: existing.hobbies,
|
||||
profileSummary = summary ?: existing.profileSummary,
|
||||
lastSeenAt = now
|
||||
)
|
||||
}
|
||||
}
|
||||
writableDatabase.update("user_memory", values, "user_id=?", arrayOf(userId))
|
||||
val cached = memoryCache[userId]
|
||||
memoryCache[userId] = UserMemory(
|
||||
userId = userId,
|
||||
displayName = cached?.displayName,
|
||||
lastSeenAt = now,
|
||||
age = age ?: cached?.age,
|
||||
gender = gender ?: cached?.gender,
|
||||
hobbies = hobbies ?: cached?.hobbies,
|
||||
preferences = cached?.preferences,
|
||||
lastTopics = cached?.lastTopics,
|
||||
lastThought = cached?.lastThought,
|
||||
profileSummary = summary ?: cached?.profileSummary,
|
||||
)
|
||||
}
|
||||
|
||||
fun getMemory(userId: String): UserMemory? {
|
||||
memoryCache[userId]?.let { return it }
|
||||
readableDatabase.rawQuery(
|
||||
"SELECT user_id, display_name, last_seen_at, age, gender, hobbies, preferences, last_topics, last_thought, profile_summary FROM user_memory WHERE user_id=?",
|
||||
arrayOf(userId)
|
||||
).use { c ->
|
||||
if (!c.moveToFirst()) return null
|
||||
val memory = UserMemory(
|
||||
userId = c.getString(0),
|
||||
displayName = c.getString(1),
|
||||
lastSeenAt = c.getLong(2),
|
||||
age = c.getString(3),
|
||||
gender = c.getString(4),
|
||||
hobbies = c.getString(5),
|
||||
preferences = c.getString(6),
|
||||
lastTopics = c.getString(7),
|
||||
lastThought = c.getString(8),
|
||||
profileSummary = c.getString(9),
|
||||
)
|
||||
memoryCache[userId] = memory
|
||||
if (!memory.lastThought.isNullOrBlank()) {
|
||||
latestThoughtCache = memory.lastThought
|
||||
}
|
||||
return memory
|
||||
val memory = userMemoryDao.getUserMemory(userId)
|
||||
memory?.let { memoryCache[userId] = it }
|
||||
if (memory?.lastThought != null) {
|
||||
latestThoughtCache = memory.lastThought
|
||||
}
|
||||
return memory
|
||||
}
|
||||
|
||||
// 同步版本,用于内部使用
|
||||
private fun getMemorySync(userId: String): UserMemory? {
|
||||
memoryCache[userId]?.let { return it }
|
||||
val memory = userMemoryDao.getUserMemory(userId)
|
||||
memory?.let { memoryCache[userId] = it }
|
||||
if (memory?.lastThought != null) {
|
||||
latestThoughtCache = memory.lastThought
|
||||
}
|
||||
return memory
|
||||
}
|
||||
|
||||
fun getLatestThought(): String? {
|
||||
latestThoughtCache?.let { return it }
|
||||
readableDatabase.rawQuery(
|
||||
"SELECT last_thought FROM user_memory WHERE last_thought IS NOT NULL AND last_thought != '' ORDER BY last_seen_at DESC LIMIT 1",
|
||||
null
|
||||
).use { c ->
|
||||
if (!c.moveToFirst()) return null
|
||||
return c.getString(0).also { latestThoughtCache = it }
|
||||
val thought = userMemoryDao.getLatestThought()
|
||||
thought?.let { latestThoughtCache = it }
|
||||
return thought
|
||||
}
|
||||
|
||||
fun getRecentThoughts(timeRangeMs: Long = 10 * 60 * 1000): List<String> {
|
||||
val cutoffTime = System.currentTimeMillis() - timeRangeMs
|
||||
val recentMemories = userMemoryDao.getRecentThoughts(cutoffTime)
|
||||
return recentMemories.mapNotNull { it.lastThought }
|
||||
}
|
||||
|
||||
// 问题相关方法
|
||||
fun addQuestion(content: String, answer: String?, subject: String?, grade: Int?, difficulty: Int = 1): Long {
|
||||
val question = Question(
|
||||
id = 0,
|
||||
content = content,
|
||||
answer = answer,
|
||||
subject = subject,
|
||||
grade = grade,
|
||||
difficulty = difficulty,
|
||||
createdAt = System.currentTimeMillis()
|
||||
)
|
||||
return questionDao.insert(question)
|
||||
}
|
||||
|
||||
fun getQuestionsBySubject(subject: String): List<Question> {
|
||||
return questionDao.getQuestionsBySubject(subject)
|
||||
}
|
||||
|
||||
fun getQuestionsByGrade(grade: Int): List<Question> {
|
||||
return questionDao.getQuestionsByGrade(grade)
|
||||
}
|
||||
|
||||
fun getRandomQuestion(subject: String? = null, grade: Int? = null): Question? {
|
||||
return when {
|
||||
subject != null && grade != null ->
|
||||
questionDao.getRandomQuestionBySubjectAndGrade(subject, grade)
|
||||
subject != null ->
|
||||
questionDao.getRandomQuestionBySubject(subject)
|
||||
else ->
|
||||
questionDao.getRandomQuestion()
|
||||
}
|
||||
}
|
||||
|
||||
fun getRandomUnansweredQuestion(userId: String): Question? {
|
||||
return questionDao.getRandomUnansweredQuestion(userId)
|
||||
}
|
||||
|
||||
fun recordUserAnswer(userId: String, questionId: Long, userAnswer: String?, evaluation: AnswerEvaluation?) {
|
||||
GlobalScope.launch(Dispatchers.IO) {
|
||||
val userAnswerEntity = UserAnswer(
|
||||
id = 0,
|
||||
userId = userId,
|
||||
questionId = questionId,
|
||||
userAnswer = userAnswer,
|
||||
evaluation = evaluation?.name,
|
||||
answeredAt = System.currentTimeMillis()
|
||||
)
|
||||
userAnswerDao.insert(userAnswerEntity)
|
||||
}
|
||||
}
|
||||
|
||||
companion object {
|
||||
private const val DB_NAME = "digital_human_memory.db"
|
||||
private const val DB_VERSION = 2
|
||||
fun getUserAnswers(userId: String, limit: Int = 50): List<UserAnswer> {
|
||||
return userAnswerDao.getUserAnswers(userId, limit)
|
||||
}
|
||||
|
||||
fun getIncorrectAnswers(userId: String, limit: Int = 10): List<Pair<Question, String>> {
|
||||
val incorrectAnswers = userAnswerDao.getIncorrectAnswers(userId, limit)
|
||||
return incorrectAnswers.mapNotNull { answerWithQuestion ->
|
||||
val question = questionDao.getQuestionById(answerWithQuestion.questionId)
|
||||
question?.let { Pair(it, answerWithQuestion.userAnswer ?: "") }
|
||||
}
|
||||
}
|
||||
|
||||
fun getQuestionStatistics(userId: String): Map<String, Int> {
|
||||
val stats = mutableMapOf(
|
||||
"total" to 0,
|
||||
"correct" to 0,
|
||||
"incorrect" to 0,
|
||||
"partially_correct" to 0
|
||||
)
|
||||
|
||||
val answerStats = userAnswerDao.getAnswerStatistics(userId)
|
||||
answerStats.forEach {
|
||||
val count = it.count
|
||||
stats["total"] = stats["total"]!! + count
|
||||
when (it.evaluation) {
|
||||
AnswerEvaluation.CORRECT.name -> stats["correct"] = count
|
||||
AnswerEvaluation.INCORRECT.name -> stats["incorrect"] = count
|
||||
AnswerEvaluation.PARTIALLY_CORRECT.name -> stats["partially_correct"] = count
|
||||
}
|
||||
}
|
||||
return stats
|
||||
}
|
||||
|
||||
fun getQuestionById(questionId: Long): Question? {
|
||||
return questionDao.getQuestionById(questionId)
|
||||
}
|
||||
|
||||
fun getAllSubjects(): List<String> {
|
||||
return questionDao.getAllSubjects()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,4 +43,30 @@ class LLMManager(modelPath: String, callback: LLMManagerCallback) :
|
||||
val msg = "<|System|>$systemPrompt<|User|>$userPrompt<|Assistant|>"
|
||||
say(msg)
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成文本
|
||||
* @param prompt 提示词
|
||||
* @param onComplete 完成回调
|
||||
*/
|
||||
fun generate(prompt: String, onComplete: (String) -> Unit) {
|
||||
val msg = "<|User|>$prompt<|Assistant|>"
|
||||
var result = ""
|
||||
|
||||
// 创建临时回调
|
||||
val tempCallback = object : LLMCallback {
|
||||
override fun onCallback(data: String, state: LLMCallback.State) {
|
||||
if (state == LLMCallback.State.NORMAL) {
|
||||
if (data != "<think>" && data != "</think>" && data != "\n") {
|
||||
result += data
|
||||
}
|
||||
} else {
|
||||
onComplete(result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 调用底层方法
|
||||
say(msg, tempCallback)
|
||||
}
|
||||
}
|
||||
@@ -38,6 +38,23 @@ open class RKLLM(modelPath: String, callback: LLMCallback) {
|
||||
infer(mInstance, text)
|
||||
}
|
||||
|
||||
protected fun say(text: String, callback: LLMCallback) {
|
||||
if (mInstance == 0L) {
|
||||
callback.onCallback("RKLLM is not initialized", LLMCallback.State.ERROR)
|
||||
return
|
||||
}
|
||||
// 保存原始回调
|
||||
val originalCallback = mCallback
|
||||
// 临时替换回调
|
||||
mCallback = callback
|
||||
try {
|
||||
infer(mInstance, text)
|
||||
} finally {
|
||||
// 恢复原始回调
|
||||
mCallback = originalCallback
|
||||
}
|
||||
}
|
||||
|
||||
fun callbackFromNative(data: String, state: Int) {
|
||||
var s = LLMCallback.State.ERROR
|
||||
s = if (state == 0) LLMCallback.State.FINISH
|
||||
|
||||
@@ -19,8 +19,11 @@ import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.launch
|
||||
import java.nio.ByteBuffer
|
||||
import java.util.UUID
|
||||
import java.util.concurrent.CountDownLatch
|
||||
import java.util.concurrent.LinkedBlockingQueue
|
||||
import java.util.concurrent.TimeUnit
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
import java.util.concurrent.atomic.AtomicLong
|
||||
|
||||
class QCloudTtsManager(private val context: Context) {
|
||||
|
||||
@@ -37,11 +40,16 @@ class QCloudTtsManager(private val context: Context) {
|
||||
data object End : TtsQueueItem()
|
||||
}
|
||||
|
||||
private val ttsQueue = LinkedBlockingQueue<TtsQueueItem>()
|
||||
private val ttsStopped = AtomicBoolean(false)
|
||||
private val ttsWorkerRunning = AtomicBoolean(false)
|
||||
private val ttsPlaying = AtomicBoolean(false)
|
||||
private val ttsWorkerRunning = AtomicBoolean(false)
|
||||
private val interrupting = AtomicBoolean(false)
|
||||
private val playEpoch = AtomicLong(0)
|
||||
private val writtenFrames = AtomicLong(0)
|
||||
private val audioTrackLock = Any()
|
||||
private val ttsQueue = LinkedBlockingQueue<TtsQueueItem>()
|
||||
@Volatile private var lastEnqueuedText: String = ""
|
||||
@Volatile private var lastEnqueuedAtMs: Long = 0L
|
||||
|
||||
private val ioScope = CoroutineScope(Dispatchers.IO)
|
||||
|
||||
@@ -100,6 +108,15 @@ class QCloudTtsManager(private val context: Context) {
|
||||
ttsStopped.set(false)
|
||||
}
|
||||
val cleanedSeg = seg.trimEnd('.', '。', '!', '!', '?', '?', ',', ',', ';', ';', ':', ':')
|
||||
if (cleanedSeg.isBlank()) return
|
||||
val now = System.currentTimeMillis()
|
||||
// Guard against accidental duplicate enqueue from streaming/final flush overlap.
|
||||
if (cleanedSeg == lastEnqueuedText && now - lastEnqueuedAtMs < 3000L) {
|
||||
Log.w(TAG, "Skip duplicate segment within 3s: '$cleanedSeg'")
|
||||
return
|
||||
}
|
||||
lastEnqueuedText = cleanedSeg
|
||||
lastEnqueuedAtMs = now
|
||||
ttsQueue.offer(TtsQueueItem.Segment(cleanedSeg))
|
||||
ensureTtsWorker()
|
||||
}
|
||||
@@ -126,20 +143,23 @@ class QCloudTtsManager(private val context: Context) {
|
||||
fun stop() {
|
||||
ttsStopped.set(true)
|
||||
ttsPlaying.set(false)
|
||||
playEpoch.incrementAndGet()
|
||||
writtenFrames.set(0L)
|
||||
ttsQueue.clear()
|
||||
ttsQueue.offer(TtsQueueItem.End)
|
||||
|
||||
try {
|
||||
synthesizer?.cancel()
|
||||
synthesizer = null
|
||||
audioTrack?.pause()
|
||||
audioTrack?.flush()
|
||||
pauseAndFlushAudioTrack()
|
||||
} catch (_: Throwable) {
|
||||
}
|
||||
}
|
||||
|
||||
fun interruptForNewTurn(waitTimeoutMs: Long = 300): Boolean {
|
||||
if (!interrupting.compareAndSet(false, true)) return false
|
||||
if (!interrupting.compareAndSet(false, true)) {
|
||||
return false
|
||||
}
|
||||
try {
|
||||
val hadPendingPlayback = ttsPlaying.get() || ttsWorkerRunning.get() || ttsQueue.isNotEmpty()
|
||||
if (!hadPendingPlayback) {
|
||||
@@ -150,15 +170,17 @@ class QCloudTtsManager(private val context: Context) {
|
||||
|
||||
ttsStopped.set(true)
|
||||
ttsPlaying.set(false)
|
||||
playEpoch.incrementAndGet()
|
||||
writtenFrames.set(0L)
|
||||
ttsQueue.clear()
|
||||
ttsQueue.offer(TtsQueueItem.End)
|
||||
|
||||
try {
|
||||
synthesizer?.cancel()
|
||||
synthesizer = null
|
||||
audioTrack?.pause()
|
||||
audioTrack?.flush()
|
||||
} catch (_: Throwable) {
|
||||
pauseAndFlushAudioTrack()
|
||||
} catch (e: Throwable) {
|
||||
Log.e(TAG, "Error during interruption: ${e.message}")
|
||||
}
|
||||
|
||||
val deadline = System.currentTimeMillis() + waitTimeoutMs
|
||||
@@ -212,21 +234,28 @@ class QCloudTtsManager(private val context: Context) {
|
||||
|
||||
while (true) {
|
||||
val item = ttsQueue.take()
|
||||
if (ttsStopped.get()) break
|
||||
if (ttsStopped.get()) {
|
||||
break
|
||||
}
|
||||
|
||||
when (item) {
|
||||
is TtsQueueItem.Segment -> {
|
||||
ttsPlaying.set(true)
|
||||
callback?.onSetSpeaking(true)
|
||||
Log.d(TAG, "QCloud TTS started: processing segment '${item.text}'")
|
||||
callback?.onTtsStarted(item.text)
|
||||
|
||||
val startMs = System.currentTimeMillis()
|
||||
val segmentEpoch = playEpoch.get()
|
||||
val segmentDone = CountDownLatch(1)
|
||||
|
||||
try {
|
||||
if (audioTrack.playState != AudioTrack.PLAYSTATE_PLAYING) {
|
||||
audioTrack.play()
|
||||
playAudioTrack(audioTrack)
|
||||
// Always cancel stale instance before creating a new one.
|
||||
try {
|
||||
synthesizer?.cancel()
|
||||
} catch (_: Throwable) {
|
||||
}
|
||||
synthesizer = null
|
||||
|
||||
val credential = Credential(
|
||||
AppConfig.QCloud.APP_ID,
|
||||
@@ -270,40 +299,46 @@ class QCloudTtsManager(private val context: Context) {
|
||||
|
||||
val listener = object : RealTimeSpeechSynthesizerListener() {
|
||||
override fun onSynthesisStart(response: SpeechSynthesizerResponse) {
|
||||
Log.d(TAG, "onSynthesisStart: ${response.sessionId}")
|
||||
Log.d(TAG, "Starting synthesizer: ${response.sessionId}")
|
||||
}
|
||||
|
||||
override fun onSynthesisEnd(response: SpeechSynthesizerResponse) {
|
||||
Log.d(TAG, "onSynthesisEnd: ${response.sessionId}")
|
||||
val ttsMs = System.currentTimeMillis() - startMs
|
||||
callback?.onTtsSegmentCompleted(ttsMs)
|
||||
segmentDone.countDown()
|
||||
}
|
||||
|
||||
override fun onAudioResult(buffer: ByteBuffer) {
|
||||
if (ttsStopped.get() || segmentEpoch != playEpoch.get()) {
|
||||
return
|
||||
}
|
||||
val data = ByteArray(buffer.remaining())
|
||||
buffer.get(data)
|
||||
// 播放pcm
|
||||
audioTrack.write(data, 0, data.size)
|
||||
writeAudioTrack(audioTrack, data)
|
||||
}
|
||||
|
||||
override fun onTextResult(response: SpeechSynthesizerResponse) {
|
||||
Log.d(TAG, "onTextResult: ${response.sessionId}")
|
||||
}
|
||||
override fun onTextResult(response: SpeechSynthesizerResponse) {}
|
||||
|
||||
override fun onSynthesisCancel() {
|
||||
Log.d(TAG, "onSynthesisCancel")
|
||||
segmentDone.countDown()
|
||||
}
|
||||
|
||||
override fun onSynthesisFail(response: SpeechSynthesizerResponse) {
|
||||
Log.e(TAG, "onSynthesisFail: ${response.sessionId}, error: ${response.message}")
|
||||
segmentDone.countDown()
|
||||
}
|
||||
}
|
||||
|
||||
synthesizer = RealTimeSpeechSynthesizer(proxy, credential, request, listener)
|
||||
synthesizer?.start()
|
||||
segmentDone.await(40, TimeUnit.SECONDS)
|
||||
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "QCloud TTS error: ${e.message}", e)
|
||||
} finally {
|
||||
if (synthesizer != null && (ttsStopped.get() || segmentEpoch != playEpoch.get())) {
|
||||
try { synthesizer?.cancel() } catch (_: Throwable) {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -324,7 +359,54 @@ class QCloudTtsManager(private val context: Context) {
|
||||
}
|
||||
|
||||
private fun waitForPlaybackComplete(audioTrack: AudioTrack) {
|
||||
// 等待音频播放完成
|
||||
Thread.sleep(1000)
|
||||
val targetFrames = writtenFrames.get()
|
||||
if (targetFrames <= 0L) return
|
||||
|
||||
val timeoutAt = System.currentTimeMillis() + 1800L
|
||||
while (System.currentTimeMillis() < timeoutAt) {
|
||||
val playedFrames = audioTrack.playbackHeadPosition.toLong() and 0xFFFFFFFFL
|
||||
if (playedFrames >= targetFrames) break
|
||||
try {
|
||||
Thread.sleep(20)
|
||||
} catch (_: Throwable) {
|
||||
break
|
||||
}
|
||||
}
|
||||
writtenFrames.set(0L)
|
||||
}
|
||||
|
||||
private fun playAudioTrack(track: AudioTrack) {
|
||||
synchronized(audioTrackLock) {
|
||||
if (track.playState != AudioTrack.PLAYSTATE_PLAYING) {
|
||||
track.play()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun writeAudioTrack(track: AudioTrack, data: ByteArray) {
|
||||
synchronized(audioTrackLock) {
|
||||
if (track.playState == AudioTrack.PLAYSTATE_PLAYING && data.isNotEmpty()) {
|
||||
track.write(data, 0, data.size)
|
||||
writtenFrames.addAndGet((data.size / 2).toLong())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun pauseAndFlushAudioTrack() {
|
||||
synchronized(audioTrackLock) {
|
||||
val track = audioTrack ?: return
|
||||
try {
|
||||
if (track.playState == AudioTrack.PLAYSTATE_PLAYING) {
|
||||
track.pause()
|
||||
}
|
||||
} catch (e: Throwable) {
|
||||
Log.e(TAG, "Error pausing track: ${e.message}")
|
||||
}
|
||||
try {
|
||||
track.flush()
|
||||
} catch (e: Throwable) {
|
||||
Log.e(TAG, "Error flushing track: ${e.message}")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -11,7 +11,7 @@ class TtsController(private val context: Context) {
|
||||
|
||||
private var localTts: TtsManager? = null
|
||||
private var qcloudTts: QCloudTtsManager? = null
|
||||
private var useQCloudTts = false
|
||||
private var useQCloudTts = true
|
||||
|
||||
interface TtsCallback {
|
||||
fun onTtsStarted(text: String)
|
||||
@@ -27,29 +27,34 @@ class TtsController(private val context: Context) {
|
||||
|
||||
fun setCallback(callback: TtsCallback) {
|
||||
this.callback = callback
|
||||
bindCallbacksIfReady()
|
||||
}
|
||||
|
||||
private fun bindCallbacksIfReady() {
|
||||
val cb = callback ?: return
|
||||
localTts?.setCallback(object : TtsManager.TtsCallback {
|
||||
override fun onTtsStarted(text: String) {
|
||||
callback.onTtsStarted(text)
|
||||
cb.onTtsStarted(text)
|
||||
}
|
||||
|
||||
override fun onTtsCompleted() {
|
||||
callback.onTtsCompleted()
|
||||
cb.onTtsCompleted()
|
||||
}
|
||||
|
||||
override fun onTtsSegmentCompleted(durationMs: Long) {
|
||||
callback.onTtsSegmentCompleted(durationMs)
|
||||
cb.onTtsSegmentCompleted(durationMs)
|
||||
}
|
||||
|
||||
override fun isTtsStopped(): Boolean {
|
||||
return callback.isTtsStopped()
|
||||
return cb.isTtsStopped()
|
||||
}
|
||||
|
||||
override fun onClearAsrQueue() {
|
||||
callback.onClearAsrQueue()
|
||||
cb.onClearAsrQueue()
|
||||
}
|
||||
|
||||
override fun onSetSpeaking(speaking: Boolean) {
|
||||
callback.onSetSpeaking(speaking)
|
||||
cb.onSetSpeaking(speaking)
|
||||
}
|
||||
|
||||
override fun getCurrentTrace() = null
|
||||
@@ -73,36 +78,36 @@ class TtsController(private val context: Context) {
|
||||
}
|
||||
|
||||
override fun onEndTurn() {
|
||||
callback.onEndTurn()
|
||||
cb.onEndTurn()
|
||||
}
|
||||
})
|
||||
qcloudTts?.setCallback(object : QCloudTtsManager.TtsCallback {
|
||||
override fun onTtsStarted(text: String) {
|
||||
callback.onTtsStarted(text)
|
||||
cb.onTtsStarted(text)
|
||||
}
|
||||
|
||||
override fun onTtsCompleted() {
|
||||
callback.onTtsCompleted()
|
||||
cb.onTtsCompleted()
|
||||
}
|
||||
|
||||
override fun onTtsSegmentCompleted(durationMs: Long) {
|
||||
callback.onTtsSegmentCompleted(durationMs)
|
||||
cb.onTtsSegmentCompleted(durationMs)
|
||||
}
|
||||
|
||||
override fun isTtsStopped(): Boolean {
|
||||
return callback.isTtsStopped()
|
||||
return cb.isTtsStopped()
|
||||
}
|
||||
|
||||
override fun onClearAsrQueue() {
|
||||
callback.onClearAsrQueue()
|
||||
cb.onClearAsrQueue()
|
||||
}
|
||||
|
||||
override fun onSetSpeaking(speaking: Boolean) {
|
||||
callback.onSetSpeaking(speaking)
|
||||
cb.onSetSpeaking(speaking)
|
||||
}
|
||||
|
||||
override fun onEndTurn() {
|
||||
callback.onEndTurn()
|
||||
cb.onEndTurn()
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -118,6 +123,9 @@ class TtsController(private val context: Context) {
|
||||
val qcloudInit = qcloudTts?.init() ?: false
|
||||
Log.d(TAG, "QCloud TTS init: $qcloudInit")
|
||||
|
||||
// setCallback() is usually called before init(), so rebind here.
|
||||
bindCallbacksIfReady()
|
||||
|
||||
return localInit || qcloudInit
|
||||
}
|
||||
|
||||
|
||||
@@ -317,4 +317,17 @@ object FileHelper {
|
||||
connection.disconnect()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 下载测试图片
|
||||
*/
|
||||
fun downloadTestImage(url: String, destination: File): Boolean {
|
||||
return try {
|
||||
downloadFile(url, destination)
|
||||
true
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Failed to download test image: ${e.message}", e)
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,99 @@
|
||||
package com.digitalperson.util
|
||||
|
||||
import android.content.Context
|
||||
import android.location.LocationManager
|
||||
import android.net.ConnectivityManager
|
||||
import android.net.NetworkCapabilities
|
||||
import android.os.Build
|
||||
import java.text.SimpleDateFormat
|
||||
import java.util.*
|
||||
|
||||
class SmartGreetingUtil(private val context: Context) {
|
||||
|
||||
// 中国主要节日(月,日)
|
||||
private val festivals = mapOf(
|
||||
Pair(1, 1) to "元旦",
|
||||
Pair(2, 14) to "情人节",
|
||||
Pair(3, 8) to "妇女节",
|
||||
Pair(3, 12) to "植树节",
|
||||
Pair(4, 1) to "愚人节",
|
||||
Pair(5, 1) to "劳动节",
|
||||
Pair(5, 4) to "青年节",
|
||||
Pair(6, 1) to "儿童节",
|
||||
Pair(7, 1) to "建党节",
|
||||
Pair(8, 1) to "建军节",
|
||||
Pair(9, 10) to "教师节",
|
||||
Pair(10, 1) to "国庆节",
|
||||
Pair(12, 25) to "圣诞节"
|
||||
)
|
||||
|
||||
// 获取当前时间问候语
|
||||
fun getTimeBasedGreeting(): String {
|
||||
val hour = Calendar.getInstance().get(Calendar.HOUR_OF_DAY)
|
||||
return when {
|
||||
hour in 6..11 -> "早上好"
|
||||
hour in 12..17 -> "下午好"
|
||||
hour in 18..22 -> "晚上好"
|
||||
else -> "夜深了"
|
||||
}
|
||||
}
|
||||
|
||||
// 检查今天是否是节日
|
||||
fun isFestivalToday(): Pair<Boolean, String> {
|
||||
val calendar = Calendar.getInstance()
|
||||
val month = calendar.get(Calendar.MONTH) + 1 // 月份从0开始
|
||||
val day = calendar.get(Calendar.DAY_OF_MONTH)
|
||||
|
||||
val festival = festivals[Pair(month, day)]
|
||||
return Pair(festival != null, festival ?: "")
|
||||
}
|
||||
|
||||
// 获取完整的问候语
|
||||
fun getSmartGreeting(name: String?): String {
|
||||
val (isFestival, festivalName) = isFestivalToday()
|
||||
|
||||
return if (isFestival) {
|
||||
if (name.isNullOrBlank()) {
|
||||
"小朋友好!今天是${festivalName},祝你${festivalName}快乐哦!"
|
||||
} else {
|
||||
"你好,${name}!今天是${festivalName},祝你${festivalName}快乐哦!"
|
||||
}
|
||||
} else {
|
||||
val timeGreeting = getTimeBasedGreeting()
|
||||
if (name.isNullOrBlank()) {
|
||||
"${timeGreeting},很高兴见到你。请问你叫什么名字呀?"
|
||||
} else {
|
||||
"${timeGreeting},${name},很高兴见到你。"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查网络连接
|
||||
fun isNetworkConnected(): Boolean {
|
||||
val connectivityManager = context.getSystemService(Context.CONNECTIVITY_SERVICE) as ConnectivityManager
|
||||
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
|
||||
val network = connectivityManager.activeNetwork ?: return false
|
||||
val capabilities = connectivityManager.getNetworkCapabilities(network) ?: return false
|
||||
return capabilities.hasCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
|
||||
} else {
|
||||
val networkInfo = connectivityManager.activeNetworkInfo ?: return false
|
||||
return networkInfo.isConnected
|
||||
}
|
||||
}
|
||||
|
||||
// 获取当前日期和时间
|
||||
fun getCurrentDateTime(): String {
|
||||
val sdf = SimpleDateFormat("yyyy年MM月dd日 HH:mm", Locale.CHINA)
|
||||
return sdf.format(Date())
|
||||
}
|
||||
|
||||
// 生成节日问候语提示词(用于本地LLM)
|
||||
fun getFestivalGreetingPrompt(festivalName: String, name: String?): String {
|
||||
val basePrompt = "今天是$festivalName,请生成一条适合小学生的节日问候语,语气活泼友好,简短亲切。"
|
||||
return if (name.isNullOrBlank()) {
|
||||
basePrompt
|
||||
} else {
|
||||
"$basePrompt 问候对象叫$name。"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -130,7 +130,7 @@
|
||||
android:id="@+id/tts_mode_switch"
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content"
|
||||
android:checked="false" />
|
||||
android:checked="true" />
|
||||
</LinearLayout>
|
||||
|
||||
<LinearLayout
|
||||
|
||||
4
app/src/main/res/values-lv/values-lv.xml
Normal file
4
app/src/main/res/values-lv/values-lv.xml
Normal file
@@ -0,0 +1,4 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<resources>
|
||||
<!-- Latvian language resources -->
|
||||
</resources>
|
||||
@@ -3,5 +3,5 @@
|
||||
<string name="start">开始</string>
|
||||
<string name="stop">结束</string>
|
||||
<string name="hint">点击“开始”说话;识别后会请求大模型并用 TTS 播放回复。</string>
|
||||
<string name="system_prompt">你是一名小学女老师,喜欢回答学生的各种问题,请简洁但温柔地回答,每个回答不超过30字。在每次回复的最前面,用方括号标注你的心情,格式为[中性、悲伤、高兴、生气、恐惧、撒娇、震惊、厌恶],例如:[高兴]同学你好呀!请问有什么问题吗?</string>
|
||||
<string name="system_prompt">你是一个特殊学校一年级的数字人老师,你的名字叫,小鱼老师,你的任务是教这些特殊学校的学生一些基础的生活常识。和这些小学生说话要有耐心,一定要讲明白,尽量用简短的语句、活泼的语气来回复。你可以和他们日常对话和《教材》相关的话题。在生成回复后,请你先检查一下内容是否符合我们约定的主题。请使用口语对话的形式跟学生聊天。在每次回复的最前面,用方括号标注你的心情,格式为[中性、悲伤、高兴、生气、恐惧、撒娇、震惊、厌恶],例如:[高兴]同学你好呀!请问有什么问题吗?</string>
|
||||
</resources>
|
||||
|
||||
Reference in New Issue
Block a user