local llm supported
This commit is contained in:
3
.vscode/settings.json
vendored
3
.vscode/settings.json
vendored
@@ -1,3 +0,0 @@
|
||||
{
|
||||
"java.configuration.updateBuildConfiguration": "interactive"
|
||||
}
|
||||
@@ -5,7 +5,7 @@ plugins {
|
||||
|
||||
android {
|
||||
namespace 'com.digitalperson'
|
||||
compileSdk 33
|
||||
compileSdk 34
|
||||
|
||||
buildFeatures {
|
||||
buildConfig true
|
||||
@@ -24,6 +24,11 @@ android {
|
||||
}
|
||||
}
|
||||
|
||||
// 不压缩大文件,避免内存不足错误
|
||||
aaptOptions {
|
||||
noCompress 'rknn', 'rkllm', 'onnx', 'model', 'bin', 'json'
|
||||
}
|
||||
|
||||
defaultConfig {
|
||||
applicationId "com.digitalperson"
|
||||
minSdk 21
|
||||
@@ -40,7 +45,7 @@ android {
|
||||
buildConfigField "String", "LLM_API_URL", "\"${(project.findProperty('LLM_API_URL') ?: 'https://ark.cn-beijing.volces.com/api/v3/chat/completions').toString()}\""
|
||||
buildConfigField "String", "LLM_API_KEY", "\"${(project.findProperty('LLM_API_KEY') ?: '').toString()}\""
|
||||
buildConfigField "String", "LLM_MODEL", "\"${(project.findProperty('LLM_MODEL') ?: 'doubao-1-5-pro-32k-character-250228').toString()}\""
|
||||
buildConfigField "boolean", "USE_LIVE2D", "${(project.findProperty('USE_LIVE2D') ?: 'false').toString()}"
|
||||
buildConfigField "boolean", "USE_LIVE2D", "${(project.findProperty('USE_LIVE2D') ?: 'true').toString()}"
|
||||
|
||||
ndk {
|
||||
abiFilters "arm64-v8a"
|
||||
@@ -63,6 +68,8 @@ android {
|
||||
}
|
||||
|
||||
dependencies {
|
||||
|
||||
|
||||
implementation 'androidx.core:core-ktx:1.7.0'
|
||||
implementation 'androidx.appcompat:appcompat:1.6.1'
|
||||
implementation 'com.google.android.material:material:1.9.0'
|
||||
@@ -73,6 +80,15 @@ dependencies {
|
||||
androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.1'
|
||||
// ExoPlayer for video playback (used to show silent / speaking videos)
|
||||
implementation 'com.google.android.exoplayer:exoplayer:2.18.6'
|
||||
implementation 'androidx.camera:camera-core:1.3.4'
|
||||
implementation 'androidx.camera:camera-camera2:1.3.4'
|
||||
implementation 'androidx.camera:camera-lifecycle:1.3.4'
|
||||
implementation 'androidx.camera:camera-view:1.3.4'
|
||||
implementation project(':framework')
|
||||
implementation files('../Live2DFramework/Core/android/Live2DCubismCore.aar')
|
||||
|
||||
// Tencent Cloud TTS SDK
|
||||
implementation files('libs/realtime_tts-release-v2.0.16-20260128-d80cafe.aar')
|
||||
implementation 'com.google.code.gson:gson:2.8.9'
|
||||
implementation 'com.squareup.okhttp3:okhttp:4.9.3'
|
||||
}
|
||||
|
||||
BIN
app/libs/realtime_tts-release-v2.0.16-20260128-d80cafe.aar
Normal file
BIN
app/libs/realtime_tts-release-v2.0.16-20260128-d80cafe.aar
Normal file
Binary file not shown.
@@ -71,6 +71,7 @@ TTS Sherpa-ONNX VITS .onnx ❌ 否 CPU ONNX Runtime
|
||||
- haru_g_m25 - 扁嘴
|
||||
- haru_g_m24 - 低头斜看地板,收手到背后
|
||||
- haru_g_m05 扁嘴,张开双手
|
||||
- haru_g_m16 双手捧腮,思考
|
||||
|
||||
### 😠 愤怒类情绪
|
||||
- haru_g_m11 双手交叉,摇头,扁嘴
|
||||
@@ -90,8 +91,6 @@ TTS Sherpa-ONNX VITS .onnx ❌ 否 CPU ONNX Runtime
|
||||
- haru_g_m12 摆手,摇头
|
||||
|
||||
### 😕 困惑类情绪
|
||||
- haru_g_m20 手指点腮,思考,皱眉
|
||||
- haru_g_m16 双手捧腮,思考
|
||||
- haru_g_m14 身体前倾,皱眉
|
||||
- haru_g_m13 身体前倾,双手分开
|
||||
|
||||
@@ -99,6 +98,22 @@ TTS Sherpa-ONNX VITS .onnx ❌ 否 CPU ONNX Runtime
|
||||
- haru_g_m19 脸红微笑
|
||||
|
||||
|
||||
### 担心
|
||||
- haru_g_m20 手指点腮,思考,皱眉
|
||||
|
||||
### ❤️ 关心类情绪
|
||||
- haru_g_m17 靠近侧脸
|
||||
|
||||
6. 其实可以抄一下讯飞的超脑平台的功能:
|
||||
https://aiui-doc.xf-yun.com/project-2/doc-397/
|
||||
|
||||
7. 人脸检测使用的是RKNN zoo里的 retinaface模型,转成了rknn格式,并且使用了wider_face的数据集(验证集)进行了校准,下载地址:
|
||||
https://www.modelscope.cn/datasets/shaoxuan/WIDER_FACE/files
|
||||
|
||||
8. 人脸识别模型是insightface的r18模型,转成了rknn格式,并且使用了 lfw 的数据集进行了校准,下载地址:
|
||||
https://tianchi.aliyun.com/dataset/93864
|
||||
|
||||
9.
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -2,13 +2,17 @@
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android">
|
||||
|
||||
<uses-permission android:name="android.permission.RECORD_AUDIO" />
|
||||
<uses-permission android:name="android.permission.CAMERA" />
|
||||
<uses-permission android:name="android.permission.INTERNET" />
|
||||
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE" />
|
||||
<uses-feature android:name="android.hardware.camera.any" />
|
||||
|
||||
<application
|
||||
android:allowBackup="true"
|
||||
android:label="@string/app_name"
|
||||
android:supportsRtl="true"
|
||||
android:theme="@style/Theme.DigitalPerson">
|
||||
android:theme="@style/Theme.DigitalPerson"
|
||||
android:usesCleartextTraffic="true">
|
||||
|
||||
<activity
|
||||
android:name="com.digitalperson.EntryActivity"
|
||||
|
||||
BIN
app/src/main/assets/Insightface/ms1mv3_arcface_r18.rknn
Normal file
BIN
app/src/main/assets/Insightface/ms1mv3_arcface_r18.rknn
Normal file
Binary file not shown.
BIN
app/src/main/assets/RetinaFace/RetinaFace_mobile320.rknn
Normal file
BIN
app/src/main/assets/RetinaFace/RetinaFace_mobile320.rknn
Normal file
Binary file not shown.
210
app/src/main/cpp/ArcFaceEngineRKNN.cpp
Normal file
210
app/src/main/cpp/ArcFaceEngineRKNN.cpp
Normal file
@@ -0,0 +1,210 @@
|
||||
#include "ArcFaceEngineRKNN.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <android/log.h>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
|
||||
#define LOG_TAG "ArcFaceRKNN"
|
||||
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__)
|
||||
#define LOGW(...) __android_log_print(ANDROID_LOG_WARN, LOG_TAG, __VA_ARGS__)
|
||||
#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__)
|
||||
|
||||
ArcFaceEngineRKNN::ArcFaceEngineRKNN() = default;
|
||||
|
||||
ArcFaceEngineRKNN::~ArcFaceEngineRKNN() {
|
||||
release();
|
||||
}
|
||||
|
||||
int ArcFaceEngineRKNN::init(const char* modelPath) {
|
||||
release();
|
||||
int ret = rknn_init(&ctx_, (void*)modelPath, 0, 0, nullptr);
|
||||
if (ret != RKNN_SUCC) {
|
||||
LOGE("rknn_init failed: %d model=%s", ret, modelPath);
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::memset(&ioNum_, 0, sizeof(ioNum_));
|
||||
ret = rknn_query(ctx_, RKNN_QUERY_IN_OUT_NUM, &ioNum_, sizeof(ioNum_));
|
||||
if (ret != RKNN_SUCC || ioNum_.n_input < 1 || ioNum_.n_output < 1) {
|
||||
LOGE("query io num failed: ret=%d in=%u out=%u", ret, ioNum_.n_input, ioNum_.n_output);
|
||||
release();
|
||||
return ret != RKNN_SUCC ? ret : -1;
|
||||
}
|
||||
|
||||
std::memset(&inputAttr_, 0, sizeof(inputAttr_));
|
||||
inputAttr_.index = 0;
|
||||
ret = rknn_query(ctx_, RKNN_QUERY_INPUT_ATTR, &inputAttr_, sizeof(inputAttr_));
|
||||
if (ret != RKNN_SUCC) {
|
||||
LOGE("query input attr failed: %d", ret);
|
||||
release();
|
||||
return ret;
|
||||
}
|
||||
if (inputAttr_.n_dims == 4) {
|
||||
if (inputAttr_.fmt == RKNN_TENSOR_NHWC) {
|
||||
inputH_ = static_cast<int>(inputAttr_.dims[1]);
|
||||
inputW_ = static_cast<int>(inputAttr_.dims[2]);
|
||||
} else {
|
||||
inputH_ = static_cast<int>(inputAttr_.dims[2]);
|
||||
inputW_ = static_cast<int>(inputAttr_.dims[3]);
|
||||
}
|
||||
} else if (inputAttr_.n_dims == 3) {
|
||||
inputH_ = static_cast<int>(inputAttr_.dims[1]);
|
||||
inputW_ = static_cast<int>(inputAttr_.dims[2]);
|
||||
}
|
||||
|
||||
std::memset(&outputAttr_, 0, sizeof(outputAttr_));
|
||||
outputAttr_.index = 0;
|
||||
ret = rknn_query(ctx_, RKNN_QUERY_OUTPUT_ATTR, &outputAttr_, sizeof(outputAttr_));
|
||||
if (ret != RKNN_SUCC) {
|
||||
LOGW("query output attr failed: %d", ret);
|
||||
}
|
||||
|
||||
initialized_ = true;
|
||||
LOGI("ArcFace initialized input=%dx%d", inputW_, inputH_);
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::vector<float> ArcFaceEngineRKNN::extractEmbedding(
|
||||
const uint32_t* argbPixels,
|
||||
int width,
|
||||
int height,
|
||||
int strideBytes,
|
||||
float left,
|
||||
float top,
|
||||
float right,
|
||||
float bottom) {
|
||||
LOGI("extractEmbedding called: initialized=%d, ctx=%p, pixels=%p, width=%d, height=%d, left=%.2f, top=%.2f, right=%.2f, bottom=%.2f",
|
||||
initialized_, ctx_, argbPixels, width, height, left, top, right, bottom);
|
||||
|
||||
std::vector<float> empty;
|
||||
if (!initialized_ || ctx_ == 0 || argbPixels == nullptr || width <= 0 || height <= 0) {
|
||||
LOGW("extractEmbedding failed: invalid parameters");
|
||||
return empty;
|
||||
}
|
||||
|
||||
const float faceW = right - left;
|
||||
const float faceH = bottom - top;
|
||||
LOGI("Face size: width=%.2f, height=%.2f", faceW, faceH);
|
||||
|
||||
if (faceW < 4.0f || faceH < 4.0f) {
|
||||
LOGW("extractEmbedding failed: face too small");
|
||||
return empty;
|
||||
}
|
||||
|
||||
const float pad = 0.15f;
|
||||
float x1f = std::max(0.0f, left - faceW * pad);
|
||||
float y1f = std::max(0.0f, top - faceH * pad);
|
||||
float x2f = std::min(static_cast<float>(width), right + faceW * pad);
|
||||
float y2f = std::min(static_cast<float>(height), bottom + faceH * pad);
|
||||
int x1 = static_cast<int>(std::floor(x1f));
|
||||
int y1 = static_cast<int>(std::floor(y1f));
|
||||
int x2 = static_cast<int>(std::ceil(x2f));
|
||||
int y2 = static_cast<int>(std::ceil(y2f));
|
||||
if (x2 <= x1 || y2 <= y1) {
|
||||
return empty;
|
||||
}
|
||||
|
||||
const int cropW = x2 - x1;
|
||||
const int cropH = y2 - y1;
|
||||
const int srcStridePx = strideBytes / 4;
|
||||
std::vector<uint8_t> rgb(inputW_ * inputH_ * 3);
|
||||
for (int y = 0; y < inputH_; ++y) {
|
||||
const int sy = y1 + (y * cropH / inputH_);
|
||||
const uint32_t* srcRow = argbPixels + sy * srcStridePx;
|
||||
uint8_t* dst = rgb.data() + y * inputW_ * 3;
|
||||
for (int x = 0; x < inputW_; ++x) {
|
||||
const int sx = x1 + (x * cropW / inputW_);
|
||||
const uint32_t pixel = srcRow[sx];
|
||||
dst[3 * x + 0] = (pixel >> 16) & 0xFF;
|
||||
dst[3 * x + 1] = (pixel >> 8) & 0xFF;
|
||||
dst[3 * x + 2] = pixel & 0xFF;
|
||||
}
|
||||
}
|
||||
|
||||
rknn_input input{};
|
||||
input.index = 0;
|
||||
input.type = RKNN_TENSOR_UINT8;
|
||||
input.size = rgb.size();
|
||||
input.buf = rgb.data();
|
||||
input.pass_through = 0;
|
||||
input.fmt = (inputAttr_.fmt == RKNN_TENSOR_NCHW) ? RKNN_TENSOR_NCHW : RKNN_TENSOR_NHWC;
|
||||
|
||||
std::vector<uint8_t> nchw;
|
||||
if (input.fmt == RKNN_TENSOR_NCHW) {
|
||||
nchw.resize(rgb.size());
|
||||
const int hw = inputW_ * inputH_;
|
||||
for (int i = 0; i < hw; ++i) {
|
||||
nchw[i] = rgb[3 * i + 0];
|
||||
nchw[hw + i] = rgb[3 * i + 1];
|
||||
nchw[2 * hw + i] = rgb[3 * i + 2];
|
||||
}
|
||||
input.buf = nchw.data();
|
||||
}
|
||||
|
||||
int ret = rknn_inputs_set(ctx_, 1, &input);
|
||||
if (ret != RKNN_SUCC) {
|
||||
LOGW("rknn_inputs_set failed: %d", ret);
|
||||
return empty;
|
||||
}
|
||||
ret = rknn_run(ctx_, nullptr);
|
||||
if (ret != RKNN_SUCC) {
|
||||
LOGW("rknn_run failed: %d", ret);
|
||||
return empty;
|
||||
}
|
||||
LOGI("rknn_run succeeded");
|
||||
|
||||
std::vector<rknn_output> outputs(ioNum_.n_output);
|
||||
for (uint32_t i = 0; i < ioNum_.n_output; ++i) {
|
||||
std::memset(&outputs[i], 0, sizeof(rknn_output));
|
||||
outputs[i].want_float = 1;
|
||||
}
|
||||
ret = rknn_outputs_get(ctx_, ioNum_.n_output, outputs.data(), nullptr);
|
||||
if (ret != RKNN_SUCC) {
|
||||
LOGW("rknn_outputs_get failed: %d", ret);
|
||||
return empty;
|
||||
}
|
||||
LOGI("rknn_outputs_get succeeded: n_output=%u", ioNum_.n_output);
|
||||
|
||||
if (outputs[0].buf == nullptr) {
|
||||
LOGW("Output buffer is null");
|
||||
rknn_outputs_release(ctx_, ioNum_.n_output, outputs.data());
|
||||
return empty;
|
||||
}
|
||||
|
||||
size_t elems = outputAttr_.n_elems > 0 ? static_cast<size_t>(outputAttr_.n_elems) : 0;
|
||||
if (elems == 0) {
|
||||
elems = 1;
|
||||
for (uint32_t d = 0; d < outputAttr_.n_dims; ++d) {
|
||||
if (outputAttr_.dims[d] > 0) {
|
||||
elems *= static_cast<size_t>(outputAttr_.dims[d]);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (elems == 0) {
|
||||
rknn_outputs_release(ctx_, ioNum_.n_output, outputs.data());
|
||||
return empty;
|
||||
}
|
||||
|
||||
const float* ptr = reinterpret_cast<const float*>(outputs[0].buf);
|
||||
std::vector<float> embedding(ptr, ptr + elems);
|
||||
rknn_outputs_release(ctx_, ioNum_.n_output, outputs.data());
|
||||
|
||||
// L2 normalize for cosine similarity.
|
||||
float sum = 0.0f;
|
||||
for (float v : embedding) sum += v * v;
|
||||
const float norm = std::sqrt(std::max(sum, 1e-12f));
|
||||
for (float& v : embedding) v /= norm;
|
||||
return embedding;
|
||||
}
|
||||
|
||||
void ArcFaceEngineRKNN::release() {
|
||||
if (ctx_ != 0) {
|
||||
rknn_destroy(ctx_);
|
||||
ctx_ = 0;
|
||||
}
|
||||
std::memset(&ioNum_, 0, sizeof(ioNum_));
|
||||
std::memset(&inputAttr_, 0, sizeof(inputAttr_));
|
||||
std::memset(&outputAttr_, 0, sizeof(outputAttr_));
|
||||
initialized_ = false;
|
||||
}
|
||||
36
app/src/main/cpp/ArcFaceEngineRKNN.h
Normal file
36
app/src/main/cpp/ArcFaceEngineRKNN.h
Normal file
@@ -0,0 +1,36 @@
|
||||
#ifndef DIGITAL_PERSON_ARCFACE_ENGINE_RKNN_H
|
||||
#define DIGITAL_PERSON_ARCFACE_ENGINE_RKNN_H
|
||||
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
#include "rknn_api.h"
|
||||
|
||||
class ArcFaceEngineRKNN {
|
||||
public:
|
||||
ArcFaceEngineRKNN();
|
||||
~ArcFaceEngineRKNN();
|
||||
|
||||
int init(const char* modelPath);
|
||||
std::vector<float> extractEmbedding(
|
||||
const uint32_t* argbPixels,
|
||||
int width,
|
||||
int height,
|
||||
int strideBytes,
|
||||
float left,
|
||||
float top,
|
||||
float right,
|
||||
float bottom);
|
||||
void release();
|
||||
|
||||
private:
|
||||
rknn_context ctx_ = 0;
|
||||
bool initialized_ = false;
|
||||
rknn_input_output_num ioNum_{};
|
||||
rknn_tensor_attr inputAttr_{};
|
||||
rknn_tensor_attr outputAttr_{};
|
||||
int inputW_ = 112;
|
||||
int inputH_ = 112;
|
||||
};
|
||||
|
||||
#endif
|
||||
109
app/src/main/cpp/ArcFaceEngineRKNNJNI.cpp
Normal file
109
app/src/main/cpp/ArcFaceEngineRKNNJNI.cpp
Normal file
@@ -0,0 +1,109 @@
|
||||
#include <jni.h>
|
||||
#include <android/bitmap.h>
|
||||
#include <android/log.h>
|
||||
|
||||
#include "ArcFaceEngineRKNN.h"
|
||||
|
||||
#define LOG_TAG "ArcFaceJNI"
|
||||
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__)
|
||||
#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__)
|
||||
|
||||
extern "C" {
|
||||
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_digitalperson_engine_ArcFaceEngineRKNN_createEngineNative(JNIEnv* env, jobject thiz) {
|
||||
auto* engine = new ArcFaceEngineRKNN();
|
||||
if (engine == nullptr) return 0;
|
||||
return reinterpret_cast<jlong>(engine);
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_com_digitalperson_engine_ArcFaceEngineRKNN_initNative(
|
||||
JNIEnv* env,
|
||||
jobject thiz,
|
||||
jlong ptr,
|
||||
jstring modelPath) {
|
||||
auto* engine = reinterpret_cast<ArcFaceEngineRKNN*>(ptr);
|
||||
if (engine == nullptr || modelPath == nullptr) return -1;
|
||||
const char* model = env->GetStringUTFChars(modelPath, nullptr);
|
||||
if (model == nullptr) return -1;
|
||||
int ret = engine->init(model);
|
||||
env->ReleaseStringUTFChars(modelPath, model);
|
||||
return ret;
|
||||
}
|
||||
|
||||
JNIEXPORT jfloatArray JNICALL
|
||||
Java_com_digitalperson_engine_ArcFaceEngineRKNN_extractEmbeddingNative(
|
||||
JNIEnv* env,
|
||||
jobject thiz,
|
||||
jlong ptr,
|
||||
jobject bitmapObj,
|
||||
jfloat left,
|
||||
jfloat top,
|
||||
jfloat right,
|
||||
jfloat bottom) {
|
||||
LOGI("extractEmbeddingNative called: ptr=%ld, left=%.2f, top=%.2f, right=%.2f, bottom=%.2f", ptr, left, top, right, bottom);
|
||||
|
||||
auto* engine = reinterpret_cast<ArcFaceEngineRKNN*>(ptr);
|
||||
if (engine == nullptr || bitmapObj == nullptr) {
|
||||
LOGE("Engine or bitmap is null: engine=%p, bitmap=%p", engine, bitmapObj);
|
||||
return env->NewFloatArray(0);
|
||||
}
|
||||
|
||||
AndroidBitmapInfo info{};
|
||||
if (AndroidBitmap_getInfo(env, bitmapObj, &info) < 0) {
|
||||
LOGE("AndroidBitmap_getInfo failed");
|
||||
return env->NewFloatArray(0);
|
||||
}
|
||||
LOGI("Bitmap info: width=%d, height=%d, stride=%d, format=%d", info.width, info.height, info.stride, info.format);
|
||||
|
||||
if (info.format != ANDROID_BITMAP_FORMAT_RGBA_8888) {
|
||||
LOGE("Unsupported bitmap format: %d", info.format);
|
||||
return env->NewFloatArray(0);
|
||||
}
|
||||
|
||||
void* pixels = nullptr;
|
||||
if (AndroidBitmap_lockPixels(env, bitmapObj, &pixels) < 0 || pixels == nullptr) {
|
||||
LOGE("AndroidBitmap_lockPixels failed");
|
||||
return env->NewFloatArray(0);
|
||||
}
|
||||
LOGI("Bitmap pixels locked successfully");
|
||||
|
||||
std::vector<float> emb = engine->extractEmbedding(
|
||||
reinterpret_cast<uint32_t*>(pixels),
|
||||
static_cast<int>(info.width),
|
||||
static_cast<int>(info.height),
|
||||
static_cast<int>(info.stride),
|
||||
left,
|
||||
top,
|
||||
right,
|
||||
bottom);
|
||||
|
||||
LOGI("Engine extractEmbedding returned: size=%zu", emb.size());
|
||||
|
||||
AndroidBitmap_unlockPixels(env, bitmapObj);
|
||||
|
||||
jfloatArray out = env->NewFloatArray(static_cast<jsize>(emb.size()));
|
||||
if (out == nullptr) {
|
||||
LOGE("Failed to create float array");
|
||||
return env->NewFloatArray(0);
|
||||
}
|
||||
if (!emb.empty()) {
|
||||
env->SetFloatArrayRegion(out, 0, static_cast<jsize>(emb.size()), emb.data());
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_com_digitalperson_engine_ArcFaceEngineRKNN_releaseNative(
|
||||
JNIEnv* env,
|
||||
jobject thiz,
|
||||
jlong ptr) {
|
||||
auto* engine = reinterpret_cast<ArcFaceEngineRKNN*>(ptr);
|
||||
if (engine != nullptr) {
|
||||
engine->release();
|
||||
delete engine;
|
||||
}
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
@@ -14,6 +14,11 @@ if (ANDROID)
|
||||
set_target_properties(sentencepiece PROPERTIES IMPORTED_LOCATION
|
||||
${JNI_LIBS_DIR}/libsentencepiece.so)
|
||||
|
||||
# 导入 rkllm 库
|
||||
add_library(rkllmrt SHARED IMPORTED)
|
||||
set_target_properties(rkllmrt PROPERTIES IMPORTED_LOCATION
|
||||
${JNI_LIBS_DIR}/librkllmrt.so)
|
||||
|
||||
# Imported static libs
|
||||
add_library(kaldi_native_fbank STATIC IMPORTED)
|
||||
set_target_properties(kaldi_native_fbank PROPERTIES IMPORTED_LOCATION
|
||||
@@ -26,6 +31,12 @@ if (ANDROID)
|
||||
add_library(sensevoiceEngine SHARED
|
||||
SenseVoiceEngineRKNN.cpp
|
||||
SenseVoiceEngineRKNNJNI.cpp
|
||||
RetinaFaceEngineRKNN.cpp
|
||||
RetinaFaceEngineRKNNJNI.cpp
|
||||
ArcFaceEngineRKNN.cpp
|
||||
ArcFaceEngineRKNNJNI.cpp
|
||||
RKLLMEngine.cpp
|
||||
RKLLMEngineJNI.cpp
|
||||
utils/audio_utils.c
|
||||
)
|
||||
|
||||
@@ -40,9 +51,11 @@ if (ANDROID)
|
||||
|
||||
target_link_libraries(sensevoiceEngine
|
||||
rknnrt
|
||||
rkllmrt
|
||||
kaldi_native_fbank
|
||||
sndfile
|
||||
sentencepiece
|
||||
jnigraphics
|
||||
log
|
||||
)
|
||||
endif()
|
||||
|
||||
118
app/src/main/cpp/RKLLMEngine.cpp
Normal file
118
app/src/main/cpp/RKLLMEngine.cpp
Normal file
@@ -0,0 +1,118 @@
|
||||
#include <android/log.h>
|
||||
#include <jni.h>
|
||||
#include "zipformer_headers/rkllm.h"
|
||||
|
||||
#define TAG "RKLLMEngine"
|
||||
#define LOGD(...) __android_log_print(ANDROID_LOG_DEBUG, TAG, __VA_ARGS__)
|
||||
#define LOGW(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__)
|
||||
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
|
||||
#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
|
||||
|
||||
namespace {
|
||||
// Keep these conservative first; too-large values may fail during init on RK3588.
|
||||
constexpr int kDefaultMaxContextLen = 1024;
|
||||
constexpr int kDefaultMaxNewTokens = 128;
|
||||
}
|
||||
|
||||
struct LLmJniEnv {
|
||||
JNIEnv *env;
|
||||
jobject thiz;
|
||||
jclass clazz;
|
||||
};
|
||||
|
||||
void callbackToJava(const char *text, int state, LLmJniEnv *jenv) {
|
||||
jmethodID method = jenv->env->GetMethodID(jenv->clazz, "callbackFromNative", "(Ljava/lang/String;I)V");
|
||||
jstring jText = text ? jenv->env->NewStringUTF(text) : jenv->env->NewStringUTF("");
|
||||
jenv->env->CallVoidMethod(jenv->thiz, method, jText, state);
|
||||
}
|
||||
|
||||
int callback(RKLLMResult *result, void *userdata, LLMCallState state) {
|
||||
auto jenv = (LLmJniEnv *)userdata;
|
||||
|
||||
if (state == RKLLM_RUN_FINISH) {
|
||||
LOGI("<FINISH/>");
|
||||
callbackToJava(nullptr, 0, jenv);
|
||||
delete jenv;
|
||||
} else if (state == RKLLM_RUN_ERROR) {
|
||||
LOGE("<ERROR/>");
|
||||
callbackToJava(nullptr, -1, jenv);
|
||||
delete jenv;
|
||||
} else if (state == RKLLM_RUN_NORMAL) {
|
||||
//LOGD("NM: [%d] %s", result->token_id, result->text);
|
||||
callbackToJava(result->text, 1, jenv);
|
||||
}
|
||||
return 0; // 返回 0 表示正常继续执行
|
||||
}
|
||||
|
||||
// JNI 方法实现
|
||||
extern "C" {
|
||||
jlong initLLM(JNIEnv *env, jobject thiz, jstring model_path) {
|
||||
const char* modelPath = env->GetStringUTFChars(model_path, nullptr);
|
||||
LLMHandle llmHandle = nullptr;
|
||||
|
||||
//设置参数及初始化
|
||||
RKLLMParam param = rkllm_createDefaultParam();
|
||||
param.model_path = modelPath;
|
||||
|
||||
//设置采样参数
|
||||
param.top_k = 1;
|
||||
param.top_p = 0.95;
|
||||
param.temperature = 0.8;
|
||||
param.repeat_penalty = 1.1;
|
||||
param.frequency_penalty = 0.0;
|
||||
param.presence_penalty = 0.0;
|
||||
|
||||
param.max_new_tokens = kDefaultMaxNewTokens;
|
||||
param.max_context_len = kDefaultMaxContextLen;
|
||||
param.skip_special_token = true;
|
||||
param.extend_param.base_domain_id = 0;
|
||||
|
||||
LOGI("rkllm init with module: %s", modelPath);
|
||||
LOGI("rkllm init params: max_context_len=%d, max_new_tokens=%d, top_k=%d, top_p=%f, temp=%f",
|
||||
param.max_context_len, param.max_new_tokens, param.top_k, param.top_p, param.temperature);
|
||||
int ret = rkllm_init(&llmHandle, ¶m, callback);
|
||||
if (ret == 0){
|
||||
LOGI("rkllm init success, handle=%p", llmHandle);
|
||||
} else {
|
||||
LOGE("rkllm init failed, ret=%d", ret);
|
||||
llmHandle = nullptr;
|
||||
}
|
||||
|
||||
env->ReleaseStringUTFChars(model_path, modelPath);
|
||||
|
||||
return (jlong)llmHandle;
|
||||
}
|
||||
|
||||
void deinitLLM(JNIEnv *env, jobject thiz, jlong handle) {
|
||||
rkllm_destroy((LLMHandle)handle);
|
||||
}
|
||||
|
||||
void infer(JNIEnv *env, jobject thiz, jlong handle, jstring text) {
|
||||
if (handle == 0) {
|
||||
LOGE("rkllm infer called with null handle");
|
||||
jclass clazz = env->GetObjectClass(thiz);
|
||||
jmethodID method = env->GetMethodID(clazz, "callbackFromNative", "(Ljava/lang/String;I)V");
|
||||
jstring jText = env->NewStringUTF("RKLLM handle is null");
|
||||
env->CallVoidMethod(thiz, method, jText, -1);
|
||||
env->DeleteLocalRef(jText);
|
||||
return;
|
||||
}
|
||||
|
||||
auto *jnienv = new LLmJniEnv {
|
||||
.env = env,
|
||||
.thiz = thiz,
|
||||
.clazz = env->GetObjectClass(thiz),
|
||||
};
|
||||
|
||||
RKLLMInput rkllm_input = {};
|
||||
RKLLMInferParam rkllm_infer_params = {};
|
||||
const char* sText = env->GetStringUTFChars(text, nullptr);
|
||||
|
||||
rkllm_infer_params.mode = RKLLM_INFER_GENERATE;
|
||||
rkllm_input.input_type = RKLLM_INPUT_PROMPT;
|
||||
rkllm_input.prompt_input = (char *)sText;
|
||||
|
||||
rkllm_run((LLMHandle)handle, &rkllm_input, &rkllm_infer_params, jnienv);
|
||||
env->ReleaseStringUTFChars(text, sText);
|
||||
}
|
||||
}
|
||||
37
app/src/main/cpp/RKLLMEngineJNI.cpp
Normal file
37
app/src/main/cpp/RKLLMEngineJNI.cpp
Normal file
@@ -0,0 +1,37 @@
|
||||
#include <jni.h>
|
||||
|
||||
// JNI 方法声明
|
||||
extern "C" {
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_digitalperson_llm_RKLLM_initLLM(JNIEnv *env, jobject thiz, jstring model_path);
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_com_digitalperson_llm_RKLLM_deinitLLM(JNIEnv *env, jobject thiz, jlong handle);
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_com_digitalperson_llm_RKLLM_infer(JNIEnv *env, jobject thiz, jlong handle, jstring text);
|
||||
}
|
||||
|
||||
// 方法实现
|
||||
extern "C" {
|
||||
jlong initLLM(JNIEnv *env, jobject thiz, jstring model_path);
|
||||
void deinitLLM(JNIEnv *env, jobject thiz, jlong handle);
|
||||
void infer(JNIEnv *env, jobject thiz, jlong handle, jstring text);
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_digitalperson_llm_RKLLM_initLLM(JNIEnv *env, jobject thiz, jstring model_path) {
|
||||
return initLLM(env, thiz, model_path);
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_com_digitalperson_llm_RKLLM_deinitLLM(JNIEnv *env, jobject thiz, jlong handle) {
|
||||
deinitLLM(env, thiz, handle);
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_com_digitalperson_llm_RKLLM_infer(JNIEnv *env, jobject thiz, jlong handle, jstring text) {
|
||||
infer(env, thiz, handle, text);
|
||||
}
|
||||
}
|
||||
435
app/src/main/cpp/RetinaFaceEngineRKNN.cpp
Normal file
435
app/src/main/cpp/RetinaFaceEngineRKNN.cpp
Normal file
@@ -0,0 +1,435 @@
|
||||
#include "RetinaFaceEngineRKNN.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <android/log.h>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
|
||||
#define LOG_TAG "RetinaFaceRKNN"
|
||||
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__)
|
||||
#define LOGW(...) __android_log_print(ANDROID_LOG_WARN, LOG_TAG, __VA_ARGS__)
|
||||
#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__)
|
||||
|
||||
namespace {
|
||||
constexpr float kVariance0 = 0.1f;
|
||||
constexpr float kVariance1 = 0.2f;
|
||||
} // namespace
|
||||
|
||||
RetinaFaceEngineRKNN::RetinaFaceEngineRKNN() = default;
|
||||
|
||||
RetinaFaceEngineRKNN::~RetinaFaceEngineRKNN() {
|
||||
release();
|
||||
}
|
||||
|
||||
size_t RetinaFaceEngineRKNN::tensorElemCount(const rknn_tensor_attr& attr) {
|
||||
if (attr.n_elems > 0) {
|
||||
return static_cast<size_t>(attr.n_elems);
|
||||
}
|
||||
if (attr.n_dims <= 0) {
|
||||
return 0;
|
||||
}
|
||||
size_t n = 1;
|
||||
for (uint32_t i = 0; i < attr.n_dims; ++i) {
|
||||
if (attr.dims[i] == 0) continue;
|
||||
n *= static_cast<size_t>(attr.dims[i]);
|
||||
}
|
||||
return n;
|
||||
}
|
||||
|
||||
int RetinaFaceEngineRKNN::init(
|
||||
const char* modelPath,
|
||||
int inputSize,
|
||||
float scoreThreshold,
|
||||
float nmsThreshold) {
|
||||
release();
|
||||
inputSize_ = inputSize;
|
||||
scoreThreshold_ = scoreThreshold;
|
||||
nmsThreshold_ = nmsThreshold;
|
||||
|
||||
int ret = rknn_init(&ctx_, (void*)modelPath, 0, 0, nullptr);
|
||||
if (ret != RKNN_SUCC) {
|
||||
LOGE("rknn_init failed: ret=%d, model=%s", ret, modelPath);
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::memset(&ioNum_, 0, sizeof(ioNum_));
|
||||
ret = rknn_query(ctx_, RKNN_QUERY_IN_OUT_NUM, &ioNum_, sizeof(ioNum_));
|
||||
if (ret != RKNN_SUCC) {
|
||||
LOGE("rknn_query(RKNN_QUERY_IN_OUT_NUM) failed: %d", ret);
|
||||
release();
|
||||
return ret;
|
||||
}
|
||||
if (ioNum_.n_input < 1 || ioNum_.n_output < 1) {
|
||||
LOGE("invalid io num: input=%u output=%u", ioNum_.n_input, ioNum_.n_output);
|
||||
release();
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::memset(&inputAttr_, 0, sizeof(inputAttr_));
|
||||
inputAttr_.index = 0;
|
||||
ret = rknn_query(ctx_, RKNN_QUERY_INPUT_ATTR, &inputAttr_, sizeof(inputAttr_));
|
||||
if (ret != RKNN_SUCC) {
|
||||
LOGE("rknn_query input attr failed: %d", ret);
|
||||
release();
|
||||
return ret;
|
||||
}
|
||||
|
||||
outputAttrs_.clear();
|
||||
outputAttrs_.resize(ioNum_.n_output);
|
||||
for (uint32_t i = 0; i < ioNum_.n_output; ++i) {
|
||||
std::memset(&outputAttrs_[i], 0, sizeof(rknn_tensor_attr));
|
||||
outputAttrs_[i].index = i;
|
||||
int qret = rknn_query(ctx_, RKNN_QUERY_OUTPUT_ATTR, &outputAttrs_[i], sizeof(rknn_tensor_attr));
|
||||
if (qret != RKNN_SUCC) {
|
||||
LOGW("query output attr[%u] failed: %d", i, qret);
|
||||
}
|
||||
LOGI("output[%u] n_elems=%u n_dims=%u type=%d qnt=%d",
|
||||
i, outputAttrs_[i].n_elems, outputAttrs_[i].n_dims, outputAttrs_[i].type, outputAttrs_[i].qnt_type);
|
||||
for (uint32_t d = 0; d < outputAttrs_[i].n_dims; ++d) {
|
||||
LOGI(" output[%u] dim[%u]=%u", i, d, outputAttrs_[i].dims[d]);
|
||||
}
|
||||
}
|
||||
|
||||
initialized_ = true;
|
||||
LOGI("RetinaFace initialized, input_size=%d, outputs=%u", inputSize_, ioNum_.n_output);
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::vector<RetinaFaceEngineRKNN::PriorBox> RetinaFaceEngineRKNN::buildPriors() const {
|
||||
std::vector<PriorBox> priors;
|
||||
const int steps[3] = {8, 16, 32};
|
||||
const int minSizes[3][2] = {{16, 32}, {64, 128}, {256, 512}};
|
||||
for (int s = 0; s < 3; ++s) {
|
||||
const int step = steps[s];
|
||||
const int featW = static_cast<int>(std::ceil(static_cast<float>(inputSize_) / step));
|
||||
const int featH = static_cast<int>(std::ceil(static_cast<float>(inputSize_) / step));
|
||||
for (int y = 0; y < featH; ++y) {
|
||||
for (int x = 0; x < featW; ++x) {
|
||||
for (int k = 0; k < 2; ++k) {
|
||||
const float minSize = static_cast<float>(minSizes[s][k]);
|
||||
PriorBox p;
|
||||
p.cx = (x + 0.5f) * step / inputSize_;
|
||||
p.cy = (y + 0.5f) * step / inputSize_;
|
||||
p.w = minSize / inputSize_;
|
||||
p.h = minSize / inputSize_;
|
||||
priors.push_back(p);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return priors;
|
||||
}
|
||||
|
||||
bool RetinaFaceEngineRKNN::parseRetinaOutputs(
|
||||
rknn_output* outputs,
|
||||
std::vector<float>* locOut,
|
||||
std::vector<float>* scoreOut) const {
|
||||
std::vector<std::vector<float>> locCandidates;
|
||||
std::vector<std::vector<float>> confCandidates2;
|
||||
std::vector<std::vector<float>> scoreCandidates1;
|
||||
|
||||
const int anchors8 = (inputSize_ / 8) * (inputSize_ / 8) * 2;
|
||||
const int anchors16 = (inputSize_ / 16) * (inputSize_ / 16) * 2;
|
||||
const int anchors32 = (inputSize_ / 32) * (inputSize_ / 32) * 2;
|
||||
const int totalAnchors = anchors8 + anchors16 + anchors32;
|
||||
const int expectedLoc8 = anchors8 * 4;
|
||||
const int expectedLoc16 = anchors16 * 4;
|
||||
const int expectedLoc32 = anchors32 * 4;
|
||||
const int expectedConf8_2 = anchors8 * 2;
|
||||
const int expectedConf16_2 = anchors16 * 2;
|
||||
const int expectedConf32_2 = anchors32 * 2;
|
||||
const int expectedConf8_1 = anchors8;
|
||||
const int expectedConf16_1 = anchors16;
|
||||
const int expectedConf32_1 = anchors32;
|
||||
|
||||
for (uint32_t i = 0; i < ioNum_.n_output; ++i) {
|
||||
const size_t elems = tensorElemCount(outputAttrs_[i]);
|
||||
if (elems == 0 || outputs[i].buf == nullptr) continue;
|
||||
|
||||
const float* ptr = reinterpret_cast<const float*>(outputs[i].buf);
|
||||
std::vector<float> data(ptr, ptr + elems);
|
||||
const int e = static_cast<int>(elems);
|
||||
|
||||
if (e == expectedLoc8 || e == expectedLoc16 || e == expectedLoc32 || e == totalAnchors * 4) {
|
||||
locCandidates.push_back(std::move(data));
|
||||
continue;
|
||||
}
|
||||
if (e == expectedConf8_2 || e == expectedConf16_2 || e == expectedConf32_2 || e == totalAnchors * 2) {
|
||||
confCandidates2.push_back(std::move(data));
|
||||
continue;
|
||||
}
|
||||
if (e == expectedConf8_1 || e == expectedConf16_1 || e == expectedConf32_1 || e == totalAnchors) {
|
||||
scoreCandidates1.push_back(std::move(data));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
locOut->clear();
|
||||
scoreOut->clear();
|
||||
|
||||
auto sortBySize = [](const std::vector<float>& a, const std::vector<float>& b) {
|
||||
return a.size() > b.size();
|
||||
};
|
||||
std::sort(locCandidates.begin(), locCandidates.end(), sortBySize);
|
||||
std::sort(confCandidates2.begin(), confCandidates2.end(), sortBySize);
|
||||
std::sort(scoreCandidates1.begin(), scoreCandidates1.end(), sortBySize);
|
||||
|
||||
auto mergeLoc = [&]() -> bool {
|
||||
if (locCandidates.empty()) return false;
|
||||
if (locCandidates.size() >= 3 &&
|
||||
static_cast<int>(locCandidates[0].size()) == anchors8 * 4 &&
|
||||
static_cast<int>(locCandidates[1].size()) == anchors16 * 4 &&
|
||||
static_cast<int>(locCandidates[2].size()) == anchors32 * 4) {
|
||||
locOut->reserve(static_cast<size_t>(totalAnchors) * 4);
|
||||
locOut->insert(locOut->end(), locCandidates[0].begin(), locCandidates[0].end());
|
||||
locOut->insert(locOut->end(), locCandidates[1].begin(), locCandidates[1].end());
|
||||
locOut->insert(locOut->end(), locCandidates[2].begin(), locCandidates[2].end());
|
||||
return true;
|
||||
}
|
||||
for (const auto& c : locCandidates) {
|
||||
if (static_cast<int>(c.size()) == totalAnchors * 4) {
|
||||
*locOut = c;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
auto mergeScoreFrom2Class = [&]() -> bool {
|
||||
if (confCandidates2.empty()) return false;
|
||||
std::vector<float> merged2;
|
||||
if (confCandidates2.size() >= 3 &&
|
||||
static_cast<int>(confCandidates2[0].size()) == expectedConf8_2 &&
|
||||
static_cast<int>(confCandidates2[1].size()) == expectedConf16_2 &&
|
||||
static_cast<int>(confCandidates2[2].size()) == expectedConf32_2) {
|
||||
merged2.reserve(static_cast<size_t>(totalAnchors) * 2);
|
||||
merged2.insert(merged2.end(), confCandidates2[0].begin(), confCandidates2[0].end());
|
||||
merged2.insert(merged2.end(), confCandidates2[1].begin(), confCandidates2[1].end());
|
||||
merged2.insert(merged2.end(), confCandidates2[2].begin(), confCandidates2[2].end());
|
||||
} else {
|
||||
bool found = false;
|
||||
for (const auto& c : confCandidates2) {
|
||||
if (static_cast<int>(c.size()) == totalAnchors * 2) {
|
||||
merged2 = c;
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!found) return false;
|
||||
}
|
||||
|
||||
scoreOut->reserve(totalAnchors);
|
||||
for (int i = 0; i < totalAnchors; ++i) {
|
||||
scoreOut->push_back(merged2[i * 2 + 1]);
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
auto mergeScoreFrom1Class = [&]() -> bool {
|
||||
if (scoreCandidates1.empty()) return false;
|
||||
if (scoreCandidates1.size() >= 3 &&
|
||||
static_cast<int>(scoreCandidates1[0].size()) == expectedConf8_1 &&
|
||||
static_cast<int>(scoreCandidates1[1].size()) == expectedConf16_1 &&
|
||||
static_cast<int>(scoreCandidates1[2].size()) == expectedConf32_1) {
|
||||
scoreOut->reserve(totalAnchors);
|
||||
scoreOut->insert(scoreOut->end(), scoreCandidates1[0].begin(), scoreCandidates1[0].end());
|
||||
scoreOut->insert(scoreOut->end(), scoreCandidates1[1].begin(), scoreCandidates1[1].end());
|
||||
scoreOut->insert(scoreOut->end(), scoreCandidates1[2].begin(), scoreCandidates1[2].end());
|
||||
return true;
|
||||
}
|
||||
for (const auto& c : scoreCandidates1) {
|
||||
if (static_cast<int>(c.size()) == totalAnchors) {
|
||||
*scoreOut = c;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
const bool locOk = mergeLoc();
|
||||
bool scoreOk = mergeScoreFrom2Class();
|
||||
if (!scoreOk) {
|
||||
scoreOk = mergeScoreFrom1Class();
|
||||
}
|
||||
if (!locOk || !scoreOk) {
|
||||
LOGW("Unable to parse retina outputs, loc_candidates=%zu, conf2_candidates=%zu, conf1_candidates=%zu",
|
||||
locCandidates.size(), confCandidates2.size(), scoreCandidates1.size());
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
float RetinaFaceEngineRKNN::iou(const FaceCandidate& a, const FaceCandidate& b) {
|
||||
const float left = std::max(a.left, b.left);
|
||||
const float top = std::max(a.top, b.top);
|
||||
const float right = std::min(a.right, b.right);
|
||||
const float bottom = std::min(a.bottom, b.bottom);
|
||||
const float w = std::max(0.0f, right - left);
|
||||
const float h = std::max(0.0f, bottom - top);
|
||||
const float inter = w * h;
|
||||
const float areaA = std::max(0.0f, a.right - a.left) * std::max(0.0f, a.bottom - a.top);
|
||||
const float areaB = std::max(0.0f, b.right - b.left) * std::max(0.0f, b.bottom - b.top);
|
||||
const float uni = areaA + areaB - inter;
|
||||
return uni > 0.0f ? (inter / uni) : 0.0f;
|
||||
}
|
||||
|
||||
std::vector<RetinaFaceEngineRKNN::FaceCandidate> RetinaFaceEngineRKNN::nms(
|
||||
const std::vector<FaceCandidate>& boxes,
|
||||
float threshold) {
|
||||
std::vector<FaceCandidate> sorted = boxes;
|
||||
std::sort(sorted.begin(), sorted.end(), [](const FaceCandidate& a, const FaceCandidate& b) {
|
||||
return a.score > b.score;
|
||||
});
|
||||
|
||||
std::vector<FaceCandidate> keep;
|
||||
std::vector<char> removed(sorted.size(), 0);
|
||||
for (size_t i = 0; i < sorted.size(); ++i) {
|
||||
if (removed[i]) continue;
|
||||
keep.push_back(sorted[i]);
|
||||
for (size_t j = i + 1; j < sorted.size(); ++j) {
|
||||
if (removed[j]) continue;
|
||||
if (iou(sorted[i], sorted[j]) > threshold) {
|
||||
removed[j] = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
return keep;
|
||||
}
|
||||
|
||||
std::vector<float> RetinaFaceEngineRKNN::detect(
|
||||
const uint32_t* argbPixels,
|
||||
int width,
|
||||
int height,
|
||||
int strideBytes) {
|
||||
std::vector<float> empty;
|
||||
if (!initialized_ || ctx_ == 0 || argbPixels == nullptr || width <= 0 || height <= 0) {
|
||||
return empty;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> rgb(inputSize_ * inputSize_ * 3);
|
||||
const int srcStridePx = strideBytes / 4;
|
||||
for (int y = 0; y < inputSize_; ++y) {
|
||||
const int sy = y * height / inputSize_;
|
||||
const uint32_t* srcRow = argbPixels + sy * srcStridePx;
|
||||
uint8_t* dst = rgb.data() + y * inputSize_ * 3;
|
||||
for (int x = 0; x < inputSize_; ++x) {
|
||||
const int sx = x * width / inputSize_;
|
||||
const uint32_t pixel = srcRow[sx];
|
||||
const uint8_t r = (pixel >> 16) & 0xFF;
|
||||
const uint8_t g = (pixel >> 8) & 0xFF;
|
||||
const uint8_t b = pixel & 0xFF;
|
||||
dst[3 * x + 0] = r;
|
||||
dst[3 * x + 1] = g;
|
||||
dst[3 * x + 2] = b;
|
||||
}
|
||||
}
|
||||
|
||||
rknn_input input{};
|
||||
input.index = 0;
|
||||
input.type = RKNN_TENSOR_UINT8;
|
||||
input.size = rgb.size();
|
||||
input.buf = rgb.data();
|
||||
input.pass_through = 0;
|
||||
input.fmt = (inputAttr_.fmt == RKNN_TENSOR_NCHW) ? RKNN_TENSOR_NCHW : RKNN_TENSOR_NHWC;
|
||||
|
||||
std::vector<uint8_t> nchw;
|
||||
if (input.fmt == RKNN_TENSOR_NCHW) {
|
||||
nchw.resize(rgb.size());
|
||||
const int hw = inputSize_ * inputSize_;
|
||||
for (int i = 0; i < hw; ++i) {
|
||||
nchw[i] = rgb[3 * i + 0];
|
||||
nchw[hw + i] = rgb[3 * i + 1];
|
||||
nchw[2 * hw + i] = rgb[3 * i + 2];
|
||||
}
|
||||
input.buf = nchw.data();
|
||||
}
|
||||
|
||||
int ret = rknn_inputs_set(ctx_, 1, &input);
|
||||
if (ret != RKNN_SUCC) {
|
||||
LOGW("rknn_inputs_set failed: %d", ret);
|
||||
return empty;
|
||||
}
|
||||
|
||||
ret = rknn_run(ctx_, nullptr);
|
||||
if (ret != RKNN_SUCC) {
|
||||
LOGW("rknn_run failed: %d", ret);
|
||||
return empty;
|
||||
}
|
||||
|
||||
std::vector<rknn_output> outputs(ioNum_.n_output);
|
||||
for (uint32_t i = 0; i < ioNum_.n_output; ++i) {
|
||||
std::memset(&outputs[i], 0, sizeof(rknn_output));
|
||||
outputs[i].want_float = 1;
|
||||
}
|
||||
ret = rknn_outputs_get(ctx_, ioNum_.n_output, outputs.data(), nullptr);
|
||||
if (ret != RKNN_SUCC) {
|
||||
LOGW("rknn_outputs_get failed: %d", ret);
|
||||
return empty;
|
||||
}
|
||||
|
||||
std::vector<float> loc;
|
||||
std::vector<float> scores;
|
||||
if (!parseRetinaOutputs(outputs.data(), &loc, &scores)) {
|
||||
rknn_outputs_release(ctx_, ioNum_.n_output, outputs.data());
|
||||
return empty;
|
||||
}
|
||||
|
||||
const std::vector<PriorBox> priors = buildPriors();
|
||||
const size_t anchorCount = priors.size();
|
||||
if (loc.size() < anchorCount * 4 || scores.size() < anchorCount) {
|
||||
LOGW("Output size mismatch: priors=%zu loc=%zu scores=%zu", anchorCount, loc.size(), scores.size());
|
||||
rknn_outputs_release(ctx_, ioNum_.n_output, outputs.data());
|
||||
return empty;
|
||||
}
|
||||
|
||||
std::vector<FaceCandidate> candidates;
|
||||
candidates.reserve(anchorCount / 8);
|
||||
for (size_t i = 0; i < anchorCount; ++i) {
|
||||
const float score = scores[i];
|
||||
if (score < scoreThreshold_) continue;
|
||||
|
||||
const PriorBox& p = priors[i];
|
||||
const float dx = loc[i * 4 + 0];
|
||||
const float dy = loc[i * 4 + 1];
|
||||
const float dw = loc[i * 4 + 2];
|
||||
const float dh = loc[i * 4 + 3];
|
||||
|
||||
const float cx = p.cx + dx * kVariance0 * p.w;
|
||||
const float cy = p.cy + dy * kVariance0 * p.h;
|
||||
const float w = p.w * std::exp(dw * kVariance1);
|
||||
const float h = p.h * std::exp(dh * kVariance1);
|
||||
|
||||
FaceCandidate box;
|
||||
box.left = std::max(0.0f, (cx - w * 0.5f) * width);
|
||||
box.top = std::max(0.0f, (cy - h * 0.5f) * height);
|
||||
box.right = std::min(static_cast<float>(width), (cx + w * 0.5f) * width);
|
||||
box.bottom = std::min(static_cast<float>(height), (cy + h * 0.5f) * height);
|
||||
box.score = score;
|
||||
candidates.push_back(box);
|
||||
}
|
||||
|
||||
rknn_outputs_release(ctx_, ioNum_.n_output, outputs.data());
|
||||
|
||||
std::vector<FaceCandidate> filtered = nms(candidates, nmsThreshold_);
|
||||
std::vector<float> result;
|
||||
result.reserve(filtered.size() * 5);
|
||||
for (const auto& f : filtered) {
|
||||
result.push_back(f.left);
|
||||
result.push_back(f.top);
|
||||
result.push_back(f.right);
|
||||
result.push_back(f.bottom);
|
||||
result.push_back(f.score);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void RetinaFaceEngineRKNN::release() {
|
||||
if (ctx_ != 0) {
|
||||
rknn_destroy(ctx_);
|
||||
ctx_ = 0;
|
||||
}
|
||||
outputAttrs_.clear();
|
||||
std::memset(&ioNum_, 0, sizeof(ioNum_));
|
||||
std::memset(&inputAttr_, 0, sizeof(inputAttr_));
|
||||
initialized_ = false;
|
||||
}
|
||||
55
app/src/main/cpp/RetinaFaceEngineRKNN.h
Normal file
55
app/src/main/cpp/RetinaFaceEngineRKNN.h
Normal file
@@ -0,0 +1,55 @@
|
||||
#ifndef DIGITAL_PERSON_RETINAFACE_ENGINE_RKNN_H
|
||||
#define DIGITAL_PERSON_RETINAFACE_ENGINE_RKNN_H
|
||||
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "rknn_api.h"
|
||||
|
||||
class RetinaFaceEngineRKNN {
|
||||
public:
|
||||
RetinaFaceEngineRKNN();
|
||||
~RetinaFaceEngineRKNN();
|
||||
|
||||
int init(const char* modelPath, int inputSize, float scoreThreshold, float nmsThreshold);
|
||||
std::vector<float> detect(const uint32_t* argbPixels, int width, int height, int strideBytes);
|
||||
void release();
|
||||
|
||||
private:
|
||||
struct PriorBox {
|
||||
float cx;
|
||||
float cy;
|
||||
float w;
|
||||
float h;
|
||||
};
|
||||
|
||||
struct FaceCandidate {
|
||||
float left;
|
||||
float top;
|
||||
float right;
|
||||
float bottom;
|
||||
float score;
|
||||
};
|
||||
|
||||
static size_t tensorElemCount(const rknn_tensor_attr& attr);
|
||||
static float iou(const FaceCandidate& a, const FaceCandidate& b);
|
||||
static std::vector<FaceCandidate> nms(const std::vector<FaceCandidate>& boxes, float threshold);
|
||||
|
||||
std::vector<PriorBox> buildPriors() const;
|
||||
bool parseRetinaOutputs(
|
||||
rknn_output* outputs,
|
||||
std::vector<float>* locOut,
|
||||
std::vector<float>* scoreOut) const;
|
||||
|
||||
rknn_context ctx_ = 0;
|
||||
bool initialized_ = false;
|
||||
int inputSize_ = 320;
|
||||
float scoreThreshold_ = 0.6f;
|
||||
float nmsThreshold_ = 0.4f;
|
||||
rknn_input_output_num ioNum_{};
|
||||
rknn_tensor_attr inputAttr_{};
|
||||
std::vector<rknn_tensor_attr> outputAttrs_;
|
||||
};
|
||||
|
||||
#endif
|
||||
100
app/src/main/cpp/RetinaFaceEngineRKNNJNI.cpp
Normal file
100
app/src/main/cpp/RetinaFaceEngineRKNNJNI.cpp
Normal file
@@ -0,0 +1,100 @@
|
||||
#include <jni.h>
|
||||
#include <android/bitmap.h>
|
||||
#include <android/log.h>
|
||||
|
||||
#include "RetinaFaceEngineRKNN.h"
|
||||
|
||||
#define LOG_TAG "RetinaFaceJNI"
|
||||
#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__)
|
||||
|
||||
extern "C" {
|
||||
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_digitalperson_engine_RetinaFaceEngineRKNN_createEngineNative(JNIEnv* env, jobject thiz) {
|
||||
auto* engine = new RetinaFaceEngineRKNN();
|
||||
if (engine == nullptr) {
|
||||
return 0;
|
||||
}
|
||||
return reinterpret_cast<jlong>(engine);
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_com_digitalperson_engine_RetinaFaceEngineRKNN_initNative(
|
||||
JNIEnv* env,
|
||||
jobject thiz,
|
||||
jlong ptr,
|
||||
jstring modelPath,
|
||||
jint inputSize,
|
||||
jfloat scoreThreshold,
|
||||
jfloat nmsThreshold) {
|
||||
auto* engine = reinterpret_cast<RetinaFaceEngineRKNN*>(ptr);
|
||||
if (engine == nullptr || modelPath == nullptr) {
|
||||
return -1;
|
||||
}
|
||||
const char* model = env->GetStringUTFChars(modelPath, nullptr);
|
||||
if (model == nullptr) {
|
||||
return -1;
|
||||
}
|
||||
int ret = engine->init(model, static_cast<int>(inputSize), scoreThreshold, nmsThreshold);
|
||||
env->ReleaseStringUTFChars(modelPath, model);
|
||||
return ret;
|
||||
}
|
||||
|
||||
JNIEXPORT jfloatArray JNICALL
|
||||
Java_com_digitalperson_engine_RetinaFaceEngineRKNN_detectNative(
|
||||
JNIEnv* env,
|
||||
jobject thiz,
|
||||
jlong ptr,
|
||||
jobject bitmapObj) {
|
||||
auto* engine = reinterpret_cast<RetinaFaceEngineRKNN*>(ptr);
|
||||
if (engine == nullptr || bitmapObj == nullptr) {
|
||||
return env->NewFloatArray(0);
|
||||
}
|
||||
|
||||
AndroidBitmapInfo info{};
|
||||
if (AndroidBitmap_getInfo(env, bitmapObj, &info) < 0) {
|
||||
LOGE("AndroidBitmap_getInfo failed");
|
||||
return env->NewFloatArray(0);
|
||||
}
|
||||
if (info.format != ANDROID_BITMAP_FORMAT_RGBA_8888) {
|
||||
LOGE("Unsupported bitmap format: %d", info.format);
|
||||
return env->NewFloatArray(0);
|
||||
}
|
||||
|
||||
void* pixels = nullptr;
|
||||
if (AndroidBitmap_lockPixels(env, bitmapObj, &pixels) < 0 || pixels == nullptr) {
|
||||
LOGE("AndroidBitmap_lockPixels failed");
|
||||
return env->NewFloatArray(0);
|
||||
}
|
||||
|
||||
std::vector<float> result = engine->detect(
|
||||
reinterpret_cast<uint32_t*>(pixels),
|
||||
static_cast<int>(info.width),
|
||||
static_cast<int>(info.height),
|
||||
static_cast<int>(info.stride));
|
||||
|
||||
AndroidBitmap_unlockPixels(env, bitmapObj);
|
||||
|
||||
jfloatArray out = env->NewFloatArray(static_cast<jsize>(result.size()));
|
||||
if (out == nullptr) {
|
||||
return env->NewFloatArray(0);
|
||||
}
|
||||
if (!result.empty()) {
|
||||
env->SetFloatArrayRegion(out, 0, static_cast<jsize>(result.size()), result.data());
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_com_digitalperson_engine_RetinaFaceEngineRKNN_releaseNative(
|
||||
JNIEnv* env,
|
||||
jobject thiz,
|
||||
jlong ptr) {
|
||||
auto* engine = reinterpret_cast<RetinaFaceEngineRKNN*>(ptr);
|
||||
if (engine != nullptr) {
|
||||
engine->release();
|
||||
delete engine;
|
||||
}
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
409
app/src/main/cpp/zipformer_headers/rkllm.h
Normal file
409
app/src/main/cpp/zipformer_headers/rkllm.h
Normal file
@@ -0,0 +1,409 @@
|
||||
#ifndef _RKLLM_H_
|
||||
#define _RKLLM_H_
|
||||
#include <cstdint>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#define CPU0 (1 << 0) // 0x01
|
||||
#define CPU1 (1 << 1) // 0x02
|
||||
#define CPU2 (1 << 2) // 0x04
|
||||
#define CPU3 (1 << 3) // 0x08
|
||||
#define CPU4 (1 << 4) // 0x10
|
||||
#define CPU5 (1 << 5) // 0x20
|
||||
#define CPU6 (1 << 6) // 0x40
|
||||
#define CPU7 (1 << 7) // 0x80
|
||||
|
||||
/**
|
||||
* @typedef LLMHandle
|
||||
* @brief A handle used to manage and interact with the large language model.
|
||||
*/
|
||||
typedef void* LLMHandle;
|
||||
|
||||
/**
|
||||
* @enum LLMCallState
|
||||
* @brief Describes the possible states of an LLM call.
|
||||
*/
|
||||
typedef enum {
|
||||
RKLLM_RUN_NORMAL = 0, /**< The LLM call is in a normal running state. */
|
||||
RKLLM_RUN_WAITING = 1, /**< The LLM call is waiting for complete UTF-8 encoded character. */
|
||||
RKLLM_RUN_FINISH = 2, /**< The LLM call has finished execution. */
|
||||
RKLLM_RUN_ERROR = 3, /**< An error occurred during the LLM call. */
|
||||
} LLMCallState;
|
||||
|
||||
/**
|
||||
* @enum RKLLMInputType
|
||||
* @brief Defines the types of inputs that can be fed into the LLM.
|
||||
*/
|
||||
typedef enum {
|
||||
RKLLM_INPUT_PROMPT = 0, /**< Input is a text prompt. */
|
||||
RKLLM_INPUT_TOKEN = 1, /**< Input is a sequence of tokens. */
|
||||
RKLLM_INPUT_EMBED = 2, /**< Input is an embedding vector. */
|
||||
RKLLM_INPUT_MULTIMODAL = 3, /**< Input is multimodal (e.g., text and image). */
|
||||
} RKLLMInputType;
|
||||
|
||||
/**
|
||||
* @enum RKLLMInferMode
|
||||
* @brief Specifies the inference modes of the LLM.
|
||||
*/
|
||||
typedef enum {
|
||||
RKLLM_INFER_GENERATE = 0, /**< The LLM generates text based on input. */
|
||||
RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1, /**< The LLM retrieves the last hidden layer for further processing. */
|
||||
RKLLM_INFER_GET_LOGITS = 2, /**< The LLM retrieves logits for further processing. */
|
||||
} RKLLMInferMode;
|
||||
|
||||
/**
|
||||
* @struct RKLLMExtendParam
|
||||
* @brief The extend parameters for configuring an LLM instance.
|
||||
*/
|
||||
typedef struct {
|
||||
int32_t base_domain_id; /**< base_domain_id */
|
||||
int8_t embed_flash; /**< Indicates whether to query word embedding vectors from flash memory (1) or not (0). */
|
||||
int8_t enabled_cpus_num; /**< Number of CPUs enabled for inference. */
|
||||
uint32_t enabled_cpus_mask; /**< Bitmask indicating which CPUs to enable for inference. */
|
||||
uint8_t n_batch; /**< Number of input samples processed concurrently in one forward pass. Set to >1 to enable batched inference. Default is 1. */
|
||||
int8_t use_cross_attn; /**< Whether to enable cross attention (non-zero to enable, 0 to disable). */
|
||||
uint8_t reserved[104]; /**< reserved */
|
||||
} RKLLMExtendParam;
|
||||
|
||||
/**
|
||||
* @struct RKLLMParam
|
||||
* @brief Defines the parameters for configuring an LLM instance.
|
||||
*/
|
||||
typedef struct {
|
||||
const char* model_path; /**< Path to the model file. */
|
||||
int32_t max_context_len; /**< Maximum number of tokens in the context window. */
|
||||
int32_t max_new_tokens; /**< Maximum number of new tokens to generate. */
|
||||
int32_t top_k; /**< Top-K sampling parameter for token generation. */
|
||||
int32_t n_keep; /** number of kv cache to keep at the beginning when shifting context window */
|
||||
float top_p; /**< Top-P (nucleus) sampling parameter. */
|
||||
float temperature; /**< Sampling temperature, affecting the randomness of token selection. */
|
||||
float repeat_penalty; /**< Penalty for repeating tokens in generation. */
|
||||
float frequency_penalty; /**< Penalizes frequent tokens during generation. */
|
||||
float presence_penalty; /**< Penalizes tokens based on their presence in the input. */
|
||||
int32_t mirostat; /**< Mirostat sampling strategy flag (0 to disable). */
|
||||
float mirostat_tau; /**< Tau parameter for Mirostat sampling. */
|
||||
float mirostat_eta; /**< Eta parameter for Mirostat sampling. */
|
||||
bool skip_special_token; /**< Whether to skip special tokens during generation. */
|
||||
bool is_async; /**< Whether to run inference asynchronously. */
|
||||
const char* img_start; /**< Starting position of an image in multimodal input. */
|
||||
const char* img_end; /**< Ending position of an image in multimodal input. */
|
||||
const char* img_content; /**< Pointer to the image content. */
|
||||
RKLLMExtendParam extend_param; /**< Extend parameters. */
|
||||
} RKLLMParam;
|
||||
|
||||
/**
|
||||
* @struct RKLLMLoraAdapter
|
||||
* @brief Defines parameters for a Lora adapter used in model fine-tuning.
|
||||
*/
|
||||
typedef struct {
|
||||
const char* lora_adapter_path; /**< Path to the Lora adapter file. */
|
||||
const char* lora_adapter_name; /**< Name of the Lora adapter. */
|
||||
float scale; /**< Scaling factor for applying the Lora adapter. */
|
||||
} RKLLMLoraAdapter;
|
||||
|
||||
/**
|
||||
* @struct RKLLMEmbedInput
|
||||
* @brief Represents an embedding input to the LLM.
|
||||
*/
|
||||
typedef struct {
|
||||
float* embed; /**< Pointer to the embedding vector (of size n_tokens * n_embed). */
|
||||
size_t n_tokens; /**< Number of tokens represented in the embedding. */
|
||||
} RKLLMEmbedInput;
|
||||
|
||||
/**
|
||||
* @struct RKLLMTokenInput
|
||||
* @brief Represents token input to the LLM.
|
||||
*/
|
||||
typedef struct {
|
||||
int32_t* input_ids; /**< Array of token IDs. */
|
||||
size_t n_tokens; /**< Number of tokens in the input. */
|
||||
} RKLLMTokenInput;
|
||||
|
||||
/**
|
||||
* @struct RKLLMMultiModalInput
|
||||
* @brief Represents multimodal input (e.g., text and image).
|
||||
*/
|
||||
typedef struct {
|
||||
char* prompt; /**< Text prompt input. */
|
||||
float* image_embed; /**< Embedding of the images (of size n_image * n_image_tokens * image_embed_length). */
|
||||
size_t n_image_tokens; /**< Number of image_token. */
|
||||
size_t n_image; /**< Number of image. */
|
||||
size_t image_width; /**< Width of image. */
|
||||
size_t image_height; /**< Height of image. */
|
||||
} RKLLMMultiModalInput;
|
||||
|
||||
/**
|
||||
* @struct RKLLMInput
|
||||
* @brief Represents different types of input to the LLM via a union.
|
||||
*/
|
||||
typedef struct {
|
||||
const char* role; /**< Message role: "user" (user input), "tool" (function result) */
|
||||
bool enable_thinking; /**< Controls whether "thinking mode" is enabled for the Qwen3 model. */
|
||||
RKLLMInputType input_type; /**< Specifies the type of input provided (e.g., prompt, token, embed, multimodal). */
|
||||
union {
|
||||
const char* prompt_input; /**< Text prompt input if input_type is RKLLM_INPUT_PROMPT. */
|
||||
RKLLMEmbedInput embed_input; /**< Embedding input if input_type is RKLLM_INPUT_EMBED. */
|
||||
RKLLMTokenInput token_input; /**< Token input if input_type is RKLLM_INPUT_TOKEN. */
|
||||
RKLLMMultiModalInput multimodal_input; /**< Multimodal input if input_type is RKLLM_INPUT_MULTIMODAL. */
|
||||
};
|
||||
} RKLLMInput;
|
||||
|
||||
/**
|
||||
* @struct RKLLMLoraParam
|
||||
* @brief Structure defining parameters for Lora adapters.
|
||||
*/
|
||||
typedef struct {
|
||||
const char* lora_adapter_name; /**< Name of the Lora adapter. */
|
||||
} RKLLMLoraParam;
|
||||
|
||||
/**
|
||||
* @struct RKLLMPromptCacheParam
|
||||
* @brief Structure to define parameters for caching prompts.
|
||||
*/
|
||||
typedef struct {
|
||||
int save_prompt_cache; /**< Flag to indicate whether to save the prompt cache (0 = don't save, 1 = save). */
|
||||
const char* prompt_cache_path; /**< Path to the prompt cache file. */
|
||||
} RKLLMPromptCacheParam;
|
||||
|
||||
/**
|
||||
* @struct RKLLMCrossAttnParam
|
||||
* @brief Structure holding parameters for cross-attention inference.
|
||||
*
|
||||
* This structure is used when performing cross-attention in the decoder.
|
||||
* It provides the encoder output (key/value caches), position indices,
|
||||
* and attention mask.
|
||||
*
|
||||
* - `encoder_k_cache` must be stored in contiguous memory with layout:
|
||||
* [num_layers][num_tokens][num_kv_heads][head_dim]
|
||||
* - `encoder_v_cache` must be stored in contiguous memory with layout:
|
||||
* [num_layers][num_kv_heads][head_dim][num_tokens]
|
||||
*/
|
||||
typedef struct {
|
||||
float* encoder_k_cache; /**< Pointer to encoder key cache (size: num_layers * num_tokens * num_kv_heads * head_dim). */
|
||||
float* encoder_v_cache; /**< Pointer to encoder value cache (size: num_layers * num_kv_heads * head_dim * num_tokens). */
|
||||
float* encoder_mask; /**< Pointer to encoder attention mask (array of size num_tokens). */
|
||||
int32_t* encoder_pos; /**< Pointer to encoder token positions (array of size num_tokens). */
|
||||
int num_tokens; /**< Number of tokens in the encoder sequence. */
|
||||
} RKLLMCrossAttnParam;
|
||||
|
||||
/**
|
||||
* @struct RKLLMInferParam
|
||||
* @brief Structure for defining parameters during inference.
|
||||
*/
|
||||
typedef struct {
|
||||
RKLLMInferMode mode; /**< Inference mode (e.g., generate or get last hidden layer). */
|
||||
RKLLMLoraParam* lora_params; /**< Pointer to Lora adapter parameters. */
|
||||
RKLLMPromptCacheParam* prompt_cache_params; /**< Pointer to prompt cache parameters. */
|
||||
int keep_history; /**Flag to determine history retention (1: keep history, 0: discard history).*/
|
||||
} RKLLMInferParam;
|
||||
|
||||
/**
|
||||
* @struct RKLLMResultLastHiddenLayer
|
||||
* @brief Structure to hold the hidden states from the last layer.
|
||||
*/
|
||||
typedef struct {
|
||||
const float* hidden_states; /**< Pointer to the hidden states (of size num_tokens * embd_size). */
|
||||
int embd_size; /**< Size of the embedding vector. */
|
||||
int num_tokens; /**< Number of tokens for which hidden states are stored. */
|
||||
} RKLLMResultLastHiddenLayer;
|
||||
|
||||
/**
|
||||
* @struct RKLLMResultLogits
|
||||
* @brief Structure to hold the logits.
|
||||
*/
|
||||
typedef struct {
|
||||
const float* logits; /**< Pointer to the logits (of size num_tokens * vocab_size). */
|
||||
int vocab_size; /**< Size of the vocab. */
|
||||
int num_tokens; /**< Number of tokens for which logits are stored. */
|
||||
} RKLLMResultLogits;
|
||||
|
||||
/**
|
||||
* @struct RKLLMPerfStat
|
||||
* @brief Structure to hold performance statistics for prefill and generate stages.
|
||||
*/
|
||||
typedef struct {
|
||||
float prefill_time_ms; /**< Total time taken for the prefill stage in milliseconds. */
|
||||
int prefill_tokens; /**< Number of tokens processed during the prefill stage. */
|
||||
float generate_time_ms; /**< Total time taken for the generate stage in milliseconds. */
|
||||
int generate_tokens; /**< Number of tokens processed during the generate stage. */
|
||||
float memory_usage_mb; /**< VmHWM resident memory usage during inference, in megabytes. */
|
||||
} RKLLMPerfStat;
|
||||
|
||||
/**
|
||||
* @struct RKLLMResult
|
||||
* @brief Structure to represent the result of LLM inference.
|
||||
*/
|
||||
typedef struct {
|
||||
const char* text; /**< Generated text result. */
|
||||
int32_t token_id; /**< ID of the generated token. */
|
||||
RKLLMResultLastHiddenLayer last_hidden_layer; /**< Hidden states of the last layer (if requested). */
|
||||
RKLLMResultLogits logits; /**< Model output logits. */
|
||||
RKLLMPerfStat perf; /**< Pointer to performance statistics (prefill and generate). */
|
||||
} RKLLMResult;
|
||||
|
||||
/**
|
||||
* @typedef LLMResultCallback
|
||||
* @brief Callback function to handle LLM results.
|
||||
* @param result Pointer to the LLM result.
|
||||
* @param userdata Pointer to user data for the callback.
|
||||
* @param state State of the LLM call (e.g., finished, error).
|
||||
* @return int Return value indicating the handling status:
|
||||
* - 0: Continue inference normally.
|
||||
* - 1: Pause inference. If the user wants to modify or intervene in the result (e.g., editing output, injecting new prompt),
|
||||
* return 1 to suspend the current inference. Later, call `rkllm_run` with updated content to resume inference.
|
||||
*/
|
||||
typedef int(*LLMResultCallback)(RKLLMResult* result, void* userdata, LLMCallState state);
|
||||
|
||||
/**
|
||||
* @brief Creates a default RKLLMParam structure with preset values.
|
||||
* @return A default RKLLMParam structure.
|
||||
*/
|
||||
RKLLMParam rkllm_createDefaultParam();
|
||||
|
||||
/**
|
||||
* @brief Initializes the LLM with the given parameters.
|
||||
* @param handle Pointer to the LLM handle.
|
||||
* @param param Configuration parameters for the LLM.
|
||||
* @param callback Callback function to handle LLM results.
|
||||
* @return Status code (0 for success, non-zero for failure).
|
||||
*/
|
||||
int rkllm_init(LLMHandle* handle, RKLLMParam* param, LLMResultCallback callback);
|
||||
|
||||
/**
|
||||
* @brief Loads a Lora adapter into the LLM.
|
||||
* @param handle LLM handle.
|
||||
* @param lora_adapter Pointer to the Lora adapter structure.
|
||||
* @return Status code (0 for success, non-zero for failure).
|
||||
*/
|
||||
int rkllm_load_lora(LLMHandle handle, RKLLMLoraAdapter* lora_adapter);
|
||||
|
||||
/**
|
||||
* @brief Loads a prompt cache from a file.
|
||||
* @param handle LLM handle.
|
||||
* @param prompt_cache_path Path to the prompt cache file.
|
||||
* @return Status code (0 for success, non-zero for failure).
|
||||
*/
|
||||
int rkllm_load_prompt_cache(LLMHandle handle, const char* prompt_cache_path);
|
||||
|
||||
/**
|
||||
* @brief Releases the prompt cache from memory.
|
||||
* @param handle LLM handle.
|
||||
* @return Status code (0 for success, non-zero for failure).
|
||||
*/
|
||||
int rkllm_release_prompt_cache(LLMHandle handle);
|
||||
|
||||
/**
|
||||
* @brief Destroys the LLM instance and releases resources.
|
||||
* @param handle LLM handle.
|
||||
* @return Status code (0 for success, non-zero for failure).
|
||||
*/
|
||||
int rkllm_destroy(LLMHandle handle);
|
||||
|
||||
/**
|
||||
* @brief Runs an LLM inference task synchronously.
|
||||
* @param handle LLM handle.
|
||||
* @param rkllm_input Input data for the LLM.
|
||||
* @param rkllm_infer_params Parameters for the inference task.
|
||||
* @param userdata Pointer to user data for the callback.
|
||||
* @return Status code (0 for success, non-zero for failure).
|
||||
*/
|
||||
int rkllm_run(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
|
||||
|
||||
/**
|
||||
* @brief Runs an LLM inference task asynchronously.
|
||||
* @param handle LLM handle.
|
||||
* @param rkllm_input Input data for the LLM.
|
||||
* @param rkllm_infer_params Parameters for the inference task.
|
||||
* @param userdata Pointer to user data for the callback.
|
||||
* @return Status code (0 for success, non-zero for failure).
|
||||
*/
|
||||
int rkllm_run_async(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
|
||||
|
||||
/**
|
||||
* @brief Aborts an ongoing LLM task.
|
||||
* @param handle LLM handle.
|
||||
* @return Status code (0 for success, non-zero for failure).
|
||||
*/
|
||||
int rkllm_abort(LLMHandle handle);
|
||||
|
||||
/**
|
||||
* @brief Checks if an LLM task is currently running.
|
||||
* @param handle LLM handle.
|
||||
* @return Status code (0 if a task is running, non-zero for otherwise).
|
||||
*/
|
||||
int rkllm_is_running(LLMHandle handle);
|
||||
|
||||
/**
|
||||
* @brief Clear the key-value cache for a given LLM handle.
|
||||
*
|
||||
* This function is used to clear part or all of the KV cache.
|
||||
*
|
||||
* @param handle LLM handle.
|
||||
* @param keep_system_prompt Flag indicating whether to retain the system prompt in the cache (1 to retain, 0 to clear).
|
||||
* This flag is ignored if a specific range [start_pos, end_pos) is provided.
|
||||
* @param start_pos Array of start positions (inclusive) of the KV cache ranges to clear, one per batch.
|
||||
* @param end_pos Array of end positions (exclusive) of the KV cache ranges to clear, one per batch.
|
||||
* If both start_pos and end_pos are set to nullptr, the entire cache will be cleared and keep_system_prompt will take effect,
|
||||
* If start_pos[i] < end_pos[i], only the specified range will be cleared, and keep_system_prompt will be ignored.
|
||||
* @note: start_pos or end_pos is only valid when keep_history == 0 and the generation has been paused by returning 1 in the callback
|
||||
* @return Status code (0 if cache was cleared successfully, non-zero otherwise).
|
||||
*/
|
||||
int rkllm_clear_kv_cache(LLMHandle handle, int keep_system_prompt, int* start_pos, int* end_pos);
|
||||
|
||||
/**
|
||||
* @brief Get the current size of the key-value cache for a given LLM handle.
|
||||
*
|
||||
* This function returns the total number of positions currently stored in the model's KV cache.
|
||||
*
|
||||
* @param handle LLM handle.
|
||||
* @param cache_sizes Pointer to an array where the per-batch cache sizes will be stored.
|
||||
* The array must be preallocated with space for `n_batch` elements.
|
||||
*/
|
||||
int rkllm_get_kv_cache_size(LLMHandle handle, int* cache_sizes);
|
||||
|
||||
/**
|
||||
* @brief Sets the chat template for the LLM, including system prompt, prefix, and postfix.
|
||||
*
|
||||
* This function allows you to customize the chat template by providing a system prompt, a prompt prefix, and a prompt postfix.
|
||||
* The system prompt is typically used to define the behavior or context of the language model,
|
||||
* while the prefix and postfix are used to format the user input and output respectively.
|
||||
*
|
||||
* @param handle LLM handle.
|
||||
* @param system_prompt The system prompt that defines the context or behavior of the language model.
|
||||
* @param prompt_prefix The prefix added before the user input in the chat.
|
||||
* @param prompt_postfix The postfix added after the user input in the chat.
|
||||
*
|
||||
* @return Status code (0 if the template was set successfully, non-zero for errors).
|
||||
*/
|
||||
int rkllm_set_chat_template(LLMHandle handle, const char* system_prompt, const char* prompt_prefix, const char* prompt_postfix);
|
||||
|
||||
/**
|
||||
* @brief Sets the function calling configuration for the LLM, including system prompt, tool definitions, and tool response token.
|
||||
*
|
||||
* @param handle LLM handle.
|
||||
* @param system_prompt The system prompt that defines the context or behavior of the language model.
|
||||
* @param tools A JSON-formatted string that defines the available functions, including their names, descriptions, and parameters.
|
||||
* @param tool_response_str A unique tag used to identify function call results within a conversation. It acts as the marker tag,
|
||||
* allowing tokenizer to recognize tool outputs separately from normal dialogue turns.
|
||||
* @return Status code (0 if the configuration was set successfully, non-zero for errors).
|
||||
*/
|
||||
int rkllm_set_function_tools(LLMHandle handle, const char* system_prompt, const char* tools, const char* tool_response_str);
|
||||
|
||||
/**
|
||||
* @brief Sets the cross-attention parameters for the LLM decoder.
|
||||
*
|
||||
* @param handle LLM handle.
|
||||
* @param cross_attn_params Pointer to the structure containing encoder-related input data
|
||||
* used for cross-attention (see RKLLMCrossAttnParam for details).
|
||||
*
|
||||
* @return Status code (0 if the parameters were set successfully, non-zero for errors).
|
||||
*/
|
||||
int rkllm_set_cross_attn_params(LLMHandle handle, RKLLMCrossAttnParam* cross_attn_params);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
409
app/src/main/cpp/zipformer_headers/rkllm.h.2
Normal file
409
app/src/main/cpp/zipformer_headers/rkllm.h.2
Normal file
@@ -0,0 +1,409 @@
|
||||
#ifndef _RKLLM_H_
|
||||
#define _RKLLM_H_
|
||||
#include <cstdint>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#define CPU0 (1 << 0) // 0x01
|
||||
#define CPU1 (1 << 1) // 0x02
|
||||
#define CPU2 (1 << 2) // 0x04
|
||||
#define CPU3 (1 << 3) // 0x08
|
||||
#define CPU4 (1 << 4) // 0x10
|
||||
#define CPU5 (1 << 5) // 0x20
|
||||
#define CPU6 (1 << 6) // 0x40
|
||||
#define CPU7 (1 << 7) // 0x80
|
||||
|
||||
/**
|
||||
* @typedef LLMHandle
|
||||
* @brief A handle used to manage and interact with the large language model.
|
||||
*/
|
||||
typedef void* LLMHandle;
|
||||
|
||||
/**
|
||||
* @enum LLMCallState
|
||||
* @brief Describes the possible states of an LLM call.
|
||||
*/
|
||||
typedef enum {
|
||||
RKLLM_RUN_NORMAL = 0, /**< The LLM call is in a normal running state. */
|
||||
RKLLM_RUN_WAITING = 1, /**< The LLM call is waiting for complete UTF-8 encoded character. */
|
||||
RKLLM_RUN_FINISH = 2, /**< The LLM call has finished execution. */
|
||||
RKLLM_RUN_ERROR = 3, /**< An error occurred during the LLM call. */
|
||||
} LLMCallState;
|
||||
|
||||
/**
|
||||
* @enum RKLLMInputType
|
||||
* @brief Defines the types of inputs that can be fed into the LLM.
|
||||
*/
|
||||
typedef enum {
|
||||
RKLLM_INPUT_PROMPT = 0, /**< Input is a text prompt. */
|
||||
RKLLM_INPUT_TOKEN = 1, /**< Input is a sequence of tokens. */
|
||||
RKLLM_INPUT_EMBED = 2, /**< Input is an embedding vector. */
|
||||
RKLLM_INPUT_MULTIMODAL = 3, /**< Input is multimodal (e.g., text and image). */
|
||||
} RKLLMInputType;
|
||||
|
||||
/**
|
||||
* @enum RKLLMInferMode
|
||||
* @brief Specifies the inference modes of the LLM.
|
||||
*/
|
||||
typedef enum {
|
||||
RKLLM_INFER_GENERATE = 0, /**< The LLM generates text based on input. */
|
||||
RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1, /**< The LLM retrieves the last hidden layer for further processing. */
|
||||
RKLLM_INFER_GET_LOGITS = 2, /**< The LLM retrieves logits for further processing. */
|
||||
} RKLLMInferMode;
|
||||
|
||||
/**
|
||||
* @struct RKLLMExtendParam
|
||||
* @brief The extend parameters for configuring an LLM instance.
|
||||
*/
|
||||
typedef struct {
|
||||
int32_t base_domain_id; /**< base_domain_id */
|
||||
int8_t embed_flash; /**< Indicates whether to query word embedding vectors from flash memory (1) or not (0). */
|
||||
int8_t enabled_cpus_num; /**< Number of CPUs enabled for inference. */
|
||||
uint32_t enabled_cpus_mask; /**< Bitmask indicating which CPUs to enable for inference. */
|
||||
uint8_t n_batch; /**< Number of input samples processed concurrently in one forward pass. Set to >1 to enable batched inference. Default is 1. */
|
||||
int8_t use_cross_attn; /**< Whether to enable cross attention (non-zero to enable, 0 to disable). */
|
||||
uint8_t reserved[104]; /**< reserved */
|
||||
} RKLLMExtendParam;
|
||||
|
||||
/**
|
||||
* @struct RKLLMParam
|
||||
* @brief Defines the parameters for configuring an LLM instance.
|
||||
*/
|
||||
typedef struct {
|
||||
const char* model_path; /**< Path to the model file. */
|
||||
int32_t max_context_len; /**< Maximum number of tokens in the context window. */
|
||||
int32_t max_new_tokens; /**< Maximum number of new tokens to generate. */
|
||||
int32_t top_k; /**< Top-K sampling parameter for token generation. */
|
||||
int32_t n_keep; /** number of kv cache to keep at the beginning when shifting context window */
|
||||
float top_p; /**< Top-P (nucleus) sampling parameter. */
|
||||
float temperature; /**< Sampling temperature, affecting the randomness of token selection. */
|
||||
float repeat_penalty; /**< Penalty for repeating tokens in generation. */
|
||||
float frequency_penalty; /**< Penalizes frequent tokens during generation. */
|
||||
float presence_penalty; /**< Penalizes tokens based on their presence in the input. */
|
||||
int32_t mirostat; /**< Mirostat sampling strategy flag (0 to disable). */
|
||||
float mirostat_tau; /**< Tau parameter for Mirostat sampling. */
|
||||
float mirostat_eta; /**< Eta parameter for Mirostat sampling. */
|
||||
bool skip_special_token; /**< Whether to skip special tokens during generation. */
|
||||
bool is_async; /**< Whether to run inference asynchronously. */
|
||||
const char* img_start; /**< Starting position of an image in multimodal input. */
|
||||
const char* img_end; /**< Ending position of an image in multimodal input. */
|
||||
const char* img_content; /**< Pointer to the image content. */
|
||||
RKLLMExtendParam extend_param; /**< Extend parameters. */
|
||||
} RKLLMParam;
|
||||
|
||||
/**
|
||||
* @struct RKLLMLoraAdapter
|
||||
* @brief Defines parameters for a Lora adapter used in model fine-tuning.
|
||||
*/
|
||||
typedef struct {
|
||||
const char* lora_adapter_path; /**< Path to the Lora adapter file. */
|
||||
const char* lora_adapter_name; /**< Name of the Lora adapter. */
|
||||
float scale; /**< Scaling factor for applying the Lora adapter. */
|
||||
} RKLLMLoraAdapter;
|
||||
|
||||
/**
|
||||
* @struct RKLLMEmbedInput
|
||||
* @brief Represents an embedding input to the LLM.
|
||||
*/
|
||||
typedef struct {
|
||||
float* embed; /**< Pointer to the embedding vector (of size n_tokens * n_embed). */
|
||||
size_t n_tokens; /**< Number of tokens represented in the embedding. */
|
||||
} RKLLMEmbedInput;
|
||||
|
||||
/**
|
||||
* @struct RKLLMTokenInput
|
||||
* @brief Represents token input to the LLM.
|
||||
*/
|
||||
typedef struct {
|
||||
int32_t* input_ids; /**< Array of token IDs. */
|
||||
size_t n_tokens; /**< Number of tokens in the input. */
|
||||
} RKLLMTokenInput;
|
||||
|
||||
/**
|
||||
* @struct RKLLMMultiModalInput
|
||||
* @brief Represents multimodal input (e.g., text and image).
|
||||
*/
|
||||
typedef struct {
|
||||
char* prompt; /**< Text prompt input. */
|
||||
float* image_embed; /**< Embedding of the images (of size n_image * n_image_tokens * image_embed_length). */
|
||||
size_t n_image_tokens; /**< Number of image_token. */
|
||||
size_t n_image; /**< Number of image. */
|
||||
size_t image_width; /**< Width of image. */
|
||||
size_t image_height; /**< Height of image. */
|
||||
} RKLLMMultiModalInput;
|
||||
|
||||
/**
|
||||
* @struct RKLLMInput
|
||||
* @brief Represents different types of input to the LLM via a union.
|
||||
*/
|
||||
typedef struct {
|
||||
const char* role; /**< Message role: "user" (user input), "tool" (function result) */
|
||||
bool enable_thinking; /**< Controls whether "thinking mode" is enabled for the Qwen3 model. */
|
||||
RKLLMInputType input_type; /**< Specifies the type of input provided (e.g., prompt, token, embed, multimodal). */
|
||||
union {
|
||||
const char* prompt_input; /**< Text prompt input if input_type is RKLLM_INPUT_PROMPT. */
|
||||
RKLLMEmbedInput embed_input; /**< Embedding input if input_type is RKLLM_INPUT_EMBED. */
|
||||
RKLLMTokenInput token_input; /**< Token input if input_type is RKLLM_INPUT_TOKEN. */
|
||||
RKLLMMultiModalInput multimodal_input; /**< Multimodal input if input_type is RKLLM_INPUT_MULTIMODAL. */
|
||||
};
|
||||
} RKLLMInput;
|
||||
|
||||
/**
|
||||
* @struct RKLLMLoraParam
|
||||
* @brief Structure defining parameters for Lora adapters.
|
||||
*/
|
||||
typedef struct {
|
||||
const char* lora_adapter_name; /**< Name of the Lora adapter. */
|
||||
} RKLLMLoraParam;
|
||||
|
||||
/**
|
||||
* @struct RKLLMPromptCacheParam
|
||||
* @brief Structure to define parameters for caching prompts.
|
||||
*/
|
||||
typedef struct {
|
||||
int save_prompt_cache; /**< Flag to indicate whether to save the prompt cache (0 = don't save, 1 = save). */
|
||||
const char* prompt_cache_path; /**< Path to the prompt cache file. */
|
||||
} RKLLMPromptCacheParam;
|
||||
|
||||
/**
|
||||
* @struct RKLLMCrossAttnParam
|
||||
* @brief Structure holding parameters for cross-attention inference.
|
||||
*
|
||||
* This structure is used when performing cross-attention in the decoder.
|
||||
* It provides the encoder output (key/value caches), position indices,
|
||||
* and attention mask.
|
||||
*
|
||||
* - `encoder_k_cache` must be stored in contiguous memory with layout:
|
||||
* [num_layers][num_tokens][num_kv_heads][head_dim]
|
||||
* - `encoder_v_cache` must be stored in contiguous memory with layout:
|
||||
* [num_layers][num_kv_heads][head_dim][num_tokens]
|
||||
*/
|
||||
typedef struct {
|
||||
float* encoder_k_cache; /**< Pointer to encoder key cache (size: num_layers * num_tokens * num_kv_heads * head_dim). */
|
||||
float* encoder_v_cache; /**< Pointer to encoder value cache (size: num_layers * num_kv_heads * head_dim * num_tokens). */
|
||||
float* encoder_mask; /**< Pointer to encoder attention mask (array of size num_tokens). */
|
||||
int32_t* encoder_pos; /**< Pointer to encoder token positions (array of size num_tokens). */
|
||||
int num_tokens; /**< Number of tokens in the encoder sequence. */
|
||||
} RKLLMCrossAttnParam;
|
||||
|
||||
/**
|
||||
* @struct RKLLMInferParam
|
||||
* @brief Structure for defining parameters during inference.
|
||||
*/
|
||||
typedef struct {
|
||||
RKLLMInferMode mode; /**< Inference mode (e.g., generate or get last hidden layer). */
|
||||
RKLLMLoraParam* lora_params; /**< Pointer to Lora adapter parameters. */
|
||||
RKLLMPromptCacheParam* prompt_cache_params; /**< Pointer to prompt cache parameters. */
|
||||
int keep_history; /**Flag to determine history retention (1: keep history, 0: discard history).*/
|
||||
} RKLLMInferParam;
|
||||
|
||||
/**
|
||||
* @struct RKLLMResultLastHiddenLayer
|
||||
* @brief Structure to hold the hidden states from the last layer.
|
||||
*/
|
||||
typedef struct {
|
||||
const float* hidden_states; /**< Pointer to the hidden states (of size num_tokens * embd_size). */
|
||||
int embd_size; /**< Size of the embedding vector. */
|
||||
int num_tokens; /**< Number of tokens for which hidden states are stored. */
|
||||
} RKLLMResultLastHiddenLayer;
|
||||
|
||||
/**
|
||||
* @struct RKLLMResultLogits
|
||||
* @brief Structure to hold the logits.
|
||||
*/
|
||||
typedef struct {
|
||||
const float* logits; /**< Pointer to the logits (of size num_tokens * vocab_size). */
|
||||
int vocab_size; /**< Size of the vocab. */
|
||||
int num_tokens; /**< Number of tokens for which logits are stored. */
|
||||
} RKLLMResultLogits;
|
||||
|
||||
/**
|
||||
* @struct RKLLMPerfStat
|
||||
* @brief Structure to hold performance statistics for prefill and generate stages.
|
||||
*/
|
||||
typedef struct {
|
||||
float prefill_time_ms; /**< Total time taken for the prefill stage in milliseconds. */
|
||||
int prefill_tokens; /**< Number of tokens processed during the prefill stage. */
|
||||
float generate_time_ms; /**< Total time taken for the generate stage in milliseconds. */
|
||||
int generate_tokens; /**< Number of tokens processed during the generate stage. */
|
||||
float memory_usage_mb; /**< VmHWM resident memory usage during inference, in megabytes. */
|
||||
} RKLLMPerfStat;
|
||||
|
||||
/**
|
||||
* @struct RKLLMResult
|
||||
* @brief Structure to represent the result of LLM inference.
|
||||
*/
|
||||
typedef struct {
|
||||
const char* text; /**< Generated text result. */
|
||||
int32_t token_id; /**< ID of the generated token. */
|
||||
RKLLMResultLastHiddenLayer last_hidden_layer; /**< Hidden states of the last layer (if requested). */
|
||||
RKLLMResultLogits logits; /**< Model output logits. */
|
||||
RKLLMPerfStat perf; /**< Pointer to performance statistics (prefill and generate). */
|
||||
} RKLLMResult;
|
||||
|
||||
/**
|
||||
* @typedef LLMResultCallback
|
||||
* @brief Callback function to handle LLM results.
|
||||
* @param result Pointer to the LLM result.
|
||||
* @param userdata Pointer to user data for the callback.
|
||||
* @param state State of the LLM call (e.g., finished, error).
|
||||
* @return int Return value indicating the handling status:
|
||||
* - 0: Continue inference normally.
|
||||
* - 1: Pause inference. If the user wants to modify or intervene in the result (e.g., editing output, injecting new prompt),
|
||||
* return 1 to suspend the current inference. Later, call `rkllm_run` with updated content to resume inference.
|
||||
*/
|
||||
typedef int(*LLMResultCallback)(RKLLMResult* result, void* userdata, LLMCallState state);
|
||||
|
||||
/**
|
||||
* @brief Creates a default RKLLMParam structure with preset values.
|
||||
* @return A default RKLLMParam structure.
|
||||
*/
|
||||
RKLLMParam rkllm_createDefaultParam();
|
||||
|
||||
/**
|
||||
* @brief Initializes the LLM with the given parameters.
|
||||
* @param handle Pointer to the LLM handle.
|
||||
* @param param Configuration parameters for the LLM.
|
||||
* @param callback Callback function to handle LLM results.
|
||||
* @return Status code (0 for success, non-zero for failure).
|
||||
*/
|
||||
int rkllm_init(LLMHandle* handle, RKLLMParam* param, LLMResultCallback callback);
|
||||
|
||||
/**
|
||||
* @brief Loads a Lora adapter into the LLM.
|
||||
* @param handle LLM handle.
|
||||
* @param lora_adapter Pointer to the Lora adapter structure.
|
||||
* @return Status code (0 for success, non-zero for failure).
|
||||
*/
|
||||
int rkllm_load_lora(LLMHandle handle, RKLLMLoraAdapter* lora_adapter);
|
||||
|
||||
/**
|
||||
* @brief Loads a prompt cache from a file.
|
||||
* @param handle LLM handle.
|
||||
* @param prompt_cache_path Path to the prompt cache file.
|
||||
* @return Status code (0 for success, non-zero for failure).
|
||||
*/
|
||||
int rkllm_load_prompt_cache(LLMHandle handle, const char* prompt_cache_path);
|
||||
|
||||
/**
|
||||
* @brief Releases the prompt cache from memory.
|
||||
* @param handle LLM handle.
|
||||
* @return Status code (0 for success, non-zero for failure).
|
||||
*/
|
||||
int rkllm_release_prompt_cache(LLMHandle handle);
|
||||
|
||||
/**
|
||||
* @brief Destroys the LLM instance and releases resources.
|
||||
* @param handle LLM handle.
|
||||
* @return Status code (0 for success, non-zero for failure).
|
||||
*/
|
||||
int rkllm_destroy(LLMHandle handle);
|
||||
|
||||
/**
|
||||
* @brief Runs an LLM inference task synchronously.
|
||||
* @param handle LLM handle.
|
||||
* @param rkllm_input Input data for the LLM.
|
||||
* @param rkllm_infer_params Parameters for the inference task.
|
||||
* @param userdata Pointer to user data for the callback.
|
||||
* @return Status code (0 for success, non-zero for failure).
|
||||
*/
|
||||
int rkllm_run(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
|
||||
|
||||
/**
|
||||
* @brief Runs an LLM inference task asynchronously.
|
||||
* @param handle LLM handle.
|
||||
* @param rkllm_input Input data for the LLM.
|
||||
* @param rkllm_infer_params Parameters for the inference task.
|
||||
* @param userdata Pointer to user data for the callback.
|
||||
* @return Status code (0 for success, non-zero for failure).
|
||||
*/
|
||||
int rkllm_run_async(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
|
||||
|
||||
/**
|
||||
* @brief Aborts an ongoing LLM task.
|
||||
* @param handle LLM handle.
|
||||
* @return Status code (0 for success, non-zero for failure).
|
||||
*/
|
||||
int rkllm_abort(LLMHandle handle);
|
||||
|
||||
/**
|
||||
* @brief Checks if an LLM task is currently running.
|
||||
* @param handle LLM handle.
|
||||
* @return Status code (0 if a task is running, non-zero for otherwise).
|
||||
*/
|
||||
int rkllm_is_running(LLMHandle handle);
|
||||
|
||||
/**
|
||||
* @brief Clear the key-value cache for a given LLM handle.
|
||||
*
|
||||
* This function is used to clear part or all of the KV cache.
|
||||
*
|
||||
* @param handle LLM handle.
|
||||
* @param keep_system_prompt Flag indicating whether to retain the system prompt in the cache (1 to retain, 0 to clear).
|
||||
* This flag is ignored if a specific range [start_pos, end_pos) is provided.
|
||||
* @param start_pos Array of start positions (inclusive) of the KV cache ranges to clear, one per batch.
|
||||
* @param end_pos Array of end positions (exclusive) of the KV cache ranges to clear, one per batch.
|
||||
* If both start_pos and end_pos are set to nullptr, the entire cache will be cleared and keep_system_prompt will take effect,
|
||||
* If start_pos[i] < end_pos[i], only the specified range will be cleared, and keep_system_prompt will be ignored.
|
||||
* @note: start_pos or end_pos is only valid when keep_history == 0 and the generation has been paused by returning 1 in the callback
|
||||
* @return Status code (0 if cache was cleared successfully, non-zero otherwise).
|
||||
*/
|
||||
int rkllm_clear_kv_cache(LLMHandle handle, int keep_system_prompt, int* start_pos, int* end_pos);
|
||||
|
||||
/**
|
||||
* @brief Get the current size of the key-value cache for a given LLM handle.
|
||||
*
|
||||
* This function returns the total number of positions currently stored in the model's KV cache.
|
||||
*
|
||||
* @param handle LLM handle.
|
||||
* @param cache_sizes Pointer to an array where the per-batch cache sizes will be stored.
|
||||
* The array must be preallocated with space for `n_batch` elements.
|
||||
*/
|
||||
int rkllm_get_kv_cache_size(LLMHandle handle, int* cache_sizes);
|
||||
|
||||
/**
|
||||
* @brief Sets the chat template for the LLM, including system prompt, prefix, and postfix.
|
||||
*
|
||||
* This function allows you to customize the chat template by providing a system prompt, a prompt prefix, and a prompt postfix.
|
||||
* The system prompt is typically used to define the behavior or context of the language model,
|
||||
* while the prefix and postfix are used to format the user input and output respectively.
|
||||
*
|
||||
* @param handle LLM handle.
|
||||
* @param system_prompt The system prompt that defines the context or behavior of the language model.
|
||||
* @param prompt_prefix The prefix added before the user input in the chat.
|
||||
* @param prompt_postfix The postfix added after the user input in the chat.
|
||||
*
|
||||
* @return Status code (0 if the template was set successfully, non-zero for errors).
|
||||
*/
|
||||
int rkllm_set_chat_template(LLMHandle handle, const char* system_prompt, const char* prompt_prefix, const char* prompt_postfix);
|
||||
|
||||
/**
|
||||
* @brief Sets the function calling configuration for the LLM, including system prompt, tool definitions, and tool response token.
|
||||
*
|
||||
* @param handle LLM handle.
|
||||
* @param system_prompt The system prompt that defines the context or behavior of the language model.
|
||||
* @param tools A JSON-formatted string that defines the available functions, including their names, descriptions, and parameters.
|
||||
* @param tool_response_str A unique tag used to identify function call results within a conversation. It acts as the marker tag,
|
||||
* allowing tokenizer to recognize tool outputs separately from normal dialogue turns.
|
||||
* @return Status code (0 if the configuration was set successfully, non-zero for errors).
|
||||
*/
|
||||
int rkllm_set_function_tools(LLMHandle handle, const char* system_prompt, const char* tools, const char* tool_response_str);
|
||||
|
||||
/**
|
||||
* @brief Sets the cross-attention parameters for the LLM decoder.
|
||||
*
|
||||
* @param handle LLM handle.
|
||||
* @param cross_attn_params Pointer to the structure containing encoder-related input data
|
||||
* used for cross-attention (see RKLLMCrossAttnParam for details).
|
||||
*
|
||||
* @return Status code (0 if the parameters were set successfully, non-zero for errors).
|
||||
*/
|
||||
int rkllm_set_cross_attn_params(LLMHandle handle, RKLLMCrossAttnParam* cross_attn_params);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
@@ -2,10 +2,15 @@ package com.digitalperson
|
||||
|
||||
import android.content.Intent
|
||||
import android.os.Bundle
|
||||
import android.util.Log
|
||||
import androidx.appcompat.app.AppCompatActivity
|
||||
import com.digitalperson.config.AppConfig
|
||||
|
||||
class EntryActivity : AppCompatActivity() {
|
||||
companion object {
|
||||
private const val TAG = "EntryActivity"
|
||||
}
|
||||
|
||||
override fun onCreate(savedInstanceState: Bundle?) {
|
||||
super.onCreate(savedInstanceState)
|
||||
|
||||
@@ -14,6 +19,7 @@ class EntryActivity : AppCompatActivity() {
|
||||
} else {
|
||||
MainActivity::class.java
|
||||
}
|
||||
Log.i(TAG, "USE_LIVE2D=${AppConfig.Avatar.USE_LIVE2D}, target=${target.simpleName}")
|
||||
startActivity(Intent(this, target))
|
||||
finish()
|
||||
}
|
||||
|
||||
@@ -2,20 +2,37 @@ package com.digitalperson
|
||||
|
||||
import android.Manifest
|
||||
import android.content.pm.PackageManager
|
||||
import android.graphics.Bitmap
|
||||
import android.os.Bundle
|
||||
import android.util.Log
|
||||
import android.widget.Toast
|
||||
import androidx.camera.core.CameraSelector
|
||||
import androidx.camera.core.ImageAnalysis
|
||||
import androidx.camera.core.ImageProxy
|
||||
import androidx.camera.core.Preview
|
||||
import androidx.camera.lifecycle.ProcessCameraProvider
|
||||
import androidx.camera.view.PreviewView
|
||||
import androidx.appcompat.app.AppCompatActivity
|
||||
import androidx.core.app.ActivityCompat
|
||||
import androidx.core.content.ContextCompat
|
||||
import com.digitalperson.cloud.CloudApiManager
|
||||
import com.digitalperson.audio.AudioProcessor
|
||||
import com.digitalperson.vad.VadManager
|
||||
import com.digitalperson.asr.AsrManager
|
||||
import com.digitalperson.tts.TtsManager
|
||||
import com.digitalperson.ui.Live2DUiManager
|
||||
import com.digitalperson.config.AppConfig
|
||||
import com.digitalperson.face.FaceDetectionPipeline
|
||||
import com.digitalperson.face.FaceOverlayView
|
||||
import com.digitalperson.face.ImageProxyBitmapConverter
|
||||
import com.digitalperson.metrics.TraceManager
|
||||
import com.digitalperson.metrics.TraceSession
|
||||
import com.digitalperson.tts.TtsController
|
||||
import com.digitalperson.llm.LLMManager
|
||||
import com.digitalperson.llm.LLMManagerCallback
|
||||
import com.digitalperson.util.FileHelper
|
||||
import java.io.File
|
||||
import java.util.concurrent.ExecutorService
|
||||
import java.util.concurrent.Executors
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.Job
|
||||
@@ -26,14 +43,24 @@ import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.withContext
|
||||
|
||||
class Live2DChatActivity : AppCompatActivity() {
|
||||
companion object {
|
||||
private const val TAG_ACTIVITY = "Live2DChatActivity"
|
||||
private const val TAG_LLM = "LLM_ROUTE"
|
||||
}
|
||||
|
||||
private lateinit var uiManager: Live2DUiManager
|
||||
private lateinit var vadManager: VadManager
|
||||
private lateinit var asrManager: AsrManager
|
||||
private lateinit var ttsManager: TtsManager
|
||||
private lateinit var ttsController: TtsController
|
||||
private lateinit var audioProcessor: AudioProcessor
|
||||
private var llmManager: LLMManager? = null
|
||||
private var useLocalLLM = false // 默认使用云端 LLM
|
||||
|
||||
private val permissions: Array<String> = arrayOf(Manifest.permission.RECORD_AUDIO)
|
||||
private val appPermissions: Array<String> = arrayOf(
|
||||
Manifest.permission.RECORD_AUDIO,
|
||||
Manifest.permission.CAMERA
|
||||
)
|
||||
private val micPermissions: Array<String> = arrayOf(Manifest.permission.RECORD_AUDIO)
|
||||
|
||||
@Volatile
|
||||
private var isRecording: Boolean = false
|
||||
@@ -55,23 +82,46 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
@Volatile private var llmInFlight: Boolean = false
|
||||
private var enableStreaming = false
|
||||
|
||||
private lateinit var cameraPreviewView: PreviewView
|
||||
private lateinit var faceOverlayView: FaceOverlayView
|
||||
private lateinit var faceDetectionPipeline: FaceDetectionPipeline
|
||||
private var facePipelineReady: Boolean = false
|
||||
private var cameraProvider: ProcessCameraProvider? = null
|
||||
private lateinit var cameraAnalyzerExecutor: ExecutorService
|
||||
|
||||
override fun onRequestPermissionsResult(
|
||||
requestCode: Int,
|
||||
permissions: Array<String>,
|
||||
grantResults: IntArray
|
||||
) {
|
||||
super.onRequestPermissionsResult(requestCode, permissions, grantResults)
|
||||
val ok = requestCode == AppConfig.REQUEST_RECORD_AUDIO_PERMISSION &&
|
||||
grantResults.isNotEmpty() &&
|
||||
grantResults[0] == PackageManager.PERMISSION_GRANTED
|
||||
if (!ok) {
|
||||
if (requestCode != AppConfig.REQUEST_RECORD_AUDIO_PERMISSION) return
|
||||
if (grantResults.isEmpty()) {
|
||||
finish()
|
||||
return
|
||||
}
|
||||
val granted = permissions.zip(grantResults.toTypedArray()).associate { it.first to it.second }
|
||||
val micGranted = granted[Manifest.permission.RECORD_AUDIO] == PackageManager.PERMISSION_GRANTED
|
||||
val cameraGranted = granted[Manifest.permission.CAMERA] == PackageManager.PERMISSION_GRANTED
|
||||
|
||||
if (!micGranted) {
|
||||
Log.e(AppConfig.TAG, "Audio record is disallowed")
|
||||
finish()
|
||||
return
|
||||
}
|
||||
if (!cameraGranted) {
|
||||
uiManager.showToast("未授予相机权限,暂不启用人脸检测")
|
||||
Log.w(AppConfig.TAG, "Camera permission denied")
|
||||
return
|
||||
}
|
||||
if (facePipelineReady) {
|
||||
startCameraPreviewAndDetection()
|
||||
}
|
||||
}
|
||||
|
||||
override fun onCreate(savedInstanceState: Bundle?) {
|
||||
super.onCreate(savedInstanceState)
|
||||
Log.i(TAG_ACTIVITY, "onCreate")
|
||||
setContentView(R.layout.activity_live2d_chat)
|
||||
|
||||
uiManager = Live2DUiManager(this)
|
||||
@@ -82,11 +132,29 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
stopButtonId = R.id.stop_button,
|
||||
recordButtonId = R.id.record_button,
|
||||
traditionalButtonsId = R.id.traditional_buttons,
|
||||
llmModeSwitchId = R.id.llm_mode_switch,
|
||||
llmModeSwitchRowId = R.id.llm_mode_switch_row,
|
||||
silentPlayerViewId = 0,
|
||||
speakingPlayerViewId = 0,
|
||||
live2dViewId = R.id.live2d_view
|
||||
)
|
||||
|
||||
cameraPreviewView = findViewById(R.id.camera_preview)
|
||||
cameraPreviewView.implementationMode = PreviewView.ImplementationMode.COMPATIBLE
|
||||
faceOverlayView = findViewById(R.id.face_overlay)
|
||||
cameraAnalyzerExecutor = Executors.newSingleThreadExecutor()
|
||||
faceDetectionPipeline = FaceDetectionPipeline(
|
||||
context = applicationContext,
|
||||
onResult = { result ->
|
||||
faceOverlayView.updateResult(result)
|
||||
},
|
||||
onGreeting = { greeting ->
|
||||
uiManager.appendToUi("\n[Face] $greeting\n")
|
||||
ttsController.enqueueSegment(greeting)
|
||||
ttsController.enqueueEnd()
|
||||
}
|
||||
)
|
||||
|
||||
// 根据配置选择交互方式
|
||||
uiManager.setUseHoldToSpeak(AppConfig.USE_HOLD_TO_SPEAK)
|
||||
|
||||
@@ -105,7 +173,7 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
uiManager.setStopButtonListener { onStopClicked(userInitiated = true) }
|
||||
}
|
||||
|
||||
ActivityCompat.requestPermissions(this, permissions, AppConfig.REQUEST_RECORD_AUDIO_PERMISSION)
|
||||
ActivityCompat.requestPermissions(this, appPermissions, AppConfig.REQUEST_RECORD_AUDIO_PERMISSION)
|
||||
|
||||
try {
|
||||
val streamingSwitch = findViewById<android.widget.Switch>(R.id.streaming_switch)
|
||||
@@ -119,6 +187,27 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
Log.w(AppConfig.TAG, "Streaming switch not found in layout: ${e.message}")
|
||||
}
|
||||
|
||||
try {
|
||||
val ttsModeSwitch = findViewById<android.widget.Switch>(R.id.tts_mode_switch)
|
||||
ttsModeSwitch.isChecked = false // 默认使用本地TTS
|
||||
ttsModeSwitch.setOnCheckedChangeListener { _, isChecked ->
|
||||
ttsController.setUseQCloudTts(isChecked)
|
||||
uiManager.showToast("TTS模式已切换到${if (isChecked) "腾讯云" else "本地"}")
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.w(AppConfig.TAG, "TTS mode switch not found in layout: ${e.message}")
|
||||
}
|
||||
|
||||
// 设置 LLM 模式开关
|
||||
uiManager.setLLMSwitchListener { isChecked ->
|
||||
useLocalLLM = isChecked
|
||||
Log.i(TAG_LLM, "LLM mode switched: useLocalLLM=$useLocalLLM")
|
||||
uiManager.showToast("LLM模式已切换到${if (isChecked) "本地" else "云端"}")
|
||||
// 重新初始化 LLM
|
||||
initLLM()
|
||||
}
|
||||
// 默认不显示 LLM 开关,等模型下载完成后再显示
|
||||
|
||||
if (AppConfig.USE_HOLD_TO_SPEAK) {
|
||||
uiManager.setButtonsEnabled(recordEnabled = false)
|
||||
} else {
|
||||
@@ -127,8 +216,8 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
uiManager.setText("初始化中…")
|
||||
|
||||
audioProcessor = AudioProcessor(this)
|
||||
ttsManager = TtsManager(this)
|
||||
ttsManager.setCallback(createTtsCallback())
|
||||
ttsController = TtsController(this)
|
||||
ttsController.setCallback(createTtsCallback())
|
||||
|
||||
asrManager = AsrManager(this)
|
||||
asrManager.setAudioProcessor(audioProcessor)
|
||||
@@ -137,6 +226,64 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
vadManager = VadManager(this)
|
||||
vadManager.setCallback(createVadCallback())
|
||||
|
||||
// 初始化 LLM 管理器
|
||||
initLLM()
|
||||
|
||||
// 检查是否需要下载模型
|
||||
if (!FileHelper.isLocalLLMAvailable(this)) {
|
||||
// 显示下载进度对话框
|
||||
uiManager.showDownloadProgressDialog()
|
||||
|
||||
// 异步下载模型文件
|
||||
FileHelper.downloadModelFilesWithProgress(
|
||||
this,
|
||||
onProgress = { fileName, downloaded, total, progress ->
|
||||
runOnUiThread {
|
||||
val downloadedMB = downloaded / (1024 * 1024)
|
||||
val totalMB = total / (1024 * 1024)
|
||||
uiManager.updateDownloadProgress(
|
||||
fileName,
|
||||
downloadedMB,
|
||||
totalMB,
|
||||
progress
|
||||
)
|
||||
}
|
||||
},
|
||||
onComplete = { success, message ->
|
||||
runOnUiThread {
|
||||
uiManager.dismissDownloadProgressDialog()
|
||||
if (success) {
|
||||
Log.i(AppConfig.TAG, "Model files downloaded successfully")
|
||||
uiManager.showToast("模型下载完成", Toast.LENGTH_SHORT)
|
||||
// 检查本地 LLM 是否可用
|
||||
if (FileHelper.isLocalLLMAvailable(this)) {
|
||||
Log.i(AppConfig.TAG, "Local LLM is available, enabling local LLM switch")
|
||||
// 显示本地 LLM 开关,并同步状态
|
||||
uiManager.showLLMSwitch(true)
|
||||
uiManager.setLLMSwitchChecked(useLocalLLM)
|
||||
}
|
||||
} else {
|
||||
Log.e(AppConfig.TAG, "Failed to download model files: $message")
|
||||
uiManager.showToast("模型下载失败: $message", Toast.LENGTH_LONG)
|
||||
}
|
||||
// 下载完成后初始化其他组件
|
||||
initializeOtherComponents()
|
||||
}
|
||||
}
|
||||
)
|
||||
} else {
|
||||
// 模型已存在,直接初始化其他组件
|
||||
initializeOtherComponents()
|
||||
// 显示本地 LLM 开关,并同步状态
|
||||
uiManager.showLLMSwitch(true)
|
||||
uiManager.setLLMSwitchChecked(useLocalLLM)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 初始化其他组件(VAD、ASR、TTS、人脸检测等)
|
||||
*/
|
||||
private fun initializeOtherComponents() {
|
||||
ioScope.launch {
|
||||
try {
|
||||
Log.i(AppConfig.TAG, "Init VAD + SenseVoice(RKNN) + TTS (background)")
|
||||
@@ -144,7 +291,8 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
vadManager.initVadModel()
|
||||
asrManager.initSenseVoiceModel()
|
||||
}
|
||||
val ttsOk = ttsManager.initTtsAndAudioTrack()
|
||||
val ttsOk = ttsController.init()
|
||||
facePipelineReady = faceDetectionPipeline.initialize()
|
||||
withContext(Dispatchers.Main) {
|
||||
if (!ttsOk) {
|
||||
uiManager.showToast(
|
||||
@@ -152,6 +300,11 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
Toast.LENGTH_LONG
|
||||
)
|
||||
}
|
||||
if (!facePipelineReady) {
|
||||
uiManager.showToast("RetinaFace 初始化失败,请检查模型和 rknn 运行库", Toast.LENGTH_LONG)
|
||||
} else if (allPermissionsGranted()) {
|
||||
startCameraPreviewAndDetection()
|
||||
}
|
||||
uiManager.setText(getString(R.string.hint))
|
||||
if (AppConfig.USE_HOLD_TO_SPEAK) {
|
||||
uiManager.setButtonsEnabled(recordEnabled = true)
|
||||
@@ -203,14 +356,22 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
Log.d(AppConfig.TAG, "ASR segment skipped: $reason")
|
||||
}
|
||||
|
||||
override fun shouldSkipAsr(): Boolean = ttsManager.isPlaying()
|
||||
override fun shouldSkipAsr(): Boolean = ttsController.isPlaying()
|
||||
|
||||
override fun isLlmInFlight(): Boolean = llmInFlight
|
||||
|
||||
override fun onLlmCalled(text: String) {
|
||||
llmInFlight = true
|
||||
Log.d(AppConfig.TAG, "Calling LLM with text: $text")
|
||||
cloudApiManager.callLLM(text)
|
||||
if (useLocalLLM) {
|
||||
Log.i(TAG_LLM, "Routing to LOCAL LLM")
|
||||
// 使用本地 LLM 生成回复
|
||||
generateResponse(text)
|
||||
} else {
|
||||
Log.i(TAG_LLM, "Routing to CLOUD LLM")
|
||||
// 使用云端 LLM 生成回复
|
||||
cloudApiManager.callLLM(text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -220,7 +381,7 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
asrManager.enqueueAudioSegment(originalAudio, processedAudio)
|
||||
}
|
||||
|
||||
override fun shouldSkipProcessing(): Boolean = ttsManager.isPlaying() || llmInFlight
|
||||
override fun shouldSkipProcessing(): Boolean = ttsController.isPlaying() || llmInFlight
|
||||
}
|
||||
|
||||
private fun createCloudApiListener() = object : CloudApiManager.CloudApiListener {
|
||||
@@ -232,9 +393,9 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
|
||||
if (enableStreaming) {
|
||||
for (seg in segmenter.flush()) {
|
||||
ttsManager.enqueueSegment(seg)
|
||||
ttsController.enqueueSegment(seg)
|
||||
}
|
||||
ttsManager.enqueueEnd()
|
||||
ttsController.enqueueEnd()
|
||||
} else {
|
||||
val previousMood = com.digitalperson.mood.MoodManager.getCurrentMood()
|
||||
val (filteredText, mood) = com.digitalperson.mood.MoodManager.extractAndFilterMood(response)
|
||||
@@ -247,8 +408,8 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
runOnUiThread {
|
||||
uiManager.appendToUi("${filteredText}\n")
|
||||
}
|
||||
ttsManager.enqueueSegment(filteredText)
|
||||
ttsManager.enqueueEnd()
|
||||
ttsController.enqueueSegment(filteredText)
|
||||
ttsController.enqueueEnd()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -271,7 +432,7 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
|
||||
val segments = segmenter.processChunk(filteredText)
|
||||
for (seg in segments) {
|
||||
ttsManager.enqueueSegment(seg)
|
||||
ttsController.enqueueSegment(seg)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -285,7 +446,7 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
}
|
||||
}
|
||||
|
||||
private fun createTtsCallback() = object : TtsManager.TtsCallback {
|
||||
private fun createTtsCallback() = object : TtsController.TtsCallback {
|
||||
override fun onTtsStarted(text: String) {
|
||||
runOnUiThread {
|
||||
uiManager.appendToUi("\n[TTS] 开始合成...\n")
|
||||
@@ -310,32 +471,6 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
uiManager.setSpeaking(speaking)
|
||||
}
|
||||
|
||||
override fun getCurrentTrace(): TraceSession? = currentTrace
|
||||
|
||||
override fun onTraceMarkTtsRequestEnqueued() {
|
||||
currentTrace?.markTtsRequestEnqueued()
|
||||
}
|
||||
|
||||
override fun onTraceMarkTtsSynthesisStart() {
|
||||
currentTrace?.markTtsSynthesisStart()
|
||||
}
|
||||
|
||||
override fun onTraceMarkTtsFirstPcmReady() {
|
||||
currentTrace?.markTtsFirstPcmReady()
|
||||
}
|
||||
|
||||
override fun onTraceMarkTtsFirstAudioPlay() {
|
||||
currentTrace?.markTtsFirstAudioPlay()
|
||||
}
|
||||
|
||||
override fun onTraceMarkTtsDone() {
|
||||
currentTrace?.markTtsDone()
|
||||
}
|
||||
|
||||
override fun onTraceAddDuration(name: String, value: Long) {
|
||||
currentTrace?.addDuration(name, value)
|
||||
}
|
||||
|
||||
override fun onEndTurn() {
|
||||
TraceManager.getInstance().endTurn()
|
||||
currentTrace = null
|
||||
@@ -344,27 +479,97 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
|
||||
override fun onDestroy() {
|
||||
super.onDestroy()
|
||||
stopCameraPreviewAndDetection()
|
||||
onStopClicked(userInitiated = false)
|
||||
ioScope.cancel()
|
||||
synchronized(nativeLock) {
|
||||
try { vadManager.release() } catch (_: Throwable) {}
|
||||
try { asrManager.release() } catch (_: Throwable) {}
|
||||
}
|
||||
try { ttsManager.release() } catch (_: Throwable) {}
|
||||
try { faceDetectionPipeline.release() } catch (_: Throwable) {}
|
||||
try { cameraAnalyzerExecutor.shutdown() } catch (_: Throwable) {}
|
||||
try { ttsController.release() } catch (_: Throwable) {}
|
||||
try { llmManager?.destroy() } catch (_: Throwable) {}
|
||||
try { uiManager.release() } catch (_: Throwable) {}
|
||||
try { audioProcessor.release() } catch (_: Throwable) {}
|
||||
}
|
||||
|
||||
override fun onResume() {
|
||||
super.onResume()
|
||||
Log.i(TAG_ACTIVITY, "onResume")
|
||||
uiManager.onResume()
|
||||
if (facePipelineReady && allPermissionsGranted()) {
|
||||
startCameraPreviewAndDetection()
|
||||
}
|
||||
}
|
||||
|
||||
override fun onPause() {
|
||||
Log.i(TAG_ACTIVITY, "onPause")
|
||||
stopCameraPreviewAndDetection()
|
||||
uiManager.onPause()
|
||||
super.onPause()
|
||||
}
|
||||
|
||||
private fun allPermissionsGranted(): Boolean {
|
||||
return appPermissions.all {
|
||||
ContextCompat.checkSelfPermission(this, it) == PackageManager.PERMISSION_GRANTED
|
||||
}
|
||||
}
|
||||
|
||||
private fun startCameraPreviewAndDetection() {
|
||||
val cameraProviderFuture = ProcessCameraProvider.getInstance(this)
|
||||
cameraProviderFuture.addListener({
|
||||
try {
|
||||
val provider = cameraProviderFuture.get()
|
||||
cameraProvider = provider
|
||||
provider.unbindAll()
|
||||
|
||||
val preview = Preview.Builder().build().apply {
|
||||
setSurfaceProvider(cameraPreviewView.surfaceProvider)
|
||||
}
|
||||
cameraPreviewView.scaleType = PreviewView.ScaleType.FIT_CENTER
|
||||
|
||||
val analyzer = ImageAnalysis.Builder()
|
||||
.setBackpressureStrategy(ImageAnalysis.STRATEGY_KEEP_ONLY_LATEST)
|
||||
.build()
|
||||
|
||||
analyzer.setAnalyzer(cameraAnalyzerExecutor) { imageProxy ->
|
||||
analyzeCameraFrame(imageProxy)
|
||||
}
|
||||
|
||||
val selector = CameraSelector.Builder()
|
||||
.requireLensFacing(CameraSelector.LENS_FACING_FRONT)
|
||||
.build()
|
||||
|
||||
provider.bindToLifecycle(this, selector, preview, analyzer)
|
||||
} catch (t: Throwable) {
|
||||
Log.e(AppConfig.TAG, "startCameraPreviewAndDetection failed: ${t.message}", t)
|
||||
}
|
||||
}, ContextCompat.getMainExecutor(this))
|
||||
}
|
||||
|
||||
private fun stopCameraPreviewAndDetection() {
|
||||
try {
|
||||
cameraProvider?.unbindAll()
|
||||
} catch (_: Throwable) {
|
||||
} finally {
|
||||
cameraProvider = null
|
||||
}
|
||||
}
|
||||
|
||||
private fun analyzeCameraFrame(imageProxy: ImageProxy) {
|
||||
try {
|
||||
val bitmap: Bitmap? = ImageProxyBitmapConverter.toBitmap(imageProxy)
|
||||
if (bitmap != null) {
|
||||
faceDetectionPipeline.submitFrame(bitmap)
|
||||
}
|
||||
} catch (t: Throwable) {
|
||||
Log.w(AppConfig.TAG, "analyzeCameraFrame error: ${t.message}")
|
||||
} finally {
|
||||
imageProxy.close()
|
||||
}
|
||||
}
|
||||
|
||||
private fun onStartClicked() {
|
||||
Log.d(AppConfig.TAG, "onStartClicked called")
|
||||
if (isRecording) {
|
||||
@@ -372,7 +577,7 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
return
|
||||
}
|
||||
|
||||
if (!audioProcessor.initMicrophone(permissions, AppConfig.REQUEST_RECORD_AUDIO_PERMISSION)) {
|
||||
if (!audioProcessor.initMicrophone(micPermissions, AppConfig.REQUEST_RECORD_AUDIO_PERMISSION)) {
|
||||
uiManager.showToast("麦克风初始化失败/无权限")
|
||||
return
|
||||
}
|
||||
@@ -383,8 +588,7 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
|
||||
uiManager.clearText()
|
||||
|
||||
ttsManager.reset()
|
||||
ttsManager.setCurrentTrace(currentTrace)
|
||||
ttsController.reset()
|
||||
segmenter.reset()
|
||||
|
||||
vadManager.reset()
|
||||
@@ -409,12 +613,12 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
}
|
||||
|
||||
// 如果TTS正在播放,打断它
|
||||
val interrupted = ttsManager.interruptForNewTurn()
|
||||
val interrupted = ttsController.interruptForNewTurn()
|
||||
if (interrupted) {
|
||||
uiManager.appendToUi("\n[LOG] 已打断TTS播放\n")
|
||||
}
|
||||
|
||||
if (!audioProcessor.initMicrophone(permissions, AppConfig.REQUEST_RECORD_AUDIO_PERMISSION)) {
|
||||
if (!audioProcessor.initMicrophone(micPermissions, AppConfig.REQUEST_RECORD_AUDIO_PERMISSION)) {
|
||||
uiManager.showToast("麦克风初始化失败/无权限")
|
||||
return
|
||||
}
|
||||
@@ -427,7 +631,7 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
|
||||
// interruptForNewTurn() already prepared TTS state for next turn.
|
||||
// Keep reset() only for non-interrupt entry points.
|
||||
ttsManager.setCurrentTrace(currentTrace)
|
||||
|
||||
segmenter.reset()
|
||||
|
||||
// 启动按住说话的动作
|
||||
@@ -479,7 +683,7 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
recordingJob?.cancel()
|
||||
recordingJob = null
|
||||
|
||||
ttsManager.stop()
|
||||
ttsController.stop()
|
||||
|
||||
if (AppConfig.USE_HOLD_TO_SPEAK) {
|
||||
uiManager.setButtonsEnabled(recordEnabled = true)
|
||||
@@ -515,10 +719,10 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
while (isRecording && ioScope.coroutineContext.isActive) {
|
||||
loopCount++
|
||||
if (loopCount % 100 == 0) {
|
||||
Log.d(AppConfig.TAG, "processSamplesLoop running, loopCount=$loopCount, ttsPlaying=${ttsManager.isPlaying()}")
|
||||
Log.d(AppConfig.TAG, "processSamplesLoop running, loopCount=$loopCount, ttsPlaying=${ttsController.isPlaying()}")
|
||||
}
|
||||
|
||||
if (ttsManager.isPlaying()) {
|
||||
if (ttsController.isPlaying()) {
|
||||
if (vadManager.isInSpeech()) {
|
||||
Log.d(AppConfig.TAG, "TTS playing, resetting VAD state")
|
||||
vadManager.clearState()
|
||||
@@ -546,11 +750,134 @@ class Live2DChatActivity : AppCompatActivity() {
|
||||
}
|
||||
|
||||
val forced = segmenter.maybeForceByTime()
|
||||
for (seg in forced) ttsManager.enqueueSegment(seg)
|
||||
for (seg in forced) ttsController.enqueueSegment(seg)
|
||||
}
|
||||
|
||||
vadManager.forceFinalize()
|
||||
}
|
||||
Log.d(AppConfig.TAG, "processSamplesLoop stopped")
|
||||
}
|
||||
|
||||
/**
|
||||
* 初始化 LLM 管理器
|
||||
*/
|
||||
private fun initLLM() {
|
||||
try {
|
||||
Log.i(TAG_LLM, "initLLM called, useLocalLLM=$useLocalLLM")
|
||||
llmManager?.destroy()
|
||||
llmManager = null
|
||||
if (useLocalLLM) {
|
||||
// // 本地 LLM 初始化前,先暂停/释放重模块
|
||||
// Log.i(AppConfig.TAG, "Pausing camera and releasing face detection before LLM initialization")
|
||||
// stopCameraPreviewAndDetection()
|
||||
// try {
|
||||
// faceDetectionPipeline.release()
|
||||
// Log.i(AppConfig.TAG, "Face detection pipeline released")
|
||||
// } catch (e: Exception) {
|
||||
// Log.w(AppConfig.TAG, "Failed to release face detection pipeline: ${e.message}")
|
||||
// }
|
||||
|
||||
// // 释放 VAD 管理器
|
||||
// try {
|
||||
// vadManager.release()
|
||||
// Log.i(AppConfig.TAG, "VAD manager released")
|
||||
// } catch (e: Exception) {
|
||||
// Log.w(AppConfig.TAG, "Failed to release VAD manager: ${e.message}")
|
||||
// }
|
||||
|
||||
val modelPath = FileHelper.getLLMModelPath(applicationContext)
|
||||
if (!File(modelPath).exists()) {
|
||||
throw IllegalStateException("RKLLM model file missing: $modelPath")
|
||||
}
|
||||
Log.i(AppConfig.TAG, "Initializing LLM with model path: $modelPath")
|
||||
val localLlmResponseBuffer = StringBuilder()
|
||||
llmManager = LLMManager(modelPath, object : LLMManagerCallback {
|
||||
override fun onThinking(msg: String, finished: Boolean) {
|
||||
// 处理思考过程
|
||||
Log.d(TAG_LLM, "LOCAL onThinking finished=$finished msg=${msg.take(60)}")
|
||||
runOnUiThread {
|
||||
if (!finished && enableStreaming) {
|
||||
uiManager.appendToUi("\n[LLM] 思考中: $msg\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun onResult(msg: String, finished: Boolean) {
|
||||
// 处理生成结果
|
||||
Log.d(TAG_LLM, "LOCAL onResult finished=$finished len=${msg.length}")
|
||||
runOnUiThread {
|
||||
if (!finished) {
|
||||
localLlmResponseBuffer.append(msg)
|
||||
if (enableStreaming) {
|
||||
uiManager.appendToUi(msg)
|
||||
}
|
||||
} else {
|
||||
val finalText = localLlmResponseBuffer.toString().trim()
|
||||
localLlmResponseBuffer.setLength(0)
|
||||
if (!enableStreaming && finalText.isNotEmpty()) {
|
||||
uiManager.appendToUi("$finalText\n")
|
||||
}
|
||||
uiManager.appendToUi("\n\n[LLM] 生成完成\n")
|
||||
llmInFlight = false
|
||||
if (finalText.isNotEmpty()) {
|
||||
ttsController.enqueueSegment(finalText)
|
||||
ttsController.enqueueEnd()
|
||||
} else {
|
||||
Log.w(TAG_LLM, "LOCAL final text is empty, skip TTS enqueue")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
Log.i(AppConfig.TAG, "LLM initialized successfully")
|
||||
Log.i(TAG_LLM, "LOCAL LLM initialized")
|
||||
} else {
|
||||
// 使用云端 LLM,不需要初始化本地 LLM
|
||||
Log.i(AppConfig.TAG, "Using cloud LLM, skipping local LLM initialization")
|
||||
Log.i(TAG_LLM, "CLOUD mode active")
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.e(AppConfig.TAG, "Failed to initialize LLM: ${e.message}", e)
|
||||
Log.e(TAG_LLM, "LOCAL init failed: ${e.message}", e)
|
||||
useLocalLLM = false
|
||||
runOnUiThread {
|
||||
uiManager.setLLMSwitchChecked(false)
|
||||
uiManager.showToast("LLM 初始化失败: ${e.message}", Toast.LENGTH_LONG)
|
||||
uiManager.appendToUi("\n[错误] LLM 初始化失败: ${e.message}\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 使用 LLM 生成回复
|
||||
*/
|
||||
private fun generateResponse(userInput: String) {
|
||||
try {
|
||||
if (useLocalLLM) {
|
||||
val systemPrompt = "你是一个友好的数字人助手,回答要简洁明了。"
|
||||
Log.d(AppConfig.TAG, "Generating response for: $userInput")
|
||||
val local = llmManager
|
||||
if (local == null) {
|
||||
Log.e(TAG_LLM, "LOCAL LLM manager is null, fallback to CLOUD")
|
||||
cloudApiManager.callLLM(userInput)
|
||||
return
|
||||
}
|
||||
Log.i(TAG_LLM, "LOCAL generateResponseWithSystem")
|
||||
local.generateResponseWithSystem(systemPrompt, userInput)
|
||||
} else {
|
||||
// 使用云端 LLM
|
||||
Log.d(AppConfig.TAG, "Using cloud LLM for response: $userInput")
|
||||
Log.i(TAG_LLM, "CLOUD callLLM")
|
||||
// 调用云端 LLM
|
||||
cloudApiManager.callLLM(userInput)
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.e(AppConfig.TAG, "Failed to generate response: ${e.message}", e)
|
||||
Log.e(TAG_LLM, "generateResponse failed: ${e.message}", e)
|
||||
runOnUiThread {
|
||||
uiManager.appendToUi("\n\n[Error] LLM 生成失败: ${e.message}\n")
|
||||
llmInFlight = false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -35,6 +35,25 @@ object AppConfig {
|
||||
const val MODEL_DIR = "sensevoice_models"
|
||||
}
|
||||
|
||||
object Face {
|
||||
const val MODEL_DIR = "RetinaFace"
|
||||
const val MODEL_NAME = "RetinaFace_mobile320.rknn"
|
||||
const val INPUT_SIZE = 320
|
||||
const val SCORE_THRESHOLD = 0.6f
|
||||
const val NMS_THRESHOLD = 0.4f
|
||||
const val TRACK_IOU_THRESHOLD = 0.45f
|
||||
const val STABLE_MS = 1000L
|
||||
const val FRONTAL_MIN_FACE_SIZE = 90f
|
||||
const val FRONTAL_MAX_ASPECT_DIFF = 0.35f
|
||||
}
|
||||
|
||||
object FaceRecognition {
|
||||
const val MODEL_DIR = "Insightface"
|
||||
const val MODEL_NAME = "ms1mv3_arcface_r18.rknn"
|
||||
const val SIMILARITY_THRESHOLD = 0.5f
|
||||
const val GREETING_COOLDOWN_MS = 6000L
|
||||
}
|
||||
|
||||
object Audio {
|
||||
const val GAIN_SMOOTHING_FACTOR = 0.1f
|
||||
const val TARGET_RMS = 0.1f
|
||||
@@ -48,4 +67,10 @@ object AppConfig {
|
||||
const val MODEL_DIR = "live2d_model/Haru_pro_jp"
|
||||
const val MODEL_JSON = "haru_greeter_t05.model3.json"
|
||||
}
|
||||
|
||||
object QCloud {
|
||||
const val APP_ID = "1302849512" // 替换为你的腾讯云APP_ID
|
||||
const val SECRET_ID = "AKIDbBdyBGE5oPuIGA1iDlDYlFallaJ0YODB" // 替换为你的腾讯云SECRET_ID
|
||||
const val SECRET_KEY = "32vhIl9OQIRclmLjvuleLp9LLAnFVYEp" // 替换为你的腾讯云SECRET_KEY
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
package com.digitalperson.engine;
|
||||
|
||||
import android.content.Context;
|
||||
import android.graphics.Bitmap;
|
||||
import android.util.Log;
|
||||
|
||||
import com.digitalperson.config.AppConfig;
|
||||
import com.digitalperson.util.FileHelper;
|
||||
|
||||
import java.io.File;
|
||||
|
||||
public class ArcFaceEngineRKNN {
|
||||
private static final String TAG = "ArcFaceEngineRKNN";
|
||||
|
||||
static {
|
||||
try {
|
||||
System.loadLibrary("rknnrt");
|
||||
System.loadLibrary("sensevoiceEngine");
|
||||
Log.d(TAG, "Loaded native libs for ArcFace RKNN");
|
||||
} catch (UnsatisfiedLinkError e) {
|
||||
Log.e(TAG, "Failed to load native libraries for ArcFace", e);
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
private final long nativePtr;
|
||||
private boolean initialized = false;
|
||||
private boolean released = false;
|
||||
|
||||
public ArcFaceEngineRKNN() {
|
||||
nativePtr = createEngineNative();
|
||||
if (nativePtr == 0) {
|
||||
throw new RuntimeException("Failed to create native ArcFace engine");
|
||||
}
|
||||
}
|
||||
|
||||
public boolean initialize(Context context) {
|
||||
if (released) return false;
|
||||
File modelDir = FileHelper.copyInsightFaceAssets(context);
|
||||
File modelFile = new File(modelDir, AppConfig.FaceRecognition.MODEL_NAME);
|
||||
int ret = initNative(nativePtr, modelFile.getAbsolutePath());
|
||||
initialized = ret == 0;
|
||||
if (!initialized) {
|
||||
Log.e(TAG, "ArcFace init failed, code=" + ret + ", model=" + modelFile.getAbsolutePath());
|
||||
}
|
||||
return initialized;
|
||||
}
|
||||
|
||||
public float[] extractEmbedding(Bitmap bitmap, float left, float top, float right, float bottom) {
|
||||
Log.d(TAG, "extractEmbedding called: initialized=" + initialized + ", released=" + released + ", bitmap=" + (bitmap != null));
|
||||
if (!initialized || released || bitmap == null) {
|
||||
Log.w(TAG, "extractEmbedding failed: initialized=" + initialized + ", released=" + released + ", bitmap=" + (bitmap != null));
|
||||
return new float[0];
|
||||
}
|
||||
float[] emb = extractEmbeddingNative(nativePtr, bitmap, left, top, right, bottom);
|
||||
Log.d(TAG, "extractEmbeddingNative returned: " + (emb != null ? emb.length : "null"));
|
||||
return emb != null ? emb : new float[0];
|
||||
}
|
||||
|
||||
public void release() {
|
||||
if (!released && nativePtr != 0) {
|
||||
releaseNative(nativePtr);
|
||||
}
|
||||
released = true;
|
||||
initialized = false;
|
||||
}
|
||||
|
||||
private native long createEngineNative();
|
||||
private native int initNative(long ptr, String modelPath);
|
||||
private native float[] extractEmbeddingNative(
|
||||
long ptr,
|
||||
Bitmap bitmap,
|
||||
float left,
|
||||
float top,
|
||||
float right,
|
||||
float bottom
|
||||
);
|
||||
private native void releaseNative(long ptr);
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
package com.digitalperson.engine;
|
||||
|
||||
import android.content.Context;
|
||||
import android.graphics.Bitmap;
|
||||
import android.util.Log;
|
||||
|
||||
import com.digitalperson.config.AppConfig;
|
||||
import com.digitalperson.util.FileHelper;
|
||||
|
||||
import java.io.File;
|
||||
|
||||
public class RetinaFaceEngineRKNN {
|
||||
private static final String TAG = "RetinaFaceEngineRKNN";
|
||||
|
||||
static {
|
||||
try {
|
||||
System.loadLibrary("rknnrt");
|
||||
System.loadLibrary("sensevoiceEngine");
|
||||
Log.d(TAG, "Loaded native libs for RetinaFace RKNN");
|
||||
} catch (UnsatisfiedLinkError e) {
|
||||
Log.e(TAG, "Failed to load native libraries for RetinaFace", e);
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
private final long nativePtr;
|
||||
private boolean initialized = false;
|
||||
private boolean released = false;
|
||||
|
||||
public RetinaFaceEngineRKNN() {
|
||||
nativePtr = createEngineNative();
|
||||
if (nativePtr == 0) {
|
||||
throw new RuntimeException("Failed to create native RetinaFace engine");
|
||||
}
|
||||
}
|
||||
|
||||
public boolean initialize(Context context) {
|
||||
if (released) {
|
||||
return false;
|
||||
}
|
||||
File modelDir = FileHelper.copyRetinaFaceAssets(context);
|
||||
File modelFile = new File(modelDir, AppConfig.Face.MODEL_NAME);
|
||||
int ret = initNative(
|
||||
nativePtr,
|
||||
modelFile.getAbsolutePath(),
|
||||
AppConfig.Face.INPUT_SIZE,
|
||||
AppConfig.Face.SCORE_THRESHOLD,
|
||||
AppConfig.Face.NMS_THRESHOLD
|
||||
);
|
||||
initialized = ret == 0;
|
||||
if (!initialized) {
|
||||
Log.e(TAG, "RetinaFace init failed, code=" + ret + ", model=" + modelFile.getAbsolutePath());
|
||||
}
|
||||
return initialized;
|
||||
}
|
||||
|
||||
public float[] detect(Bitmap bitmap) {
|
||||
if (!initialized || released || bitmap == null) {
|
||||
return new float[0];
|
||||
}
|
||||
float[] raw = detectNative(nativePtr, bitmap);
|
||||
return raw != null ? raw : new float[0];
|
||||
}
|
||||
|
||||
public void release() {
|
||||
if (!released && nativePtr != 0) {
|
||||
releaseNative(nativePtr);
|
||||
}
|
||||
released = true;
|
||||
initialized = false;
|
||||
}
|
||||
|
||||
private native long createEngineNative();
|
||||
private native int initNative(long ptr, String modelPath, int inputSize, float scoreThreshold, float nmsThreshold);
|
||||
private native float[] detectNative(long ptr, Bitmap bitmap);
|
||||
private native void releaseNative(long ptr);
|
||||
}
|
||||
@@ -0,0 +1,223 @@
|
||||
package com.digitalperson.face
|
||||
|
||||
import android.content.Context
|
||||
import android.graphics.Bitmap
|
||||
import android.util.Log
|
||||
import com.digitalperson.config.AppConfig
|
||||
import com.digitalperson.engine.RetinaFaceEngineRKNN
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
import kotlin.math.abs
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.SupervisorJob
|
||||
import kotlinx.coroutines.cancel
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.withContext
|
||||
|
||||
data class FaceBox(
|
||||
val left: Float,
|
||||
val top: Float,
|
||||
val right: Float,
|
||||
val bottom: Float,
|
||||
val score: Float,
|
||||
)
|
||||
|
||||
data class FaceDetectionResult(
|
||||
val sourceWidth: Int,
|
||||
val sourceHeight: Int,
|
||||
val faces: List<FaceBox>,
|
||||
)
|
||||
|
||||
class FaceDetectionPipeline(
|
||||
private val context: Context,
|
||||
private val onResult: (FaceDetectionResult) -> Unit,
|
||||
private val onGreeting: (String) -> Unit,
|
||||
) {
|
||||
private val engine = RetinaFaceEngineRKNN()
|
||||
private val recognizer = FaceRecognizer(context)
|
||||
private val scope = CoroutineScope(SupervisorJob() + Dispatchers.Default)
|
||||
private val frameInFlight = AtomicBoolean(false)
|
||||
private val initialized = AtomicBoolean(false)
|
||||
private var trackFace: FaceBox? = null
|
||||
private var trackId: Long = 0
|
||||
private var trackStableSinceMs: Long = 0
|
||||
private var greetedTrackId: Long = -1
|
||||
private var lastGreetMs: Long = 0
|
||||
|
||||
fun initialize(): Boolean {
|
||||
val detectorOk = engine.initialize(context)
|
||||
val recognizerOk = recognizer.initialize()
|
||||
val ok = detectorOk && recognizerOk
|
||||
initialized.set(ok)
|
||||
Log.i(AppConfig.TAG, "Face pipeline initialize result=$ok detector=$detectorOk recognizer=$recognizerOk")
|
||||
return ok
|
||||
}
|
||||
|
||||
fun submitFrame(bitmap: Bitmap) {
|
||||
if (!initialized.get()) {
|
||||
bitmap.recycle()
|
||||
return
|
||||
}
|
||||
if (!frameInFlight.compareAndSet(false, true)) {
|
||||
bitmap.recycle()
|
||||
return
|
||||
}
|
||||
|
||||
scope.launch {
|
||||
try {
|
||||
val width = bitmap.width
|
||||
val height = bitmap.height
|
||||
val raw = engine.detect(bitmap)
|
||||
|
||||
val faceCount = raw.size / 5
|
||||
val faces = ArrayList<FaceBox>(faceCount)
|
||||
var i = 0
|
||||
while (i + 4 < raw.size) {
|
||||
faces.add(
|
||||
FaceBox(
|
||||
left = raw[i],
|
||||
top = raw[i + 1],
|
||||
right = raw[i + 2],
|
||||
bottom = raw[i + 3],
|
||||
score = raw[i + 4],
|
||||
)
|
||||
)
|
||||
i += 5
|
||||
}
|
||||
// 过滤太小的人脸
|
||||
val minFaceSize = 50 // 最小人脸大小(像素)
|
||||
val filteredFaces = faces.filter { face ->
|
||||
val width = face.right - face.left
|
||||
val height = face.bottom - face.top
|
||||
width > minFaceSize && height > minFaceSize
|
||||
}
|
||||
|
||||
// if (filteredFaces.isNotEmpty()) {
|
||||
// Log.d(
|
||||
// AppConfig.TAG,"[Face] filtered detected ${filteredFaces.size} face(s)"
|
||||
// )
|
||||
// }
|
||||
|
||||
maybeRecognizeAndGreet(bitmap, filteredFaces)
|
||||
withContext(Dispatchers.Main) {
|
||||
onResult(FaceDetectionResult(width, height, filteredFaces))
|
||||
}
|
||||
} catch (t: Throwable) {
|
||||
Log.e(AppConfig.TAG, "Face detection pipeline failed: ${t.message}", t)
|
||||
} finally {
|
||||
bitmap.recycle()
|
||||
frameInFlight.set(false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private suspend fun maybeRecognizeAndGreet(bitmap: Bitmap, faces: List<FaceBox>) {
|
||||
val now = System.currentTimeMillis()
|
||||
if (faces.isEmpty()) {
|
||||
trackFace = null
|
||||
trackStableSinceMs = 0
|
||||
return
|
||||
}
|
||||
|
||||
val primary = faces.maxByOrNull { (it.right - it.left) * (it.bottom - it.top) } ?: return
|
||||
val prev = trackFace
|
||||
if (prev == null || iou(prev, primary) < AppConfig.Face.TRACK_IOU_THRESHOLD) {
|
||||
trackId += 1
|
||||
greetedTrackId = -1
|
||||
trackStableSinceMs = now
|
||||
}
|
||||
trackFace = primary
|
||||
|
||||
val stableMs = now - trackStableSinceMs
|
||||
val frontal = isFrontal(primary, bitmap.width, bitmap.height)
|
||||
val coolingDown = (now - lastGreetMs) < AppConfig.FaceRecognition.GREETING_COOLDOWN_MS
|
||||
if (stableMs < AppConfig.Face.STABLE_MS || !frontal || greetedTrackId == trackId || coolingDown) {
|
||||
return
|
||||
}
|
||||
|
||||
val match = recognizer.identify(bitmap, primary)
|
||||
|
||||
Log.d(AppConfig.TAG, "[Face] Recognition result: matchedName=${match.matchedName}, similarity=${match.similarity}")
|
||||
|
||||
// 检查是否需要保存新人脸
|
||||
if (match.matchedName.isNullOrBlank()) {
|
||||
Log.d(AppConfig.TAG, "[Face] No match found, attempting to add new face")
|
||||
// 提取人脸特征
|
||||
val embedding = extractEmbedding(bitmap, primary)
|
||||
Log.d(AppConfig.TAG, "[Face] Extracted embedding size: ${embedding.size}")
|
||||
if (embedding.isNotEmpty()) {
|
||||
// 尝试添加新人脸
|
||||
val added = recognizer.addNewFace(embedding)
|
||||
Log.d(AppConfig.TAG, "[Face] Add new face result: $added")
|
||||
if (added) {
|
||||
Log.i(AppConfig.TAG, "[Face] New face added to database")
|
||||
} else {
|
||||
Log.i(AppConfig.TAG, "[Face] Face already exists in database (similar face found)")
|
||||
}
|
||||
} else {
|
||||
Log.w(AppConfig.TAG, "[Face] Failed to extract embedding")
|
||||
}
|
||||
} else {
|
||||
Log.d(AppConfig.TAG, "[Face] Matched existing face: ${match.matchedName}")
|
||||
}
|
||||
|
||||
val greeting = if (!match.matchedName.isNullOrBlank()) {
|
||||
"你好,${match.matchedName}!"
|
||||
} else {
|
||||
"你好,很高兴见到你。"
|
||||
}
|
||||
greetedTrackId = trackId
|
||||
lastGreetMs = now
|
||||
Log.i(
|
||||
AppConfig.TAG,
|
||||
"[Face] greeting track=$trackId stable=${stableMs}ms frontal=$frontal matched=${match.matchedName} score=${match.similarity}"
|
||||
)
|
||||
withContext(Dispatchers.Main) {
|
||||
onGreeting(greeting)
|
||||
}
|
||||
}
|
||||
|
||||
private fun extractEmbedding(bitmap: Bitmap, face: FaceBox): FloatArray {
|
||||
return recognizer.extractEmbedding(bitmap, face)
|
||||
}
|
||||
|
||||
private fun isFrontal(face: FaceBox, frameW: Int, frameH: Int): Boolean {
|
||||
val w = face.right - face.left
|
||||
val h = face.bottom - face.top
|
||||
if (w < AppConfig.Face.FRONTAL_MIN_FACE_SIZE || h < AppConfig.Face.FRONTAL_MIN_FACE_SIZE) {
|
||||
return false
|
||||
}
|
||||
val aspectDiff = abs((w / h) - 1f)
|
||||
if (aspectDiff > AppConfig.Face.FRONTAL_MAX_ASPECT_DIFF) {
|
||||
return false
|
||||
}
|
||||
val cx = (face.left + face.right) * 0.5f
|
||||
val cy = (face.top + face.bottom) * 0.5f
|
||||
val minX = frameW * 0.15f
|
||||
val maxX = frameW * 0.85f
|
||||
val minY = frameH * 0.15f
|
||||
val maxY = frameH * 0.85f
|
||||
return cx in minX..maxX && cy in minY..maxY
|
||||
}
|
||||
|
||||
private fun iou(a: FaceBox, b: FaceBox): Float {
|
||||
val left = maxOf(a.left, b.left)
|
||||
val top = maxOf(a.top, b.top)
|
||||
val right = minOf(a.right, b.right)
|
||||
val bottom = minOf(a.bottom, b.bottom)
|
||||
val w = maxOf(0f, right - left)
|
||||
val h = maxOf(0f, bottom - top)
|
||||
val inter = w * h
|
||||
val areaA = maxOf(0f, a.right - a.left) * maxOf(0f, a.bottom - a.top)
|
||||
val areaB = maxOf(0f, b.right - b.left) * maxOf(0f, b.bottom - b.top)
|
||||
val union = areaA + areaB - inter
|
||||
return if (union <= 0f) 0f else inter / union
|
||||
}
|
||||
|
||||
fun release() {
|
||||
scope.cancel()
|
||||
engine.release()
|
||||
recognizer.release()
|
||||
initialized.set(false)
|
||||
}
|
||||
}
|
||||
93
app/src/main/java/com/digitalperson/face/FaceFeatureStore.kt
Normal file
93
app/src/main/java/com/digitalperson/face/FaceFeatureStore.kt
Normal file
@@ -0,0 +1,93 @@
|
||||
package com.digitalperson.face
|
||||
|
||||
import android.content.ContentValues
|
||||
import android.content.Context
|
||||
import android.database.sqlite.SQLiteDatabase
|
||||
import android.database.sqlite.SQLiteOpenHelper
|
||||
import android.util.Log
|
||||
import com.digitalperson.config.AppConfig
|
||||
import java.nio.ByteBuffer
|
||||
import java.nio.ByteOrder
|
||||
|
||||
data class FaceProfile(
|
||||
val id: Long,
|
||||
val name: String,
|
||||
val embedding: FloatArray,
|
||||
)
|
||||
|
||||
class FaceFeatureStore(context: Context) : SQLiteOpenHelper(context, DB_NAME, null, DB_VERSION) {
|
||||
override fun onCreate(db: SQLiteDatabase) {
|
||||
db.execSQL(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS face_profiles (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL UNIQUE,
|
||||
embedding BLOB NOT NULL,
|
||||
updated_at INTEGER NOT NULL
|
||||
)
|
||||
""".trimIndent()
|
||||
)
|
||||
}
|
||||
|
||||
override fun onUpgrade(db: SQLiteDatabase, oldVersion: Int, newVersion: Int) {
|
||||
db.execSQL("DROP TABLE IF EXISTS face_profiles")
|
||||
onCreate(db)
|
||||
}
|
||||
|
||||
fun loadAllProfiles(): List<FaceProfile> {
|
||||
val db = readableDatabase
|
||||
val list = ArrayList<FaceProfile>()
|
||||
db.rawQuery("SELECT id, name, embedding FROM face_profiles", null).use { c ->
|
||||
val idIdx = c.getColumnIndexOrThrow("id")
|
||||
val nameIdx = c.getColumnIndexOrThrow("name")
|
||||
val embIdx = c.getColumnIndexOrThrow("embedding")
|
||||
while (c.moveToNext()) {
|
||||
val embBlob = c.getBlob(embIdx) ?: continue
|
||||
list.add(
|
||||
FaceProfile(
|
||||
id = c.getLong(idIdx),
|
||||
name = c.getString(nameIdx),
|
||||
embedding = blobToFloatArray(embBlob),
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
return list
|
||||
}
|
||||
|
||||
fun upsertProfile(name: String, embedding: FloatArray) {
|
||||
// 确保名字不为null,使用空字符串作为默认值
|
||||
val safeName = name.takeIf { it.isNotBlank() } ?: ""
|
||||
val values = ContentValues().apply {
|
||||
put("name", safeName)
|
||||
put("embedding", floatArrayToBlob(embedding))
|
||||
put("updated_at", System.currentTimeMillis())
|
||||
}
|
||||
val rowId = writableDatabase.insertWithOnConflict(
|
||||
"face_profiles",
|
||||
null,
|
||||
values,
|
||||
SQLiteDatabase.CONFLICT_REPLACE
|
||||
)
|
||||
Log.i(AppConfig.TAG, "[FaceFeatureStore] upsertProfile name='$safeName' rowId=$rowId dim=${embedding.size}")
|
||||
}
|
||||
|
||||
private fun floatArrayToBlob(values: FloatArray): ByteArray {
|
||||
val buf = ByteBuffer.allocate(values.size * 4).order(ByteOrder.LITTLE_ENDIAN)
|
||||
for (v in values) buf.putFloat(v)
|
||||
return buf.array()
|
||||
}
|
||||
|
||||
private fun blobToFloatArray(blob: ByteArray): FloatArray {
|
||||
if (blob.isEmpty()) return FloatArray(0)
|
||||
val buf = ByteBuffer.wrap(blob).order(ByteOrder.LITTLE_ENDIAN)
|
||||
val out = FloatArray(blob.size / 4)
|
||||
for (i in out.indices) out[i] = buf.getFloat()
|
||||
return out
|
||||
}
|
||||
|
||||
companion object {
|
||||
private const val DB_NAME = "face_feature.db"
|
||||
private const val DB_VERSION = 1
|
||||
}
|
||||
}
|
||||
61
app/src/main/java/com/digitalperson/face/FaceOverlayView.kt
Normal file
61
app/src/main/java/com/digitalperson/face/FaceOverlayView.kt
Normal file
@@ -0,0 +1,61 @@
|
||||
package com.digitalperson.face
|
||||
|
||||
import android.content.Context
|
||||
import android.graphics.Canvas
|
||||
import android.graphics.Color
|
||||
import android.graphics.Paint
|
||||
import android.graphics.RectF
|
||||
import android.util.AttributeSet
|
||||
import android.view.View
|
||||
|
||||
class FaceOverlayView @JvmOverloads constructor(
|
||||
context: Context,
|
||||
attrs: AttributeSet? = null,
|
||||
) : View(context, attrs) {
|
||||
|
||||
private val boxPaint = Paint(Paint.ANTI_ALIAS_FLAG).apply {
|
||||
color = Color.GREEN
|
||||
style = Paint.Style.STROKE
|
||||
strokeWidth = 4f
|
||||
}
|
||||
|
||||
private val textPaint = Paint(Paint.ANTI_ALIAS_FLAG).apply {
|
||||
color = Color.GREEN
|
||||
textSize = 28f
|
||||
}
|
||||
|
||||
@Volatile
|
||||
private var latestResult: FaceDetectionResult? = null
|
||||
|
||||
fun updateResult(result: FaceDetectionResult) {
|
||||
latestResult = result
|
||||
postInvalidateOnAnimation()
|
||||
}
|
||||
|
||||
override fun onDraw(canvas: Canvas) {
|
||||
super.onDraw(canvas)
|
||||
val result = latestResult ?: return
|
||||
if (result.sourceWidth <= 0 || result.sourceHeight <= 0) return
|
||||
|
||||
val srcW = result.sourceWidth.toFloat()
|
||||
val srcH = result.sourceHeight.toFloat()
|
||||
val viewW = width.toFloat()
|
||||
val viewH = height.toFloat()
|
||||
if (viewW <= 0f || viewH <= 0f) return
|
||||
|
||||
val scale = minOf(viewW / srcW, viewH / srcH)
|
||||
val dx = (viewW - srcW * scale) / 2f
|
||||
val dy = (viewH - srcH * scale) / 2f
|
||||
|
||||
for (face in result.faces) {
|
||||
val rect = RectF(
|
||||
dx + face.left * scale,
|
||||
dy + face.top * scale,
|
||||
dx + face.right * scale,
|
||||
dy + face.bottom * scale,
|
||||
)
|
||||
canvas.drawRect(rect, boxPaint)
|
||||
canvas.drawText(String.format("%.2f", face.score), rect.left, rect.top - 8f, textPaint)
|
||||
}
|
||||
}
|
||||
}
|
||||
129
app/src/main/java/com/digitalperson/face/FaceRecognizer.kt
Normal file
129
app/src/main/java/com/digitalperson/face/FaceRecognizer.kt
Normal file
@@ -0,0 +1,129 @@
|
||||
package com.digitalperson.face
|
||||
|
||||
import android.content.Context
|
||||
import android.graphics.Bitmap
|
||||
import android.util.Log
|
||||
import com.digitalperson.config.AppConfig
|
||||
import com.digitalperson.engine.ArcFaceEngineRKNN
|
||||
import kotlin.math.sqrt
|
||||
|
||||
data class FaceRecognitionResult(
|
||||
val matchedName: String?,
|
||||
val similarity: Float,
|
||||
val embeddingDim: Int,
|
||||
)
|
||||
|
||||
class FaceRecognizer(context: Context) {
|
||||
private val appContext = context.applicationContext
|
||||
private val engine = ArcFaceEngineRKNN()
|
||||
private val store = FaceFeatureStore(appContext)
|
||||
private val cache = ArrayList<FaceProfile>()
|
||||
|
||||
@Volatile
|
||||
private var initialized = false
|
||||
|
||||
fun initialize(): Boolean {
|
||||
Log.d(AppConfig.TAG, "[FaceRecognizer] initialize: starting...")
|
||||
val ok = engine.initialize(appContext)
|
||||
Log.d(AppConfig.TAG, "[FaceRecognizer] initialize: engine.initialize() returned $ok")
|
||||
if (!ok) {
|
||||
initialized = false
|
||||
Log.e(AppConfig.TAG, "[FaceRecognizer] initialize: failed - engine initialization failed")
|
||||
return false
|
||||
}
|
||||
cache.clear()
|
||||
val profiles = store.loadAllProfiles()
|
||||
cache.addAll(profiles)
|
||||
initialized = true
|
||||
Log.i(AppConfig.TAG, "[FaceRecognizer] initialized, profiles=${cache.size}")
|
||||
return true
|
||||
}
|
||||
|
||||
fun identify(bitmap: Bitmap, face: FaceBox): FaceRecognitionResult {
|
||||
if (!initialized) return FaceRecognitionResult(null, 0f, 0)
|
||||
val embedding = extractEmbedding(bitmap, face)
|
||||
if (embedding.isEmpty()) return FaceRecognitionResult(null, 0f, 0)
|
||||
|
||||
var bestName: String? = null
|
||||
var bestScore = -1f
|
||||
for (p in cache) {
|
||||
if (p.embedding.size != embedding.size) continue
|
||||
val score = cosineSimilarity(embedding, p.embedding)
|
||||
if (score > bestScore) {
|
||||
bestScore = score
|
||||
bestName = p.name
|
||||
}
|
||||
}
|
||||
if (bestScore >= AppConfig.FaceRecognition.SIMILARITY_THRESHOLD) {
|
||||
return FaceRecognitionResult(bestName, bestScore, embedding.size)
|
||||
}
|
||||
return FaceRecognitionResult(null, bestScore, embedding.size)
|
||||
}
|
||||
|
||||
fun extractEmbedding(bitmap: Bitmap, face: FaceBox): FloatArray {
|
||||
if (!initialized) return FloatArray(0)
|
||||
return engine.extractEmbedding(bitmap, face.left, face.top, face.right, face.bottom)
|
||||
}
|
||||
|
||||
fun addOrUpdateProfile(name: String?, embedding: FloatArray) {
|
||||
val normalized = normalize(embedding)
|
||||
store.upsertProfile(name ?: "", normalized)
|
||||
// 移除旧的记录(如果存在)
|
||||
if (name != null) {
|
||||
cache.removeAll { it.name == name }
|
||||
}
|
||||
cache.add(FaceProfile(id = -1L, name = name ?: "", embedding = normalized))
|
||||
}
|
||||
|
||||
fun addNewFace(embedding: FloatArray): Boolean {
|
||||
Log.d(AppConfig.TAG, "[FaceRecognizer] addNewFace: embedding size=${embedding.size}, cache size=${cache.size}")
|
||||
|
||||
// 检查是否已经存在相似的人脸
|
||||
for (p in cache) {
|
||||
if (p.embedding.size != embedding.size) {
|
||||
Log.d(AppConfig.TAG, "[FaceRecognizer] Skipping profile with different embedding size: ${p.embedding.size}")
|
||||
continue
|
||||
}
|
||||
val score = cosineSimilarity(embedding, p.embedding)
|
||||
Log.d(AppConfig.TAG, "[FaceRecognizer] Comparing with profile '${p.name}': similarity=$score, threshold=${AppConfig.FaceRecognition.SIMILARITY_THRESHOLD}")
|
||||
if (score >= AppConfig.FaceRecognition.SIMILARITY_THRESHOLD) {
|
||||
// 已经存在相似的人脸,不需要添加
|
||||
Log.i(AppConfig.TAG, "[FaceRecognizer] Similar face found: ${p.name} with similarity=$score, not adding new face")
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// 添加新人脸,名字为null
|
||||
Log.i(AppConfig.TAG, "[FaceRecognizer] No similar face found, adding new face")
|
||||
addOrUpdateProfile(null, embedding)
|
||||
return true
|
||||
}
|
||||
|
||||
fun release() {
|
||||
initialized = false
|
||||
engine.release()
|
||||
store.close()
|
||||
}
|
||||
|
||||
private fun cosineSimilarity(a: FloatArray, b: FloatArray): Float {
|
||||
var dot = 0f
|
||||
var na = 0f
|
||||
var nb = 0f
|
||||
for (i in a.indices) {
|
||||
dot += a[i] * b[i]
|
||||
na += a[i] * a[i]
|
||||
nb += b[i] * b[i]
|
||||
}
|
||||
if (na <= 1e-12f || nb <= 1e-12f) return -1f
|
||||
return (dot / (sqrt(na) * sqrt(nb))).coerceIn(-1f, 1f)
|
||||
}
|
||||
|
||||
private fun normalize(v: FloatArray): FloatArray {
|
||||
var sum = 0f
|
||||
for (x in v) sum += x * x
|
||||
val norm = sqrt(sum.coerceAtLeast(1e-12f))
|
||||
val out = FloatArray(v.size)
|
||||
for (i in v.indices) out[i] = v[i] / norm
|
||||
return out
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
package com.digitalperson.face
|
||||
|
||||
import android.graphics.Bitmap
|
||||
import android.graphics.BitmapFactory
|
||||
import android.graphics.ImageFormat
|
||||
import android.graphics.Matrix
|
||||
import android.graphics.Rect
|
||||
import android.graphics.YuvImage
|
||||
import androidx.camera.core.ImageProxy
|
||||
import java.io.ByteArrayOutputStream
|
||||
|
||||
object ImageProxyBitmapConverter {
|
||||
fun toBitmap(image: ImageProxy): Bitmap? {
|
||||
val nv21 = yuv420ToNv21(image) ?: return null
|
||||
val yuvImage = YuvImage(nv21, ImageFormat.NV21, image.width, image.height, null)
|
||||
val out = ByteArrayOutputStream()
|
||||
if (!yuvImage.compressToJpeg(Rect(0, 0, image.width, image.height), 80, out)) {
|
||||
return null
|
||||
}
|
||||
val bytes = out.toByteArray()
|
||||
var bitmap = BitmapFactory.decodeByteArray(bytes, 0, bytes.size) ?: return null
|
||||
if (bitmap.config != Bitmap.Config.ARGB_8888) {
|
||||
val converted = bitmap.copy(Bitmap.Config.ARGB_8888, false)
|
||||
bitmap.recycle()
|
||||
bitmap = converted
|
||||
}
|
||||
|
||||
val matrix = Matrix()
|
||||
|
||||
// 前置摄像头需要水平翻转
|
||||
// 注意:这里假设我们使用的是前置摄像头
|
||||
// 如果需要支持后置摄像头,需要根据实际使用的摄像头类型来决定是否翻转
|
||||
matrix.postScale(-1f, 1f, bitmap.width / 2f, bitmap.height / 2f)
|
||||
|
||||
// 处理旋转
|
||||
val rotation = image.imageInfo.rotationDegrees
|
||||
if (rotation != 0) {
|
||||
matrix.postRotate(rotation.toFloat())
|
||||
}
|
||||
|
||||
// 应用变换
|
||||
val transformed = Bitmap.createBitmap(bitmap, 0, 0, bitmap.width, bitmap.height, matrix, true)
|
||||
bitmap.recycle()
|
||||
bitmap = transformed
|
||||
|
||||
return bitmap
|
||||
}
|
||||
|
||||
private fun yuv420ToNv21(image: ImageProxy): ByteArray? {
|
||||
val planes = image.planes
|
||||
if (planes.size < 3) return null
|
||||
val width = image.width
|
||||
val height = image.height
|
||||
val ySize = width * height
|
||||
val uvSize = width * height / 4
|
||||
val nv21 = ByteArray(ySize + uvSize * 2)
|
||||
|
||||
val yPlane = planes[0]
|
||||
val yBuffer = yPlane.buffer
|
||||
val yRowStride = yPlane.rowStride
|
||||
var dst = 0
|
||||
for (row in 0 until height) {
|
||||
yBuffer.position(row * yRowStride)
|
||||
yBuffer.get(nv21, dst, width)
|
||||
dst += width
|
||||
}
|
||||
|
||||
val uPlane = planes[1]
|
||||
val vPlane = planes[2]
|
||||
val uBuffer = uPlane.buffer
|
||||
val vBuffer = vPlane.buffer
|
||||
val uRowStride = uPlane.rowStride
|
||||
val vRowStride = vPlane.rowStride
|
||||
val uPixelStride = uPlane.pixelStride
|
||||
val vPixelStride = vPlane.pixelStride
|
||||
|
||||
for (row in 0 until height / 2) {
|
||||
for (col in 0 until width / 2) {
|
||||
val uIndex = row * uRowStride + col * uPixelStride
|
||||
val vIndex = row * vRowStride + col * vPixelStride
|
||||
nv21[dst++] = vBuffer.get(vIndex)
|
||||
nv21[dst++] = uBuffer.get(uIndex)
|
||||
}
|
||||
}
|
||||
return nv21
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,7 @@ class Live2DAvatarManager(private val glSurfaceView: GLSurfaceView) {
|
||||
|
||||
init {
|
||||
glSurfaceView.setEGLContextClientVersion(2)
|
||||
glSurfaceView.setPreserveEGLContextOnPause(true)
|
||||
glSurfaceView.setRenderer(renderer)
|
||||
glSurfaceView.renderMode = GLSurfaceView.RENDERMODE_CONTINUOUSLY
|
||||
}
|
||||
@@ -16,11 +17,15 @@ class Live2DAvatarManager(private val glSurfaceView: GLSurfaceView) {
|
||||
}
|
||||
|
||||
fun setMood(mood: String) {
|
||||
renderer.setMood(mood)
|
||||
glSurfaceView.queueEvent {
|
||||
renderer.setMood(mood)
|
||||
}
|
||||
}
|
||||
|
||||
fun startSpecificMotion(motionName: String) {
|
||||
renderer.startSpecificMotion(motionName)
|
||||
glSurfaceView.queueEvent {
|
||||
renderer.startSpecificMotion(motionName)
|
||||
}
|
||||
}
|
||||
|
||||
fun onResume() {
|
||||
@@ -32,6 +37,8 @@ class Live2DAvatarManager(private val glSurfaceView: GLSurfaceView) {
|
||||
}
|
||||
|
||||
fun release() {
|
||||
renderer.release()
|
||||
glSurfaceView.queueEvent {
|
||||
renderer.release()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -214,32 +214,8 @@ class Live2DCharacter : CubismUserModel() {
|
||||
}
|
||||
|
||||
private fun loadMoodMotions(assets: AssetManager, modelDir: String) {
|
||||
// 开心心情动作
|
||||
moodMotions["开心"] = listOf(
|
||||
"haru_g_m22.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m22.motion3.json"),
|
||||
"haru_g_m21.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m21.motion3.json"),
|
||||
"haru_g_m18.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m18.motion3.json")
|
||||
).mapNotNull { (fileName, motion) ->
|
||||
motion?.let {
|
||||
motionFileMap[it] = fileName
|
||||
it
|
||||
}
|
||||
}
|
||||
|
||||
// 伤心心情动作
|
||||
moodMotions["伤心"] = listOf(
|
||||
"haru_g_m25.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m25.motion3.json"),
|
||||
"haru_g_m24.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m24.motion3.json"),
|
||||
"haru_g_m05.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m05.motion3.json")
|
||||
).mapNotNull { (fileName, motion) ->
|
||||
motion?.let {
|
||||
motionFileMap[it] = fileName
|
||||
it
|
||||
}
|
||||
}
|
||||
|
||||
// 平和心情动作
|
||||
moodMotions["平和"] = listOf(
|
||||
// 中性心情动作
|
||||
moodMotions["中性"] = listOf(
|
||||
"haru_g_m15.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m15.motion3.json"),
|
||||
"haru_g_m07.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m07.motion3.json"),
|
||||
"haru_g_m06.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m06.motion3.json"),
|
||||
@@ -252,8 +228,50 @@ class Live2DCharacter : CubismUserModel() {
|
||||
}
|
||||
}
|
||||
|
||||
// 惊讶心情动作
|
||||
moodMotions["惊讶"] = listOf(
|
||||
// 悲伤心情动作
|
||||
moodMotions["悲伤"] = listOf(
|
||||
"haru_g_m25.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m25.motion3.json"),
|
||||
"haru_g_m24.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m24.motion3.json"),
|
||||
"haru_g_m05.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m05.motion3.json"),
|
||||
"haru_g_m16.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m16.motion3.json"),
|
||||
"haru_g_m20.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m20.motion3.json"),
|
||||
).mapNotNull { (fileName, motion) ->
|
||||
motion?.let {
|
||||
motionFileMap[it] = fileName
|
||||
it
|
||||
}
|
||||
}
|
||||
|
||||
// 高兴心情动作
|
||||
moodMotions["高兴"] = listOf(
|
||||
"haru_g_m22.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m22.motion3.json"),
|
||||
"haru_g_m21.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m21.motion3.json"),
|
||||
"haru_g_m18.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m18.motion3.json"),
|
||||
"haru_g_m09.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m09.motion3.json"),
|
||||
"haru_g_m08.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m08.motion3.json")
|
||||
).mapNotNull { (fileName, motion) ->
|
||||
motion?.let {
|
||||
motionFileMap[it] = fileName
|
||||
it
|
||||
}
|
||||
}
|
||||
|
||||
// 生气心情动作
|
||||
moodMotions["生气"] = listOf(
|
||||
"haru_g_m10.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m10.motion3.json"),
|
||||
"haru_g_m11.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m11.motion3.json"),
|
||||
"haru_g_m04.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m04.motion3.json"),
|
||||
"haru_g_m03.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m03.motion3.json"),
|
||||
|
||||
).mapNotNull { (fileName, motion) ->
|
||||
motion?.let {
|
||||
motionFileMap[it] = fileName
|
||||
it
|
||||
}
|
||||
}
|
||||
|
||||
// 恐惧心情动作
|
||||
moodMotions["恐惧"] = listOf(
|
||||
"haru_g_m26.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m26.motion3.json"),
|
||||
"haru_g_m12.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m12.motion3.json")
|
||||
).mapNotNull { (fileName, motion) ->
|
||||
@@ -263,18 +281,8 @@ class Live2DCharacter : CubismUserModel() {
|
||||
}
|
||||
}
|
||||
|
||||
// 关心心情动作
|
||||
moodMotions["关心"] = listOf(
|
||||
"haru_g_m17.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m17.motion3.json")
|
||||
).mapNotNull { (fileName, motion) ->
|
||||
motion?.let {
|
||||
motionFileMap[it] = fileName
|
||||
it
|
||||
}
|
||||
}
|
||||
|
||||
// 害羞心情动作
|
||||
moodMotions["害羞"] = listOf(
|
||||
// 撒娇心情动作
|
||||
moodMotions["撒娇"] = listOf(
|
||||
"haru_g_m19.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m19.motion3.json")
|
||||
).mapNotNull { (fileName, motion) ->
|
||||
motion?.let {
|
||||
@@ -282,6 +290,38 @@ class Live2DCharacter : CubismUserModel() {
|
||||
it
|
||||
}
|
||||
}
|
||||
|
||||
// 震惊心情动作
|
||||
moodMotions["震惊"] = listOf(
|
||||
"haru_g_m26.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m26.motion3.json"),
|
||||
"haru_g_m12.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m12.motion3.json")
|
||||
).mapNotNull { (fileName, motion) ->
|
||||
motion?.let {
|
||||
motionFileMap[it] = fileName
|
||||
it
|
||||
}
|
||||
}
|
||||
|
||||
// 厌恶心情动作
|
||||
moodMotions["厌恶"] = listOf(
|
||||
"haru_g_m14.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m14.motion3.json"),
|
||||
"haru_g_m13.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m13.motion3.json")
|
||||
).mapNotNull { (fileName, motion) ->
|
||||
motion?.let {
|
||||
motionFileMap[it] = fileName
|
||||
it
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
// 兼容旧的心情名称
|
||||
moodMotions["开心"] = moodMotions["高兴"] ?: emptyList()
|
||||
moodMotions["伤心"] = moodMotions["悲伤"] ?: emptyList()
|
||||
moodMotions["平和"] = moodMotions["平静"] ?: emptyList()
|
||||
moodMotions["惊讶"] = moodMotions["震惊"] ?: emptyList()
|
||||
moodMotions["关心"] = moodMotions["中性"] ?: emptyList()
|
||||
moodMotions["害羞"] = moodMotions["撒娇"] ?: emptyList()
|
||||
}
|
||||
|
||||
private fun loadSpecificMotions(assets: AssetManager, modelDir: String) {
|
||||
|
||||
@@ -14,6 +14,9 @@ import javax.microedition.khronos.opengles.GL10
|
||||
class Live2DRenderer(
|
||||
private val context: Context
|
||||
) : GLSurfaceView.Renderer {
|
||||
companion object {
|
||||
private const val TAG = "Live2DRenderer"
|
||||
}
|
||||
@Volatile
|
||||
private var speaking = false
|
||||
|
||||
@@ -25,6 +28,7 @@ class Live2DRenderer(
|
||||
GLES20.glClearColor(0f, 0f, 0f, 0f)
|
||||
ensureFrameworkInitialized()
|
||||
startTimeMs = SystemClock.elapsedRealtime()
|
||||
Log.i(TAG, "onSurfaceCreated")
|
||||
|
||||
runCatching {
|
||||
val model = Live2DCharacter()
|
||||
@@ -35,6 +39,7 @@ class Live2DRenderer(
|
||||
)
|
||||
model.bindTextures(context.assets, AppConfig.Avatar.MODEL_DIR)
|
||||
character = model
|
||||
Log.i(TAG, "Live2D model loaded and textures bound")
|
||||
}.onFailure {
|
||||
Log.e(AppConfig.TAG, "Load Live2D model failed: ${it.message}", it)
|
||||
character = null
|
||||
|
||||
46
app/src/main/java/com/digitalperson/llm/LLMManager.kt
Normal file
46
app/src/main/java/com/digitalperson/llm/LLMManager.kt
Normal file
@@ -0,0 +1,46 @@
|
||||
package com.digitalperson.llm
|
||||
|
||||
interface LLMManagerCallback {
|
||||
fun onThinking(msg: String, finished: Boolean)
|
||||
fun onResult(msg: String, finished: Boolean)
|
||||
}
|
||||
|
||||
class LLMManager(modelPath: String, callback: LLMManagerCallback) :
|
||||
RKLLM(modelPath, object : LLMCallback {
|
||||
var inThinking = false
|
||||
|
||||
override fun onCallback(data: String, state: LLMCallback.State) {
|
||||
if (state == LLMCallback.State.NORMAL) {
|
||||
if (data == "<think>") {
|
||||
inThinking = true
|
||||
return
|
||||
} else if (data == "</think>") {
|
||||
inThinking = false
|
||||
callback.onThinking("", true)
|
||||
return
|
||||
}
|
||||
|
||||
if (inThinking) {
|
||||
callback.onThinking(data, false)
|
||||
} else {
|
||||
if (data == "\n") return
|
||||
callback.onResult(data, false)
|
||||
}
|
||||
} else {
|
||||
callback.onThinking("", true)
|
||||
callback.onResult("", true)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
{
|
||||
fun generateResponse(prompt: String) {
|
||||
val msg = "<|User|>$prompt<|Assistant|>"
|
||||
say(msg)
|
||||
}
|
||||
|
||||
fun generateResponseWithSystem(systemPrompt: String, userPrompt: String) {
|
||||
val msg = "<|System|>$systemPrompt<|User|>$userPrompt<|Assistant|>"
|
||||
say(msg)
|
||||
}
|
||||
}
|
||||
52
app/src/main/java/com/digitalperson/llm/RKLLM.kt
Normal file
52
app/src/main/java/com/digitalperson/llm/RKLLM.kt
Normal file
@@ -0,0 +1,52 @@
|
||||
package com.digitalperson.llm
|
||||
|
||||
interface LLMCallback {
|
||||
enum class State {
|
||||
ERROR, NORMAL, FINISH
|
||||
}
|
||||
fun onCallback(data: String, state: State)
|
||||
}
|
||||
|
||||
open class RKLLM(modelPath: String, callback: LLMCallback) {
|
||||
companion object {
|
||||
init {
|
||||
System.loadLibrary("rkllmrt")
|
||||
}
|
||||
}
|
||||
|
||||
private var mInstance: Long
|
||||
private var mCallback: LLMCallback
|
||||
|
||||
init {
|
||||
mInstance = initLLM(modelPath)
|
||||
mCallback = callback
|
||||
if (mInstance == 0L) {
|
||||
throw IllegalStateException("RKLLM init failed: native handle is null")
|
||||
}
|
||||
}
|
||||
|
||||
fun destroy() {
|
||||
deinitLLM(mInstance)
|
||||
mInstance = 0
|
||||
}
|
||||
|
||||
protected fun say(text: String) {
|
||||
if (mInstance == 0L) {
|
||||
mCallback.onCallback("RKLLM is not initialized", LLMCallback.State.ERROR)
|
||||
return
|
||||
}
|
||||
infer(mInstance, text)
|
||||
}
|
||||
|
||||
fun callbackFromNative(data: String, state: Int) {
|
||||
var s = LLMCallback.State.ERROR
|
||||
s = if (state == 0) LLMCallback.State.FINISH
|
||||
else if (state < 0) LLMCallback.State.ERROR
|
||||
else LLMCallback.State.NORMAL
|
||||
mCallback.onCallback(data, s)
|
||||
}
|
||||
|
||||
private external fun initLLM(modelPath: String): Long
|
||||
private external fun deinitLLM(handle: Long)
|
||||
private external fun infer(handle: Long, text: String)
|
||||
}
|
||||
330
app/src/main/java/com/digitalperson/tts/QCloudTtsManager.kt
Normal file
330
app/src/main/java/com/digitalperson/tts/QCloudTtsManager.kt
Normal file
@@ -0,0 +1,330 @@
|
||||
package com.digitalperson.tts
|
||||
|
||||
import android.content.Context
|
||||
import android.media.AudioAttributes
|
||||
import android.media.AudioFormat
|
||||
import android.media.AudioManager
|
||||
import android.media.AudioTrack
|
||||
import android.util.Log
|
||||
import com.digitalperson.config.AppConfig
|
||||
import com.digitalperson.mood.MoodManager
|
||||
import com.tencent.cloud.realtime.tts.RealTimeSpeechSynthesizer
|
||||
import com.tencent.cloud.realtime.tts.RealTimeSpeechSynthesizerListener
|
||||
import com.tencent.cloud.realtime.tts.RealTimeSpeechSynthesizerRequest
|
||||
import com.tencent.cloud.realtime.tts.SpeechSynthesizerResponse
|
||||
import com.tencent.cloud.realtime.tts.core.ws.Credential
|
||||
import com.tencent.cloud.realtime.tts.core.ws.SpeechClient
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.launch
|
||||
import java.nio.ByteBuffer
|
||||
import java.util.UUID
|
||||
import java.util.concurrent.LinkedBlockingQueue
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
|
||||
class QCloudTtsManager(private val context: Context) {
|
||||
|
||||
companion object {
|
||||
private const val TAG = "QCloudTtsManager"
|
||||
private const val SAMPLE_RATE = 16000
|
||||
private val proxy = SpeechClient()
|
||||
}
|
||||
private var audioTrack: AudioTrack? = null
|
||||
private var synthesizer: RealTimeSpeechSynthesizer? = null
|
||||
|
||||
private sealed class TtsQueueItem {
|
||||
data class Segment(val text: String) : TtsQueueItem()
|
||||
data object End : TtsQueueItem()
|
||||
}
|
||||
|
||||
private val ttsQueue = LinkedBlockingQueue<TtsQueueItem>()
|
||||
private val ttsStopped = AtomicBoolean(false)
|
||||
private val ttsWorkerRunning = AtomicBoolean(false)
|
||||
private val ttsPlaying = AtomicBoolean(false)
|
||||
private val interrupting = AtomicBoolean(false)
|
||||
|
||||
private val ioScope = CoroutineScope(Dispatchers.IO)
|
||||
|
||||
interface TtsCallback {
|
||||
fun onTtsStarted(text: String)
|
||||
fun onTtsCompleted()
|
||||
fun onTtsSegmentCompleted(durationMs: Long)
|
||||
fun isTtsStopped(): Boolean
|
||||
fun onClearAsrQueue()
|
||||
fun onSetSpeaking(speaking: Boolean)
|
||||
fun onEndTurn()
|
||||
}
|
||||
|
||||
private var callback: TtsCallback? = null
|
||||
|
||||
fun setCallback(callback: TtsCallback) {
|
||||
this.callback = callback
|
||||
}
|
||||
|
||||
fun init(): Boolean {
|
||||
return try {
|
||||
initAudioTrack()
|
||||
true
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Init QCloud TTS failed: ${e.message}", e)
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
private fun initAudioTrack() {
|
||||
val bufferSize = AudioTrack.getMinBufferSize(
|
||||
SAMPLE_RATE,
|
||||
AudioFormat.CHANNEL_OUT_MONO,
|
||||
AudioFormat.ENCODING_PCM_16BIT
|
||||
)
|
||||
val attr = AudioAttributes.Builder()
|
||||
.setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
|
||||
.setUsage(AudioAttributes.USAGE_MEDIA)
|
||||
.build()
|
||||
val format = AudioFormat.Builder()
|
||||
.setEncoding(AudioFormat.ENCODING_PCM_16BIT)
|
||||
.setChannelMask(AudioFormat.CHANNEL_OUT_MONO)
|
||||
.setSampleRate(SAMPLE_RATE)
|
||||
.build()
|
||||
audioTrack = AudioTrack(
|
||||
attr,
|
||||
format,
|
||||
bufferSize,
|
||||
AudioTrack.MODE_STREAM,
|
||||
AudioManager.AUDIO_SESSION_ID_GENERATE
|
||||
)
|
||||
}
|
||||
|
||||
fun enqueueSegment(seg: String) {
|
||||
if (ttsStopped.get()) {
|
||||
ttsStopped.set(false)
|
||||
}
|
||||
val cleanedSeg = seg.trimEnd('.', '。', '!', '!', '?', '?', ',', ',', ';', ';', ':', ':')
|
||||
ttsQueue.offer(TtsQueueItem.Segment(cleanedSeg))
|
||||
ensureTtsWorker()
|
||||
}
|
||||
|
||||
fun enqueueEnd() {
|
||||
ttsQueue.offer(TtsQueueItem.End)
|
||||
}
|
||||
|
||||
fun isPlaying(): Boolean = ttsPlaying.get()
|
||||
|
||||
fun reset() {
|
||||
val workerRunning = ttsWorkerRunning.get()
|
||||
val wasStopped = ttsStopped.get()
|
||||
|
||||
ttsStopped.set(false)
|
||||
ttsPlaying.set(false)
|
||||
ttsQueue.clear()
|
||||
|
||||
if (wasStopped && workerRunning) {
|
||||
ttsQueue.offer(TtsQueueItem.End)
|
||||
}
|
||||
}
|
||||
|
||||
fun stop() {
|
||||
ttsStopped.set(true)
|
||||
ttsPlaying.set(false)
|
||||
ttsQueue.clear()
|
||||
ttsQueue.offer(TtsQueueItem.End)
|
||||
|
||||
try {
|
||||
synthesizer?.cancel()
|
||||
synthesizer = null
|
||||
audioTrack?.pause()
|
||||
audioTrack?.flush()
|
||||
} catch (_: Throwable) {
|
||||
}
|
||||
}
|
||||
|
||||
fun interruptForNewTurn(waitTimeoutMs: Long = 300): Boolean {
|
||||
if (!interrupting.compareAndSet(false, true)) return false
|
||||
try {
|
||||
val hadPendingPlayback = ttsPlaying.get() || ttsWorkerRunning.get() || ttsQueue.isNotEmpty()
|
||||
if (!hadPendingPlayback) {
|
||||
ttsStopped.set(false)
|
||||
ttsPlaying.set(false)
|
||||
return false
|
||||
}
|
||||
|
||||
ttsStopped.set(true)
|
||||
ttsPlaying.set(false)
|
||||
ttsQueue.clear()
|
||||
ttsQueue.offer(TtsQueueItem.End)
|
||||
|
||||
try {
|
||||
synthesizer?.cancel()
|
||||
synthesizer = null
|
||||
audioTrack?.pause()
|
||||
audioTrack?.flush()
|
||||
} catch (_: Throwable) {
|
||||
}
|
||||
|
||||
val deadline = System.currentTimeMillis() + waitTimeoutMs
|
||||
while (ttsWorkerRunning.get() && System.currentTimeMillis() < deadline) {
|
||||
Thread.sleep(10)
|
||||
}
|
||||
|
||||
if (ttsWorkerRunning.get()) {
|
||||
Log.w(TAG, "interruptForNewTurn timeout: worker still running")
|
||||
}
|
||||
|
||||
ttsQueue.clear()
|
||||
ttsStopped.set(false)
|
||||
ttsPlaying.set(false)
|
||||
callback?.onSetSpeaking(false)
|
||||
return true
|
||||
} finally {
|
||||
interrupting.set(false)
|
||||
}
|
||||
}
|
||||
|
||||
fun release() {
|
||||
try {
|
||||
synthesizer?.cancel()
|
||||
synthesizer = null
|
||||
} catch (_: Throwable) {
|
||||
}
|
||||
try {
|
||||
audioTrack?.release()
|
||||
audioTrack = null
|
||||
} catch (_: Throwable) {
|
||||
}
|
||||
}
|
||||
|
||||
private fun ensureTtsWorker() {
|
||||
if (!ttsWorkerRunning.compareAndSet(false, true)) return
|
||||
ioScope.launch {
|
||||
try {
|
||||
runTtsWorker()
|
||||
} finally {
|
||||
ttsWorkerRunning.set(false)
|
||||
if (!ttsStopped.get() && ttsQueue.isNotEmpty()) {
|
||||
ensureTtsWorker()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun runTtsWorker() {
|
||||
val audioTrack = audioTrack ?: return
|
||||
|
||||
while (true) {
|
||||
val item = ttsQueue.take()
|
||||
if (ttsStopped.get()) break
|
||||
|
||||
when (item) {
|
||||
is TtsQueueItem.Segment -> {
|
||||
ttsPlaying.set(true)
|
||||
callback?.onSetSpeaking(true)
|
||||
Log.d(TAG, "QCloud TTS started: processing segment '${item.text}'")
|
||||
callback?.onTtsStarted(item.text)
|
||||
|
||||
val startMs = System.currentTimeMillis()
|
||||
|
||||
try {
|
||||
if (audioTrack.playState != AudioTrack.PLAYSTATE_PLAYING) {
|
||||
audioTrack.play()
|
||||
}
|
||||
|
||||
val credential = Credential(
|
||||
AppConfig.QCloud.APP_ID,
|
||||
AppConfig.QCloud.SECRET_ID,
|
||||
AppConfig.QCloud.SECRET_KEY,
|
||||
""
|
||||
)
|
||||
|
||||
val request = RealTimeSpeechSynthesizerRequest()
|
||||
request.setVolume(0f) // 音量大小,范围[-10,10]
|
||||
request.setSpeed(0f) // 语速,范围:[-2,6]
|
||||
request.setCodec("pcm") // 返回音频格式:pcm
|
||||
request.setSampleRate(SAMPLE_RATE) // 音频采样率
|
||||
request.setVoiceType(601010) // 音色ID
|
||||
request.setEnableSubtitle(true) // 是否开启时间戳功能
|
||||
// 根据当前心情设置情感类别
|
||||
val currentMood = MoodManager.getCurrentMood()
|
||||
val emotionCategory = when (currentMood) {
|
||||
"中性" -> "neutral"
|
||||
"悲伤" -> "sad"
|
||||
"高兴" -> "happy"
|
||||
"生气" -> "angry"
|
||||
"恐惧" -> "fear"
|
||||
"撒娇" -> "sajiao"
|
||||
"震惊" -> "amaze"
|
||||
"厌恶" -> "disgusted"
|
||||
"平静" -> "peaceful"
|
||||
// 兼容旧的心情名称
|
||||
"开心" -> "happy"
|
||||
"伤心" -> "sad"
|
||||
"平和" -> "peaceful"
|
||||
"惊讶" -> "amaze"
|
||||
"关心" -> "neutral"
|
||||
"害羞" -> "sajiao"
|
||||
else -> "neutral"
|
||||
}
|
||||
request.setEmotionCategory(emotionCategory) // 控制合成音频的情感
|
||||
request.setEmotionIntensity(100) // 控制合成音频情感程度
|
||||
request.setSessionId(UUID.randomUUID().toString()) // sessionId
|
||||
request.setText(item.text) // 合成文本
|
||||
|
||||
val listener = object : RealTimeSpeechSynthesizerListener() {
|
||||
override fun onSynthesisStart(response: SpeechSynthesizerResponse) {
|
||||
Log.d(TAG, "onSynthesisStart: ${response.sessionId}")
|
||||
}
|
||||
|
||||
override fun onSynthesisEnd(response: SpeechSynthesizerResponse) {
|
||||
Log.d(TAG, "onSynthesisEnd: ${response.sessionId}")
|
||||
val ttsMs = System.currentTimeMillis() - startMs
|
||||
callback?.onTtsSegmentCompleted(ttsMs)
|
||||
}
|
||||
|
||||
override fun onAudioResult(buffer: ByteBuffer) {
|
||||
val data = ByteArray(buffer.remaining())
|
||||
buffer.get(data)
|
||||
// 播放pcm
|
||||
audioTrack.write(data, 0, data.size)
|
||||
}
|
||||
|
||||
override fun onTextResult(response: SpeechSynthesizerResponse) {
|
||||
Log.d(TAG, "onTextResult: ${response.sessionId}")
|
||||
}
|
||||
|
||||
override fun onSynthesisCancel() {
|
||||
Log.d(TAG, "onSynthesisCancel")
|
||||
}
|
||||
|
||||
override fun onSynthesisFail(response: SpeechSynthesizerResponse) {
|
||||
Log.e(TAG, "onSynthesisFail: ${response.sessionId}, error: ${response.message}")
|
||||
}
|
||||
}
|
||||
|
||||
synthesizer = RealTimeSpeechSynthesizer(proxy, credential, request, listener)
|
||||
synthesizer?.start()
|
||||
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "QCloud TTS error: ${e.message}", e)
|
||||
}
|
||||
}
|
||||
|
||||
TtsQueueItem.End -> {
|
||||
callback?.onClearAsrQueue()
|
||||
|
||||
waitForPlaybackComplete(audioTrack)
|
||||
|
||||
callback?.onTtsCompleted()
|
||||
|
||||
ttsPlaying.set(false)
|
||||
callback?.onSetSpeaking(false)
|
||||
callback?.onEndTurn()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun waitForPlaybackComplete(audioTrack: AudioTrack) {
|
||||
// 等待音频播放完成
|
||||
Thread.sleep(1000)
|
||||
}
|
||||
}
|
||||
181
app/src/main/java/com/digitalperson/tts/TtsController.kt
Normal file
181
app/src/main/java/com/digitalperson/tts/TtsController.kt
Normal file
@@ -0,0 +1,181 @@
|
||||
package com.digitalperson.tts
|
||||
|
||||
import android.content.Context
|
||||
import android.util.Log
|
||||
|
||||
class TtsController(private val context: Context) {
|
||||
|
||||
companion object {
|
||||
private const val TAG = "TtsController"
|
||||
}
|
||||
|
||||
private var localTts: TtsManager? = null
|
||||
private var qcloudTts: QCloudTtsManager? = null
|
||||
private var useQCloudTts = false
|
||||
|
||||
interface TtsCallback {
|
||||
fun onTtsStarted(text: String)
|
||||
fun onTtsCompleted()
|
||||
fun onTtsSegmentCompleted(durationMs: Long)
|
||||
fun isTtsStopped(): Boolean
|
||||
fun onClearAsrQueue()
|
||||
fun onSetSpeaking(speaking: Boolean)
|
||||
fun onEndTurn()
|
||||
}
|
||||
|
||||
private var callback: TtsCallback? = null
|
||||
|
||||
fun setCallback(callback: TtsCallback) {
|
||||
this.callback = callback
|
||||
localTts?.setCallback(object : TtsManager.TtsCallback {
|
||||
override fun onTtsStarted(text: String) {
|
||||
callback.onTtsStarted(text)
|
||||
}
|
||||
|
||||
override fun onTtsCompleted() {
|
||||
callback.onTtsCompleted()
|
||||
}
|
||||
|
||||
override fun onTtsSegmentCompleted(durationMs: Long) {
|
||||
callback.onTtsSegmentCompleted(durationMs)
|
||||
}
|
||||
|
||||
override fun isTtsStopped(): Boolean {
|
||||
return callback.isTtsStopped()
|
||||
}
|
||||
|
||||
override fun onClearAsrQueue() {
|
||||
callback.onClearAsrQueue()
|
||||
}
|
||||
|
||||
override fun onSetSpeaking(speaking: Boolean) {
|
||||
callback.onSetSpeaking(speaking)
|
||||
}
|
||||
|
||||
override fun getCurrentTrace() = null
|
||||
|
||||
override fun onTraceMarkTtsRequestEnqueued() {
|
||||
}
|
||||
|
||||
override fun onTraceMarkTtsSynthesisStart() {
|
||||
}
|
||||
|
||||
override fun onTraceMarkTtsFirstPcmReady() {
|
||||
}
|
||||
|
||||
override fun onTraceMarkTtsFirstAudioPlay() {
|
||||
}
|
||||
|
||||
override fun onTraceMarkTtsDone() {
|
||||
}
|
||||
|
||||
override fun onTraceAddDuration(name: String, value: Long) {
|
||||
}
|
||||
|
||||
override fun onEndTurn() {
|
||||
callback.onEndTurn()
|
||||
}
|
||||
})
|
||||
qcloudTts?.setCallback(object : QCloudTtsManager.TtsCallback {
|
||||
override fun onTtsStarted(text: String) {
|
||||
callback.onTtsStarted(text)
|
||||
}
|
||||
|
||||
override fun onTtsCompleted() {
|
||||
callback.onTtsCompleted()
|
||||
}
|
||||
|
||||
override fun onTtsSegmentCompleted(durationMs: Long) {
|
||||
callback.onTtsSegmentCompleted(durationMs)
|
||||
}
|
||||
|
||||
override fun isTtsStopped(): Boolean {
|
||||
return callback.isTtsStopped()
|
||||
}
|
||||
|
||||
override fun onClearAsrQueue() {
|
||||
callback.onClearAsrQueue()
|
||||
}
|
||||
|
||||
override fun onSetSpeaking(speaking: Boolean) {
|
||||
callback.onSetSpeaking(speaking)
|
||||
}
|
||||
|
||||
override fun onEndTurn() {
|
||||
callback.onEndTurn()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fun init(): Boolean {
|
||||
// 初始化本地TTS
|
||||
localTts = TtsManager(context)
|
||||
val localInit = localTts?.initTtsAndAudioTrack() ?: false
|
||||
Log.d(TAG, "Local TTS init: $localInit")
|
||||
|
||||
// 初始化腾讯云TTS
|
||||
qcloudTts = QCloudTtsManager(context)
|
||||
val qcloudInit = qcloudTts?.init() ?: false
|
||||
Log.d(TAG, "QCloud TTS init: $qcloudInit")
|
||||
|
||||
return localInit || qcloudInit
|
||||
}
|
||||
|
||||
fun setUseQCloudTts(useQCloud: Boolean) {
|
||||
this.useQCloudTts = useQCloud
|
||||
Log.d(TAG, "TTS mode changed: ${if (useQCloud) "QCloud" else "Local"}")
|
||||
}
|
||||
|
||||
fun enqueueSegment(seg: String) {
|
||||
if (useQCloudTts) {
|
||||
qcloudTts?.enqueueSegment(seg)
|
||||
} else {
|
||||
localTts?.enqueueSegment(seg)
|
||||
}
|
||||
}
|
||||
|
||||
fun enqueueEnd() {
|
||||
if (useQCloudTts) {
|
||||
qcloudTts?.enqueueEnd()
|
||||
} else {
|
||||
localTts?.enqueueEnd()
|
||||
}
|
||||
}
|
||||
|
||||
fun isPlaying(): Boolean {
|
||||
return if (useQCloudTts) {
|
||||
qcloudTts?.isPlaying() ?: false
|
||||
} else {
|
||||
localTts?.isPlaying() ?: false
|
||||
}
|
||||
}
|
||||
|
||||
fun reset() {
|
||||
if (useQCloudTts) {
|
||||
qcloudTts?.reset()
|
||||
} else {
|
||||
localTts?.reset()
|
||||
}
|
||||
}
|
||||
|
||||
fun stop() {
|
||||
if (useQCloudTts) {
|
||||
qcloudTts?.stop()
|
||||
} else {
|
||||
localTts?.stop()
|
||||
}
|
||||
}
|
||||
|
||||
fun interruptForNewTurn(waitTimeoutMs: Long = 300): Boolean {
|
||||
return if (useQCloudTts) {
|
||||
qcloudTts?.interruptForNewTurn(waitTimeoutMs) ?: false
|
||||
} else {
|
||||
localTts?.interruptForNewTurn(waitTimeoutMs) ?: false
|
||||
}
|
||||
}
|
||||
|
||||
fun release() {
|
||||
localTts?.release()
|
||||
qcloudTts?.release()
|
||||
}
|
||||
}
|
||||
@@ -1,12 +1,14 @@
|
||||
package com.digitalperson.ui
|
||||
|
||||
import android.app.Activity
|
||||
import android.app.ProgressDialog
|
||||
import android.opengl.GLSurfaceView
|
||||
import android.text.method.ScrollingMovementMethod
|
||||
import android.view.MotionEvent
|
||||
import android.widget.Button
|
||||
import android.widget.LinearLayout
|
||||
import android.widget.ScrollView
|
||||
import android.widget.Switch
|
||||
import android.widget.TextView
|
||||
import android.widget.Toast
|
||||
import com.digitalperson.live2d.Live2DAvatarManager
|
||||
@@ -18,7 +20,10 @@ class Live2DUiManager(private val activity: Activity) {
|
||||
private var stopButton: Button? = null
|
||||
private var recordButton: Button? = null
|
||||
private var traditionalButtons: LinearLayout? = null
|
||||
private var llmModeSwitch: Switch? = null
|
||||
private var llmModeSwitchRow: LinearLayout? = null
|
||||
private var avatarManager: Live2DAvatarManager? = null
|
||||
private var downloadProgressDialog: ProgressDialog? = null
|
||||
|
||||
private var lastUiText: String = ""
|
||||
|
||||
@@ -29,6 +34,8 @@ class Live2DUiManager(private val activity: Activity) {
|
||||
stopButtonId: Int = -1,
|
||||
recordButtonId: Int = -1,
|
||||
traditionalButtonsId: Int = -1,
|
||||
llmModeSwitchId: Int = -1,
|
||||
llmModeSwitchRowId: Int = -1,
|
||||
silentPlayerViewId: Int,
|
||||
speakingPlayerViewId: Int,
|
||||
live2dViewId: Int
|
||||
@@ -39,12 +46,17 @@ class Live2DUiManager(private val activity: Activity) {
|
||||
if (stopButtonId != -1) stopButton = activity.findViewById(stopButtonId)
|
||||
if (recordButtonId != -1) recordButton = activity.findViewById(recordButtonId)
|
||||
if (traditionalButtonsId != -1) traditionalButtons = activity.findViewById(traditionalButtonsId)
|
||||
if (llmModeSwitchId != -1) llmModeSwitch = activity.findViewById(llmModeSwitchId)
|
||||
if (llmModeSwitchRowId != -1) llmModeSwitchRow = activity.findViewById(llmModeSwitchRowId)
|
||||
|
||||
textView?.movementMethod = ScrollingMovementMethod()
|
||||
|
||||
val glView = activity.findViewById<GLSurfaceView>(live2dViewId)
|
||||
avatarManager = Live2DAvatarManager(glView)
|
||||
avatarManager?.setSpeaking(false)
|
||||
|
||||
// 默认隐藏本地 LLM 开关
|
||||
llmModeSwitchRow?.visibility = LinearLayout.GONE
|
||||
}
|
||||
|
||||
fun setStartButtonListener(listener: () -> Unit) {
|
||||
@@ -131,6 +143,72 @@ class Live2DUiManager(private val activity: Activity) {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 显示或隐藏本地 LLM 开关
|
||||
*/
|
||||
fun showLLMSwitch(show: Boolean) {
|
||||
activity.runOnUiThread {
|
||||
llmModeSwitchRow?.visibility = if (show) LinearLayout.VISIBLE else LinearLayout.GONE
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置 LLM 模式开关的监听器
|
||||
*/
|
||||
fun setLLMSwitchListener(listener: (Boolean) -> Unit) {
|
||||
llmModeSwitch?.setOnCheckedChangeListener { _, isChecked ->
|
||||
listener(isChecked)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置 LLM 模式开关的状态
|
||||
*/
|
||||
fun setLLMSwitchChecked(checked: Boolean) {
|
||||
activity.runOnUiThread {
|
||||
llmModeSwitch?.isChecked = checked
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 显示下载进度对话框
|
||||
*/
|
||||
fun showDownloadProgressDialog() {
|
||||
activity.runOnUiThread {
|
||||
downloadProgressDialog = ProgressDialog(activity).apply {
|
||||
setTitle("下载模型")
|
||||
setMessage("正在下载 LLM 模型文件,请稍候...")
|
||||
setProgressStyle(ProgressDialog.STYLE_HORIZONTAL)
|
||||
isIndeterminate = false
|
||||
setCancelable(false)
|
||||
setCanceledOnTouchOutside(false)
|
||||
show()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 更新下载进度
|
||||
*/
|
||||
fun updateDownloadProgress(fileName: String, downloadedMB: Long, totalMB: Long, progress: Int) {
|
||||
activity.runOnUiThread {
|
||||
downloadProgressDialog?.apply {
|
||||
setMessage("正在下载: $fileName\n$downloadedMB MB / $totalMB MB")
|
||||
setProgress(progress)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 关闭下载进度对话框
|
||||
*/
|
||||
fun dismissDownloadProgressDialog() {
|
||||
activity.runOnUiThread {
|
||||
downloadProgressDialog?.dismiss()
|
||||
downloadProgressDialog = null
|
||||
}
|
||||
}
|
||||
|
||||
fun onResume() {
|
||||
avatarManager?.onResume()
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package com.digitalperson.util
|
||||
|
||||
import android.content.ContentUris
|
||||
import android.content.Context
|
||||
import android.provider.MediaStore
|
||||
import android.util.Log
|
||||
import com.digitalperson.config.AppConfig
|
||||
import java.io.File
|
||||
@@ -49,12 +51,269 @@ object FileHelper {
|
||||
return copyAssetsToInternal(context, AppConfig.Asr.MODEL_DIR, outDir, files)
|
||||
}
|
||||
|
||||
@JvmStatic
|
||||
fun copyRetinaFaceAssets(context: Context): File {
|
||||
val outDir = File(context.filesDir, AppConfig.Face.MODEL_DIR)
|
||||
val files = arrayOf(AppConfig.Face.MODEL_NAME)
|
||||
return copyAssetsToInternal(context, AppConfig.Face.MODEL_DIR, outDir, files)
|
||||
}
|
||||
|
||||
@JvmStatic
|
||||
fun copyInsightFaceAssets(context: Context): File {
|
||||
val outDir = File(context.filesDir, AppConfig.FaceRecognition.MODEL_DIR)
|
||||
val files = arrayOf(AppConfig.FaceRecognition.MODEL_NAME)
|
||||
return copyAssetsToInternal(context, AppConfig.FaceRecognition.MODEL_DIR, outDir, files)
|
||||
}
|
||||
|
||||
fun ensureDir(dir: File): File {
|
||||
if (!dir.exists()) dir.mkdirs()
|
||||
if (!dir.exists()) {
|
||||
val created = dir.mkdirs()
|
||||
if (!created) {
|
||||
Log.e(TAG, "Failed to create directory: ${dir.absolutePath}")
|
||||
// 如果创建失败,使用应用内部存储
|
||||
return File("/data/data/${dir.parentFile?.parentFile?.name}/files/llm")
|
||||
}
|
||||
}
|
||||
return dir
|
||||
}
|
||||
|
||||
fun getAsrAudioDir(context: Context): File {
|
||||
return ensureDir(File(context.filesDir, "asr_audio"))
|
||||
}
|
||||
|
||||
// @JvmStatic
|
||||
// 当前使用的模型文件名
|
||||
private const val MODEL_FILE_NAME = "Qwen3-0.6B-rk3588-w8a8.rkllm"
|
||||
|
||||
fun getLLMModelPath(context: Context): String {
|
||||
Log.d(TAG, "=== getLLMModelPath START ===")
|
||||
|
||||
// 从应用内部存储目录加载模型
|
||||
val llmDir = ensureDir(File(context.filesDir, "llm"))
|
||||
|
||||
Log.d(TAG, "Loading models from: ${llmDir.absolutePath}")
|
||||
|
||||
// 检查文件是否存在
|
||||
val rkllmFile = File(llmDir, MODEL_FILE_NAME)
|
||||
|
||||
if (!rkllmFile.exists()) {
|
||||
Log.e(TAG, "RKLLM model not found: ${rkllmFile.absolutePath}")
|
||||
} else {
|
||||
Log.i(TAG, "RKLLM model exists, size: ${rkllmFile.length() / (1024*1024)} MB")
|
||||
}
|
||||
|
||||
val modelPath = rkllmFile.absolutePath
|
||||
Log.i(TAG, "Using RKLLM model path: $modelPath")
|
||||
Log.d(TAG, "=== getLLMModelPath END ===")
|
||||
return modelPath
|
||||
}
|
||||
|
||||
/**
|
||||
* 异步下载模型文件,带进度回调
|
||||
* @param context 上下文
|
||||
* @param onProgress 进度回调 (currentFile, downloadedBytes, totalBytes, progressPercent)
|
||||
* @param onComplete 完成回调 (success, message)
|
||||
*/
|
||||
@JvmStatic
|
||||
fun downloadModelFilesWithProgress(
|
||||
context: Context,
|
||||
onProgress: (String, Long, Long, Int) -> Unit,
|
||||
onComplete: (Boolean, String) -> Unit
|
||||
) {
|
||||
Log.d(TAG, "=== downloadModelFilesWithProgress START ===")
|
||||
|
||||
val llmDir = ensureDir(File(context.filesDir, "llm"))
|
||||
|
||||
// 模型文件列表 - 使用 DeepSeek-R1-Distill-Qwen-1.5B 模型
|
||||
val modelFiles = listOf(
|
||||
MODEL_FILE_NAME
|
||||
)
|
||||
|
||||
// 在后台线程下载
|
||||
Thread {
|
||||
try {
|
||||
var allSuccess = true
|
||||
var totalDownloaded: Long = 0
|
||||
var totalSize: Long = 0
|
||||
|
||||
// 首先计算总大小
|
||||
for (fileName in modelFiles) {
|
||||
val modelFile = File(llmDir, fileName)
|
||||
if (!modelFile.exists() || modelFile.length() == 0L) {
|
||||
val size = getFileSizeFromServer("http://192.168.1.19:5000/download/$fileName")
|
||||
if (size > 0) {
|
||||
totalSize += size
|
||||
} else {
|
||||
// 如果无法获取文件大小,使用估计值
|
||||
when (fileName) {
|
||||
MODEL_FILE_NAME -> totalSize += 1L * 1024 * 1024 * 1024 // 1.5B模型约1GB
|
||||
else -> totalSize += 1L * 1024 * 1024 * 1024 // 1GB 默认
|
||||
}
|
||||
Log.i(TAG, "Using estimated size for $fileName: ${totalSize / (1024*1024)} MB")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (fileName in modelFiles) {
|
||||
val modelFile = File(llmDir, fileName)
|
||||
if (!modelFile.exists() || modelFile.length() == 0L) {
|
||||
Log.i(TAG, "Downloading model file: $fileName")
|
||||
try {
|
||||
downloadFileWithProgress(
|
||||
"http://192.168.1.19:5000/download/$fileName",
|
||||
modelFile
|
||||
) { downloaded, total ->
|
||||
val progress = if (totalSize > 0) {
|
||||
((totalDownloaded + downloaded) * 100 / totalSize).toInt()
|
||||
} else 0
|
||||
onProgress(fileName, downloaded, total, progress)
|
||||
}
|
||||
totalDownloaded += modelFile.length()
|
||||
Log.i(TAG, "Downloaded model file: $fileName, size: ${modelFile.length() / (1024*1024)} MB")
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Failed to download model file $fileName: ${e.message}")
|
||||
allSuccess = false
|
||||
}
|
||||
} else {
|
||||
totalDownloaded += modelFile.length()
|
||||
Log.i(TAG, "Model file exists: $fileName, size: ${modelFile.length() / (1024*1024)} MB")
|
||||
}
|
||||
}
|
||||
Log.d(TAG, "=== downloadModelFilesWithProgress END ===")
|
||||
if (allSuccess) {
|
||||
onComplete(true, "模型下载完成")
|
||||
} else {
|
||||
onComplete(false, "部分模型下载失败")
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Download failed: ${e.message}")
|
||||
onComplete(false, "下载失败: ${e.message}")
|
||||
}
|
||||
}.start()
|
||||
}
|
||||
|
||||
/**
|
||||
* 从服务器获取文件大小
|
||||
*/
|
||||
private fun getFileSizeFromServer(url: String): Long {
|
||||
return try {
|
||||
val connection = java.net.URL(url).openConnection() as java.net.HttpURLConnection
|
||||
connection.requestMethod = "HEAD"
|
||||
connection.connectTimeout = 15000
|
||||
connection.readTimeout = 15000
|
||||
|
||||
// 从响应头获取 Content-Length,避免 int 溢出
|
||||
val contentLengthStr = connection.getHeaderField("Content-Length")
|
||||
var size = 0L
|
||||
|
||||
if (contentLengthStr != null) {
|
||||
try {
|
||||
size = contentLengthStr.toLong()
|
||||
if (size < 0) {
|
||||
Log.w(TAG, "Invalid Content-Length value: $size")
|
||||
size = 0
|
||||
}
|
||||
} catch (e: NumberFormatException) {
|
||||
Log.w(TAG, "Invalid Content-Length format: $contentLengthStr")
|
||||
size = 0
|
||||
}
|
||||
} else {
|
||||
val contentLength = connection.contentLength
|
||||
if (contentLength > 0) {
|
||||
size = contentLength.toLong()
|
||||
} else {
|
||||
Log.w(TAG, "Content-Length not available or invalid: $contentLength")
|
||||
size = 0
|
||||
}
|
||||
}
|
||||
|
||||
connection.disconnect()
|
||||
Log.i(TAG, "File size for $url: $size bytes")
|
||||
size
|
||||
} catch (e: Exception) {
|
||||
Log.w(TAG, "Failed to get file size: ${e.message}")
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 从网络下载文件,带进度回调
|
||||
*/
|
||||
private fun downloadFileWithProgress(
|
||||
url: String,
|
||||
destination: File,
|
||||
onProgress: (Long, Long) -> Unit
|
||||
) {
|
||||
val connection = java.net.URL(url).openConnection() as java.net.HttpURLConnection
|
||||
connection.connectTimeout = 30000
|
||||
connection.readTimeout = 6000000
|
||||
|
||||
// 从响应头获取 Content-Length,避免 int 溢出
|
||||
val contentLengthStr = connection.getHeaderField("Content-Length")
|
||||
val totalSize = if (contentLengthStr != null) {
|
||||
try {
|
||||
contentLengthStr.toLong()
|
||||
} catch (e: NumberFormatException) {
|
||||
Log.w(TAG, "Invalid Content-Length format: $contentLengthStr")
|
||||
0
|
||||
}
|
||||
} else {
|
||||
connection.contentLength.toLong()
|
||||
}
|
||||
Log.i(TAG, "Downloading file $url, size: $totalSize bytes")
|
||||
|
||||
try {
|
||||
connection.inputStream.use { input ->
|
||||
FileOutputStream(destination).use { output ->
|
||||
val buffer = ByteArray(8192)
|
||||
var downloaded: Long = 0
|
||||
var bytesRead: Int
|
||||
|
||||
while (input.read(buffer).also { bytesRead = it } != -1) {
|
||||
output.write(buffer, 0, bytesRead)
|
||||
downloaded += bytesRead
|
||||
onProgress(downloaded, totalSize)
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
connection.disconnect()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查本地 LLM 模型是否可用
|
||||
*/
|
||||
@JvmStatic
|
||||
fun isLocalLLMAvailable(context: Context): Boolean {
|
||||
val llmDir = File(context.filesDir, "llm")
|
||||
|
||||
val rkllmFile = File(llmDir, MODEL_FILE_NAME)
|
||||
|
||||
val rkllmExists = rkllmFile.exists() && rkllmFile.length() > 0
|
||||
|
||||
Log.i(TAG, "LLM model check: rkllm=$rkllmExists")
|
||||
Log.i(TAG, "RKLLM file: ${rkllmFile.absolutePath}, size: ${if (rkllmFile.exists()) rkllmFile.length() / (1024*1024) else 0} MB")
|
||||
|
||||
return rkllmExists
|
||||
}
|
||||
|
||||
/**
|
||||
* 从网络下载文件
|
||||
*/
|
||||
private fun downloadFile(url: String, destination: File) {
|
||||
val connection = java.net.URL(url).openConnection() as java.net.HttpURLConnection
|
||||
connection.connectTimeout = 30000 // 30秒超时
|
||||
connection.readTimeout = 60000 // 60秒读取超时
|
||||
|
||||
try {
|
||||
connection.inputStream.use { input ->
|
||||
FileOutputStream(destination).use { output ->
|
||||
input.copyTo(output)
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
connection.disconnect()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BIN
app/src/main/jniLibs/arm64-v8a/libomp.so
Normal file
BIN
app/src/main/jniLibs/arm64-v8a/libomp.so
Normal file
Binary file not shown.
BIN
app/src/main/jniLibs/arm64-v8a/librkllmrt.so
Normal file
BIN
app/src/main/jniLibs/arm64-v8a/librkllmrt.so
Normal file
Binary file not shown.
BIN
app/src/main/jniLibs/arm64-v8a/librknnrt.so.new
Normal file
BIN
app/src/main/jniLibs/arm64-v8a/librknnrt.so.new
Normal file
Binary file not shown.
@@ -16,11 +16,36 @@
|
||||
app:layout_constraintStart_toStartOf="parent"
|
||||
app:layout_constraintTop_toTopOf="parent" />
|
||||
|
||||
<FrameLayout
|
||||
android:id="@+id/face_preview_container"
|
||||
android:layout_width="220dp"
|
||||
android:layout_height="300dp"
|
||||
android:layout_margin="12dp"
|
||||
android:background="#55000000"
|
||||
app:layout_constraintEnd_toEndOf="parent"
|
||||
app:layout_constraintTop_toTopOf="parent">
|
||||
|
||||
<androidx.camera.view.PreviewView
|
||||
android:id="@+id/camera_preview"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="match_parent"
|
||||
app:implementationMode="compatible"
|
||||
app:scaleType="fitCenter" />
|
||||
|
||||
<com.digitalperson.face.FaceOverlayView
|
||||
android:id="@+id/face_overlay"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="match_parent" />
|
||||
</FrameLayout>
|
||||
|
||||
<ScrollView
|
||||
android:id="@+id/scroll_view"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_width="0dp"
|
||||
android:layout_height="200dp"
|
||||
android:fillViewport="true">
|
||||
android:fillViewport="true"
|
||||
app:layout_constraintBottom_toTopOf="@+id/llm_mode_switch_row"
|
||||
app:layout_constraintEnd_toEndOf="parent"
|
||||
app:layout_constraintStart_toStartOf="parent">
|
||||
|
||||
<TextView
|
||||
android:id="@+id/my_text"
|
||||
@@ -33,6 +58,31 @@
|
||||
android:textIsSelectable="true" />
|
||||
</ScrollView>
|
||||
|
||||
<LinearLayout
|
||||
android:id="@+id/llm_mode_switch_row"
|
||||
android:layout_width="0dp"
|
||||
android:layout_height="wrap_content"
|
||||
android:gravity="center_vertical"
|
||||
android:orientation="horizontal"
|
||||
android:padding="16dp"
|
||||
app:layout_constraintBottom_toTopOf="@+id/streaming_switch_row"
|
||||
app:layout_constraintEnd_toEndOf="parent"
|
||||
app:layout_constraintStart_toStartOf="parent">
|
||||
|
||||
<TextView
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content"
|
||||
android:layout_marginEnd="16dp"
|
||||
android:text="本地LLM"
|
||||
android:textSize="16sp" />
|
||||
|
||||
<Switch
|
||||
android:id="@+id/llm_mode_switch"
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content"
|
||||
android:checked="false" />
|
||||
</LinearLayout>
|
||||
|
||||
<LinearLayout
|
||||
android:id="@+id/streaming_switch_row"
|
||||
android:layout_width="0dp"
|
||||
@@ -40,6 +90,31 @@
|
||||
android:gravity="center_vertical"
|
||||
android:orientation="horizontal"
|
||||
android:padding="16dp"
|
||||
app:layout_constraintBottom_toTopOf="@+id/tts_mode_switch_row"
|
||||
app:layout_constraintEnd_toEndOf="parent"
|
||||
app:layout_constraintStart_toStartOf="parent">
|
||||
|
||||
<TextView
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content"
|
||||
android:layout_marginEnd="16dp"
|
||||
android:text="流式输出"
|
||||
android:textSize="16sp" />
|
||||
|
||||
<Switch
|
||||
android:id="@+id/streaming_switch"
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content"
|
||||
android:checked="false" />
|
||||
</LinearLayout>
|
||||
|
||||
<LinearLayout
|
||||
android:id="@+id/tts_mode_switch_row"
|
||||
android:layout_width="0dp"
|
||||
android:layout_height="wrap_content"
|
||||
android:gravity="center_vertical"
|
||||
android:orientation="horizontal"
|
||||
android:padding="16dp"
|
||||
app:layout_constraintBottom_toTopOf="@+id/button_row"
|
||||
app:layout_constraintEnd_toEndOf="parent"
|
||||
app:layout_constraintStart_toStartOf="parent">
|
||||
@@ -48,11 +123,11 @@
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content"
|
||||
android:layout_marginEnd="16dp"
|
||||
android:text="流式输出"
|
||||
android:text="腾讯云TTS"
|
||||
android:textSize="16sp" />
|
||||
|
||||
<Switch
|
||||
android:id="@+id/streaming_switch"
|
||||
android:id="@+id/tts_mode_switch"
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content"
|
||||
android:checked="false" />
|
||||
|
||||
@@ -3,5 +3,5 @@
|
||||
<string name="start">开始</string>
|
||||
<string name="stop">结束</string>
|
||||
<string name="hint">点击“开始”说话;识别后会请求大模型并用 TTS 播放回复。</string>
|
||||
<string name="system_prompt">你是一名小学女老师,喜欢回答学生的各种问题,请简洁但温柔地回答,每个回答不超过30字。在每次回复的最前面,用方括号标注你的心情,格式为[开心/伤心/愤怒/平和/惊讶/关心/害羞],例如:[开心]同学你好呀!请问有什么问题吗?</string>
|
||||
<string name="system_prompt">你是一名小学女老师,喜欢回答学生的各种问题,请简洁但温柔地回答,每个回答不超过30字。在每次回复的最前面,用方括号标注你的心情,格式为[中性、悲伤、高兴、生气、恐惧、撒娇、震惊、厌恶],例如:[高兴]同学你好呀!请问有什么问题吗?</string>
|
||||
</resources>
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
# http://www.gradle.org/docs/current/userguide/build_environment.html
|
||||
# Specifies the JVM arguments used for the daemon process.
|
||||
# The setting is particularly useful for tweaking memory settings.
|
||||
org.gradle.jvmargs=-Xmx6g -Dfile.encoding=UTF-8
|
||||
org.gradle.jvmargs=-Xmx8g -Dfile.encoding=UTF-8
|
||||
# When configured, Gradle will run in incubating parallel mode.
|
||||
# This option should only be used with decoupled projects. More details, visit
|
||||
# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects
|
||||
|
||||
Reference in New Issue
Block a user