local llm supported
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
package com.digitalperson.util
|
||||
|
||||
import android.content.ContentUris
|
||||
import android.content.Context
|
||||
import android.provider.MediaStore
|
||||
import android.util.Log
|
||||
import com.digitalperson.config.AppConfig
|
||||
import java.io.File
|
||||
@@ -48,13 +50,270 @@ object FileHelper {
|
||||
)
|
||||
return copyAssetsToInternal(context, AppConfig.Asr.MODEL_DIR, outDir, files)
|
||||
}
|
||||
|
||||
@JvmStatic
|
||||
fun copyRetinaFaceAssets(context: Context): File {
|
||||
val outDir = File(context.filesDir, AppConfig.Face.MODEL_DIR)
|
||||
val files = arrayOf(AppConfig.Face.MODEL_NAME)
|
||||
return copyAssetsToInternal(context, AppConfig.Face.MODEL_DIR, outDir, files)
|
||||
}
|
||||
|
||||
@JvmStatic
|
||||
fun copyInsightFaceAssets(context: Context): File {
|
||||
val outDir = File(context.filesDir, AppConfig.FaceRecognition.MODEL_DIR)
|
||||
val files = arrayOf(AppConfig.FaceRecognition.MODEL_NAME)
|
||||
return copyAssetsToInternal(context, AppConfig.FaceRecognition.MODEL_DIR, outDir, files)
|
||||
}
|
||||
|
||||
fun ensureDir(dir: File): File {
|
||||
if (!dir.exists()) dir.mkdirs()
|
||||
if (!dir.exists()) {
|
||||
val created = dir.mkdirs()
|
||||
if (!created) {
|
||||
Log.e(TAG, "Failed to create directory: ${dir.absolutePath}")
|
||||
// 如果创建失败,使用应用内部存储
|
||||
return File("/data/data/${dir.parentFile?.parentFile?.name}/files/llm")
|
||||
}
|
||||
}
|
||||
return dir
|
||||
}
|
||||
|
||||
fun getAsrAudioDir(context: Context): File {
|
||||
return ensureDir(File(context.filesDir, "asr_audio"))
|
||||
}
|
||||
|
||||
// @JvmStatic
|
||||
// 当前使用的模型文件名
|
||||
private const val MODEL_FILE_NAME = "Qwen3-0.6B-rk3588-w8a8.rkllm"
|
||||
|
||||
fun getLLMModelPath(context: Context): String {
|
||||
Log.d(TAG, "=== getLLMModelPath START ===")
|
||||
|
||||
// 从应用内部存储目录加载模型
|
||||
val llmDir = ensureDir(File(context.filesDir, "llm"))
|
||||
|
||||
Log.d(TAG, "Loading models from: ${llmDir.absolutePath}")
|
||||
|
||||
// 检查文件是否存在
|
||||
val rkllmFile = File(llmDir, MODEL_FILE_NAME)
|
||||
|
||||
if (!rkllmFile.exists()) {
|
||||
Log.e(TAG, "RKLLM model not found: ${rkllmFile.absolutePath}")
|
||||
} else {
|
||||
Log.i(TAG, "RKLLM model exists, size: ${rkllmFile.length() / (1024*1024)} MB")
|
||||
}
|
||||
|
||||
val modelPath = rkllmFile.absolutePath
|
||||
Log.i(TAG, "Using RKLLM model path: $modelPath")
|
||||
Log.d(TAG, "=== getLLMModelPath END ===")
|
||||
return modelPath
|
||||
}
|
||||
|
||||
/**
|
||||
* 异步下载模型文件,带进度回调
|
||||
* @param context 上下文
|
||||
* @param onProgress 进度回调 (currentFile, downloadedBytes, totalBytes, progressPercent)
|
||||
* @param onComplete 完成回调 (success, message)
|
||||
*/
|
||||
@JvmStatic
|
||||
fun downloadModelFilesWithProgress(
|
||||
context: Context,
|
||||
onProgress: (String, Long, Long, Int) -> Unit,
|
||||
onComplete: (Boolean, String) -> Unit
|
||||
) {
|
||||
Log.d(TAG, "=== downloadModelFilesWithProgress START ===")
|
||||
|
||||
val llmDir = ensureDir(File(context.filesDir, "llm"))
|
||||
|
||||
// 模型文件列表 - 使用 DeepSeek-R1-Distill-Qwen-1.5B 模型
|
||||
val modelFiles = listOf(
|
||||
MODEL_FILE_NAME
|
||||
)
|
||||
|
||||
// 在后台线程下载
|
||||
Thread {
|
||||
try {
|
||||
var allSuccess = true
|
||||
var totalDownloaded: Long = 0
|
||||
var totalSize: Long = 0
|
||||
|
||||
// 首先计算总大小
|
||||
for (fileName in modelFiles) {
|
||||
val modelFile = File(llmDir, fileName)
|
||||
if (!modelFile.exists() || modelFile.length() == 0L) {
|
||||
val size = getFileSizeFromServer("http://192.168.1.19:5000/download/$fileName")
|
||||
if (size > 0) {
|
||||
totalSize += size
|
||||
} else {
|
||||
// 如果无法获取文件大小,使用估计值
|
||||
when (fileName) {
|
||||
MODEL_FILE_NAME -> totalSize += 1L * 1024 * 1024 * 1024 // 1.5B模型约1GB
|
||||
else -> totalSize += 1L * 1024 * 1024 * 1024 // 1GB 默认
|
||||
}
|
||||
Log.i(TAG, "Using estimated size for $fileName: ${totalSize / (1024*1024)} MB")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (fileName in modelFiles) {
|
||||
val modelFile = File(llmDir, fileName)
|
||||
if (!modelFile.exists() || modelFile.length() == 0L) {
|
||||
Log.i(TAG, "Downloading model file: $fileName")
|
||||
try {
|
||||
downloadFileWithProgress(
|
||||
"http://192.168.1.19:5000/download/$fileName",
|
||||
modelFile
|
||||
) { downloaded, total ->
|
||||
val progress = if (totalSize > 0) {
|
||||
((totalDownloaded + downloaded) * 100 / totalSize).toInt()
|
||||
} else 0
|
||||
onProgress(fileName, downloaded, total, progress)
|
||||
}
|
||||
totalDownloaded += modelFile.length()
|
||||
Log.i(TAG, "Downloaded model file: $fileName, size: ${modelFile.length() / (1024*1024)} MB")
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Failed to download model file $fileName: ${e.message}")
|
||||
allSuccess = false
|
||||
}
|
||||
} else {
|
||||
totalDownloaded += modelFile.length()
|
||||
Log.i(TAG, "Model file exists: $fileName, size: ${modelFile.length() / (1024*1024)} MB")
|
||||
}
|
||||
}
|
||||
Log.d(TAG, "=== downloadModelFilesWithProgress END ===")
|
||||
if (allSuccess) {
|
||||
onComplete(true, "模型下载完成")
|
||||
} else {
|
||||
onComplete(false, "部分模型下载失败")
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Download failed: ${e.message}")
|
||||
onComplete(false, "下载失败: ${e.message}")
|
||||
}
|
||||
}.start()
|
||||
}
|
||||
|
||||
/**
|
||||
* 从服务器获取文件大小
|
||||
*/
|
||||
private fun getFileSizeFromServer(url: String): Long {
|
||||
return try {
|
||||
val connection = java.net.URL(url).openConnection() as java.net.HttpURLConnection
|
||||
connection.requestMethod = "HEAD"
|
||||
connection.connectTimeout = 15000
|
||||
connection.readTimeout = 15000
|
||||
|
||||
// 从响应头获取 Content-Length,避免 int 溢出
|
||||
val contentLengthStr = connection.getHeaderField("Content-Length")
|
||||
var size = 0L
|
||||
|
||||
if (contentLengthStr != null) {
|
||||
try {
|
||||
size = contentLengthStr.toLong()
|
||||
if (size < 0) {
|
||||
Log.w(TAG, "Invalid Content-Length value: $size")
|
||||
size = 0
|
||||
}
|
||||
} catch (e: NumberFormatException) {
|
||||
Log.w(TAG, "Invalid Content-Length format: $contentLengthStr")
|
||||
size = 0
|
||||
}
|
||||
} else {
|
||||
val contentLength = connection.contentLength
|
||||
if (contentLength > 0) {
|
||||
size = contentLength.toLong()
|
||||
} else {
|
||||
Log.w(TAG, "Content-Length not available or invalid: $contentLength")
|
||||
size = 0
|
||||
}
|
||||
}
|
||||
|
||||
connection.disconnect()
|
||||
Log.i(TAG, "File size for $url: $size bytes")
|
||||
size
|
||||
} catch (e: Exception) {
|
||||
Log.w(TAG, "Failed to get file size: ${e.message}")
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 从网络下载文件,带进度回调
|
||||
*/
|
||||
private fun downloadFileWithProgress(
|
||||
url: String,
|
||||
destination: File,
|
||||
onProgress: (Long, Long) -> Unit
|
||||
) {
|
||||
val connection = java.net.URL(url).openConnection() as java.net.HttpURLConnection
|
||||
connection.connectTimeout = 30000
|
||||
connection.readTimeout = 6000000
|
||||
|
||||
// 从响应头获取 Content-Length,避免 int 溢出
|
||||
val contentLengthStr = connection.getHeaderField("Content-Length")
|
||||
val totalSize = if (contentLengthStr != null) {
|
||||
try {
|
||||
contentLengthStr.toLong()
|
||||
} catch (e: NumberFormatException) {
|
||||
Log.w(TAG, "Invalid Content-Length format: $contentLengthStr")
|
||||
0
|
||||
}
|
||||
} else {
|
||||
connection.contentLength.toLong()
|
||||
}
|
||||
Log.i(TAG, "Downloading file $url, size: $totalSize bytes")
|
||||
|
||||
try {
|
||||
connection.inputStream.use { input ->
|
||||
FileOutputStream(destination).use { output ->
|
||||
val buffer = ByteArray(8192)
|
||||
var downloaded: Long = 0
|
||||
var bytesRead: Int
|
||||
|
||||
while (input.read(buffer).also { bytesRead = it } != -1) {
|
||||
output.write(buffer, 0, bytesRead)
|
||||
downloaded += bytesRead
|
||||
onProgress(downloaded, totalSize)
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
connection.disconnect()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查本地 LLM 模型是否可用
|
||||
*/
|
||||
@JvmStatic
|
||||
fun isLocalLLMAvailable(context: Context): Boolean {
|
||||
val llmDir = File(context.filesDir, "llm")
|
||||
|
||||
val rkllmFile = File(llmDir, MODEL_FILE_NAME)
|
||||
|
||||
val rkllmExists = rkllmFile.exists() && rkllmFile.length() > 0
|
||||
|
||||
Log.i(TAG, "LLM model check: rkllm=$rkllmExists")
|
||||
Log.i(TAG, "RKLLM file: ${rkllmFile.absolutePath}, size: ${if (rkllmFile.exists()) rkllmFile.length() / (1024*1024) else 0} MB")
|
||||
|
||||
return rkllmExists
|
||||
}
|
||||
|
||||
/**
|
||||
* 从网络下载文件
|
||||
*/
|
||||
private fun downloadFile(url: String, destination: File) {
|
||||
val connection = java.net.URL(url).openConnection() as java.net.HttpURLConnection
|
||||
connection.connectTimeout = 30000 // 30秒超时
|
||||
connection.readTimeout = 60000 // 60秒读取超时
|
||||
|
||||
try {
|
||||
connection.inputStream.use { input ->
|
||||
FileOutputStream(destination).use { output ->
|
||||
input.copyTo(output)
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
connection.disconnect()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user