205 lines
7.3 KiB
Kotlin
205 lines
7.3 KiB
Kotlin
package com.digitalperson.question
|
||
|
||
import android.content.Context
|
||
import androidx.test.core.app.ApplicationProvider
|
||
import androidx.test.ext.junit.runners.AndroidJUnit4
|
||
import com.digitalperson.data.AppDatabase
|
||
import com.digitalperson.data.entity.Question
|
||
import com.digitalperson.interaction.UserMemoryStore
|
||
import kotlinx.coroutines.runBlocking
|
||
import org.json.JSONObject
|
||
import org.junit.After
|
||
import org.junit.Before
|
||
import org.junit.Test
|
||
import org.junit.runner.RunWith
|
||
import java.io.InputStream
|
||
|
||
/**
|
||
* 题目生成智能体测试
|
||
* 可以在模拟器或本地运行,不需要完整启动应用
|
||
*/
|
||
@RunWith(AndroidJUnit4::class)
|
||
class QuestionGenerationAgentTest {
|
||
|
||
private lateinit var context: Context
|
||
private lateinit var database: AppDatabase
|
||
private lateinit var userMemoryStore: UserMemoryStore
|
||
|
||
@Before
|
||
fun setUp() {
|
||
context = ApplicationProvider.getApplicationContext()
|
||
database = AppDatabase.getInstance(context)
|
||
userMemoryStore = UserMemoryStore(context)
|
||
}
|
||
|
||
@After
|
||
fun tearDown() {
|
||
// 清理测试数据
|
||
// database.clearAllTables()
|
||
}
|
||
|
||
@Test
|
||
fun testLoadPromptPoolFromJson() {
|
||
// 测试JSON提示词池加载
|
||
val inputStream: InputStream = context.assets.open("question_prompts.json")
|
||
val jsonString = inputStream.bufferedReader().use { it.readText() }
|
||
val jsonArray = org.json.JSONArray(jsonString)
|
||
|
||
println("✅ Loaded ${jsonArray.length()} prompts from JSON")
|
||
|
||
// 验证每个提示词的格式
|
||
for (i in 0 until jsonArray.length()) {
|
||
val json = jsonArray.getJSONObject(i)
|
||
assert(json.has("subject")) { "Missing subject in prompt $i" }
|
||
assert(json.has("grade")) { "Missing grade in prompt $i" }
|
||
assert(json.has("topic")) { "Missing topic in prompt $i" }
|
||
assert(json.has("difficulty")) { "Missing difficulty in prompt $i" }
|
||
assert(json.has("promptTemplate")) { "Missing promptTemplate in prompt $i" }
|
||
|
||
println(" - Prompt $i: ${json.getString("subject")} / ${json.getString("topic")}")
|
||
}
|
||
|
||
assert(jsonArray.length() > 0) { "Should have at least 1 prompt" }
|
||
}
|
||
|
||
@Test
|
||
fun testQuestionDatabaseOperations() = runBlocking {
|
||
// 测试数据库操作
|
||
val questionDao = database.questionDao()
|
||
|
||
// 插入测试题目
|
||
val testQuestion = Question(
|
||
id = 0,
|
||
content = "测试题目:苹果和香蕉哪个大?",
|
||
answer = "香蕉",
|
||
subject = "生活数学",
|
||
grade = 1,
|
||
difficulty = 1,
|
||
createdAt = System.currentTimeMillis()
|
||
)
|
||
|
||
val questionId = questionDao.insert(testQuestion)
|
||
println("✅ Inserted question with ID: $questionId")
|
||
|
||
// 查询题目
|
||
val retrievedQuestion = questionDao.getQuestionById(questionId)
|
||
assert(retrievedQuestion != null) { "Should retrieve inserted question" }
|
||
assert(retrievedQuestion?.content == testQuestion.content) { "Content should match" }
|
||
println("✅ Retrieved question: ${retrievedQuestion?.content}")
|
||
|
||
// 测试未答题计数
|
||
val userId = "test_user_001"
|
||
val count = questionDao.countUnansweredQuestions(userId)
|
||
println("✅ Unanswered questions count: $count")
|
||
|
||
// 测试获取随机未答题
|
||
val randomQuestion = questionDao.getRandomUnansweredQuestion(userId)
|
||
println("✅ Random unanswered question: ${randomQuestion?.content?.take(20)}...")
|
||
}
|
||
|
||
@Test
|
||
fun testJsonResponseParsing() {
|
||
// 测试LLM JSON响应解析
|
||
val testResponses = listOf(
|
||
"""{"content": "苹果和香蕉哪个大?", "answer": "香蕉", "explanation": "香蕉通常比苹果大"}""",
|
||
"""
|
||
{
|
||
"content": "2个苹果和5个苹果比,谁多?",
|
||
"answer": "5个苹果多",
|
||
"explanation": "5大于2"
|
||
}
|
||
""",
|
||
"""Some text before {"content": "测试题目", "answer": "答案"} some text after"""
|
||
)
|
||
|
||
testResponses.forEachIndexed { index, response ->
|
||
val json = extractJsonFromResponse(response)
|
||
if (json != null) {
|
||
println("✅ Test $index: Parsed successfully")
|
||
println(" Content: ${json.getString("content")}")
|
||
println(" Answer: ${json.getString("answer")}")
|
||
} else {
|
||
println("❌ Test $index: Failed to parse JSON")
|
||
}
|
||
}
|
||
}
|
||
|
||
@Test
|
||
fun testUserMemoryOperations() = runBlocking {
|
||
// 测试用户记忆操作
|
||
val userId = "test_user_001"
|
||
|
||
// 创建或更新用户
|
||
userMemoryStore.upsertUserSeen(userId, "测试小朋友")
|
||
println("✅ Created/updated user: $userId")
|
||
|
||
// 获取用户信息
|
||
val memory = userMemoryStore.getMemory(userId)
|
||
println("✅ User memory: displayName=${memory?.displayName}")
|
||
|
||
// 测试未答题计数
|
||
val unansweredCount = userMemoryStore.countUnansweredQuestions(userId)
|
||
println("✅ Unanswered questions for $userId: $unansweredCount")
|
||
}
|
||
|
||
@Test
|
||
fun testPromptTemplateBuilding() = runBlocking {
|
||
// 测试提示词模板构建
|
||
val userProfile = userMemoryStore.getMemory("test_user_001")
|
||
|
||
val promptTemplate = """
|
||
你是一个专门为特殊教育儿童设计题目的教育专家。请根据以下要求生成一个题目:
|
||
|
||
用户信息:
|
||
${userProfile?.displayName?.let { "姓名:$it," } ?: ""}
|
||
${userProfile?.age?.let { "年龄:$it," } ?: ""}
|
||
|
||
学科:生活数学
|
||
年级:1
|
||
主题:比大小
|
||
难度:1
|
||
|
||
具体要求:
|
||
基于以下学习目标,针对一年级小学生出1道题目:
|
||
1. 初步感知物品的大小
|
||
2. 会比较2个物品的大小
|
||
|
||
通用要求:
|
||
1. 题目要贴近生活,适合智力障碍儿童理解
|
||
2. 语言简单明了,避免复杂句式
|
||
3. 题目内容积极向上
|
||
4. 提供标准答案
|
||
5. 确保题目没有重复
|
||
6. 题目要有趣味性,能吸引学生注意力
|
||
|
||
请以JSON格式返回,格式如下:
|
||
{
|
||
"content": "题目内容",
|
||
"answer": "标准答案",
|
||
"explanation": "题目解析(可选)"
|
||
}
|
||
|
||
只返回JSON,不要其他内容。
|
||
""".trimIndent()
|
||
|
||
println("✅ Generated prompt template:")
|
||
println(promptTemplate)
|
||
println("\n✅ Prompt length: ${promptTemplate.length} characters")
|
||
}
|
||
|
||
/**
|
||
* 从响应中提取JSON
|
||
*/
|
||
private fun extractJsonFromResponse(response: String): JSONObject? {
|
||
val trimmed = response.trim()
|
||
val start = trimmed.indexOf('{')
|
||
val end = trimmed.lastIndexOf('}')
|
||
|
||
if (start >= 0 && end > start) {
|
||
val jsonStr = trimmed.substring(start, end + 1)
|
||
return JSONObject(jsonStr)
|
||
}
|
||
return null
|
||
}
|
||
}
|