local llm supported

This commit is contained in:
gcw_4spBpAfv
2026-03-05 13:55:57 +08:00
parent 1701ecfb7f
commit bd07a7526a
43 changed files with 4258 additions and 115 deletions

View File

@@ -1,3 +0,0 @@
{
"java.configuration.updateBuildConfiguration": "interactive"
}

View File

@@ -5,7 +5,7 @@ plugins {
android {
namespace 'com.digitalperson'
compileSdk 33
compileSdk 34
buildFeatures {
buildConfig true
@@ -24,6 +24,11 @@ android {
}
}
// 不压缩大文件,避免内存不足错误
aaptOptions {
noCompress 'rknn', 'rkllm', 'onnx', 'model', 'bin', 'json'
}
defaultConfig {
applicationId "com.digitalperson"
minSdk 21
@@ -40,7 +45,7 @@ android {
buildConfigField "String", "LLM_API_URL", "\"${(project.findProperty('LLM_API_URL') ?: 'https://ark.cn-beijing.volces.com/api/v3/chat/completions').toString()}\""
buildConfigField "String", "LLM_API_KEY", "\"${(project.findProperty('LLM_API_KEY') ?: '').toString()}\""
buildConfigField "String", "LLM_MODEL", "\"${(project.findProperty('LLM_MODEL') ?: 'doubao-1-5-pro-32k-character-250228').toString()}\""
buildConfigField "boolean", "USE_LIVE2D", "${(project.findProperty('USE_LIVE2D') ?: 'false').toString()}"
buildConfigField "boolean", "USE_LIVE2D", "${(project.findProperty('USE_LIVE2D') ?: 'true').toString()}"
ndk {
abiFilters "arm64-v8a"
@@ -63,6 +68,8 @@ android {
}
dependencies {
implementation 'androidx.core:core-ktx:1.7.0'
implementation 'androidx.appcompat:appcompat:1.6.1'
implementation 'com.google.android.material:material:1.9.0'
@@ -73,6 +80,15 @@ dependencies {
androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.1'
// ExoPlayer for video playback (used to show silent / speaking videos)
implementation 'com.google.android.exoplayer:exoplayer:2.18.6'
implementation 'androidx.camera:camera-core:1.3.4'
implementation 'androidx.camera:camera-camera2:1.3.4'
implementation 'androidx.camera:camera-lifecycle:1.3.4'
implementation 'androidx.camera:camera-view:1.3.4'
implementation project(':framework')
implementation files('../Live2DFramework/Core/android/Live2DCubismCore.aar')
// Tencent Cloud TTS SDK
implementation files('libs/realtime_tts-release-v2.0.16-20260128-d80cafe.aar')
implementation 'com.google.code.gson:gson:2.8.9'
implementation 'com.squareup.okhttp3:okhttp:4.9.3'
}

View File

@@ -71,6 +71,7 @@ TTS Sherpa-ONNX VITS .onnx ❌ 否 CPU ONNX Runtime
- haru_g_m25 - 扁嘴
- haru_g_m24 - 低头斜看地板,收手到背后
- haru_g_m05 扁嘴,张开双手
- haru_g_m16 双手捧腮,思考
### 😠 愤怒类情绪
- haru_g_m11 双手交叉,摇头,扁嘴
@@ -90,8 +91,6 @@ TTS Sherpa-ONNX VITS .onnx ❌ 否 CPU ONNX Runtime
- haru_g_m12 摆手,摇头
### 😕 困惑类情绪
- haru_g_m20 手指点腮,思考,皱眉
- haru_g_m16 双手捧腮,思考
- haru_g_m14 身体前倾,皱眉
- haru_g_m13 身体前倾,双手分开
@@ -99,6 +98,22 @@ TTS Sherpa-ONNX VITS .onnx ❌ 否 CPU ONNX Runtime
- haru_g_m19 脸红微笑
### 担心
- haru_g_m20 手指点腮,思考,皱眉
### ❤️ 关心类情绪
- haru_g_m17 靠近侧脸
6. 其实可以抄一下讯飞的超脑平台的功能:
https://aiui-doc.xf-yun.com/project-2/doc-397/
7. 人脸检测使用的是RKNN zoo里的 retinaface模型转成了rknn格式并且使用了wider_face的数据集验证集进行了校准下载地址
https://www.modelscope.cn/datasets/shaoxuan/WIDER_FACE/files
8. 人脸识别模型是insightface的r18模型转成了rknn格式并且使用了 lfw 的数据集进行了校准,下载地址:
https://tianchi.aliyun.com/dataset/93864
9.

View File

@@ -2,13 +2,17 @@
<manifest xmlns:android="http://schemas.android.com/apk/res/android">
<uses-permission android:name="android.permission.RECORD_AUDIO" />
<uses-permission android:name="android.permission.CAMERA" />
<uses-permission android:name="android.permission.INTERNET" />
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE" />
<uses-feature android:name="android.hardware.camera.any" />
<application
android:allowBackup="true"
android:label="@string/app_name"
android:supportsRtl="true"
android:theme="@style/Theme.DigitalPerson">
android:theme="@style/Theme.DigitalPerson"
android:usesCleartextTraffic="true">
<activity
android:name="com.digitalperson.EntryActivity"

View File

@@ -0,0 +1,210 @@
#include "ArcFaceEngineRKNN.h"
#include <algorithm>
#include <android/log.h>
#include <cmath>
#include <cstring>
#define LOG_TAG "ArcFaceRKNN"
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__)
#define LOGW(...) __android_log_print(ANDROID_LOG_WARN, LOG_TAG, __VA_ARGS__)
#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__)
ArcFaceEngineRKNN::ArcFaceEngineRKNN() = default;
ArcFaceEngineRKNN::~ArcFaceEngineRKNN() {
release();
}
int ArcFaceEngineRKNN::init(const char* modelPath) {
release();
int ret = rknn_init(&ctx_, (void*)modelPath, 0, 0, nullptr);
if (ret != RKNN_SUCC) {
LOGE("rknn_init failed: %d model=%s", ret, modelPath);
return ret;
}
std::memset(&ioNum_, 0, sizeof(ioNum_));
ret = rknn_query(ctx_, RKNN_QUERY_IN_OUT_NUM, &ioNum_, sizeof(ioNum_));
if (ret != RKNN_SUCC || ioNum_.n_input < 1 || ioNum_.n_output < 1) {
LOGE("query io num failed: ret=%d in=%u out=%u", ret, ioNum_.n_input, ioNum_.n_output);
release();
return ret != RKNN_SUCC ? ret : -1;
}
std::memset(&inputAttr_, 0, sizeof(inputAttr_));
inputAttr_.index = 0;
ret = rknn_query(ctx_, RKNN_QUERY_INPUT_ATTR, &inputAttr_, sizeof(inputAttr_));
if (ret != RKNN_SUCC) {
LOGE("query input attr failed: %d", ret);
release();
return ret;
}
if (inputAttr_.n_dims == 4) {
if (inputAttr_.fmt == RKNN_TENSOR_NHWC) {
inputH_ = static_cast<int>(inputAttr_.dims[1]);
inputW_ = static_cast<int>(inputAttr_.dims[2]);
} else {
inputH_ = static_cast<int>(inputAttr_.dims[2]);
inputW_ = static_cast<int>(inputAttr_.dims[3]);
}
} else if (inputAttr_.n_dims == 3) {
inputH_ = static_cast<int>(inputAttr_.dims[1]);
inputW_ = static_cast<int>(inputAttr_.dims[2]);
}
std::memset(&outputAttr_, 0, sizeof(outputAttr_));
outputAttr_.index = 0;
ret = rknn_query(ctx_, RKNN_QUERY_OUTPUT_ATTR, &outputAttr_, sizeof(outputAttr_));
if (ret != RKNN_SUCC) {
LOGW("query output attr failed: %d", ret);
}
initialized_ = true;
LOGI("ArcFace initialized input=%dx%d", inputW_, inputH_);
return 0;
}
std::vector<float> ArcFaceEngineRKNN::extractEmbedding(
const uint32_t* argbPixels,
int width,
int height,
int strideBytes,
float left,
float top,
float right,
float bottom) {
LOGI("extractEmbedding called: initialized=%d, ctx=%p, pixels=%p, width=%d, height=%d, left=%.2f, top=%.2f, right=%.2f, bottom=%.2f",
initialized_, ctx_, argbPixels, width, height, left, top, right, bottom);
std::vector<float> empty;
if (!initialized_ || ctx_ == 0 || argbPixels == nullptr || width <= 0 || height <= 0) {
LOGW("extractEmbedding failed: invalid parameters");
return empty;
}
const float faceW = right - left;
const float faceH = bottom - top;
LOGI("Face size: width=%.2f, height=%.2f", faceW, faceH);
if (faceW < 4.0f || faceH < 4.0f) {
LOGW("extractEmbedding failed: face too small");
return empty;
}
const float pad = 0.15f;
float x1f = std::max(0.0f, left - faceW * pad);
float y1f = std::max(0.0f, top - faceH * pad);
float x2f = std::min(static_cast<float>(width), right + faceW * pad);
float y2f = std::min(static_cast<float>(height), bottom + faceH * pad);
int x1 = static_cast<int>(std::floor(x1f));
int y1 = static_cast<int>(std::floor(y1f));
int x2 = static_cast<int>(std::ceil(x2f));
int y2 = static_cast<int>(std::ceil(y2f));
if (x2 <= x1 || y2 <= y1) {
return empty;
}
const int cropW = x2 - x1;
const int cropH = y2 - y1;
const int srcStridePx = strideBytes / 4;
std::vector<uint8_t> rgb(inputW_ * inputH_ * 3);
for (int y = 0; y < inputH_; ++y) {
const int sy = y1 + (y * cropH / inputH_);
const uint32_t* srcRow = argbPixels + sy * srcStridePx;
uint8_t* dst = rgb.data() + y * inputW_ * 3;
for (int x = 0; x < inputW_; ++x) {
const int sx = x1 + (x * cropW / inputW_);
const uint32_t pixel = srcRow[sx];
dst[3 * x + 0] = (pixel >> 16) & 0xFF;
dst[3 * x + 1] = (pixel >> 8) & 0xFF;
dst[3 * x + 2] = pixel & 0xFF;
}
}
rknn_input input{};
input.index = 0;
input.type = RKNN_TENSOR_UINT8;
input.size = rgb.size();
input.buf = rgb.data();
input.pass_through = 0;
input.fmt = (inputAttr_.fmt == RKNN_TENSOR_NCHW) ? RKNN_TENSOR_NCHW : RKNN_TENSOR_NHWC;
std::vector<uint8_t> nchw;
if (input.fmt == RKNN_TENSOR_NCHW) {
nchw.resize(rgb.size());
const int hw = inputW_ * inputH_;
for (int i = 0; i < hw; ++i) {
nchw[i] = rgb[3 * i + 0];
nchw[hw + i] = rgb[3 * i + 1];
nchw[2 * hw + i] = rgb[3 * i + 2];
}
input.buf = nchw.data();
}
int ret = rknn_inputs_set(ctx_, 1, &input);
if (ret != RKNN_SUCC) {
LOGW("rknn_inputs_set failed: %d", ret);
return empty;
}
ret = rknn_run(ctx_, nullptr);
if (ret != RKNN_SUCC) {
LOGW("rknn_run failed: %d", ret);
return empty;
}
LOGI("rknn_run succeeded");
std::vector<rknn_output> outputs(ioNum_.n_output);
for (uint32_t i = 0; i < ioNum_.n_output; ++i) {
std::memset(&outputs[i], 0, sizeof(rknn_output));
outputs[i].want_float = 1;
}
ret = rknn_outputs_get(ctx_, ioNum_.n_output, outputs.data(), nullptr);
if (ret != RKNN_SUCC) {
LOGW("rknn_outputs_get failed: %d", ret);
return empty;
}
LOGI("rknn_outputs_get succeeded: n_output=%u", ioNum_.n_output);
if (outputs[0].buf == nullptr) {
LOGW("Output buffer is null");
rknn_outputs_release(ctx_, ioNum_.n_output, outputs.data());
return empty;
}
size_t elems = outputAttr_.n_elems > 0 ? static_cast<size_t>(outputAttr_.n_elems) : 0;
if (elems == 0) {
elems = 1;
for (uint32_t d = 0; d < outputAttr_.n_dims; ++d) {
if (outputAttr_.dims[d] > 0) {
elems *= static_cast<size_t>(outputAttr_.dims[d]);
}
}
}
if (elems == 0) {
rknn_outputs_release(ctx_, ioNum_.n_output, outputs.data());
return empty;
}
const float* ptr = reinterpret_cast<const float*>(outputs[0].buf);
std::vector<float> embedding(ptr, ptr + elems);
rknn_outputs_release(ctx_, ioNum_.n_output, outputs.data());
// L2 normalize for cosine similarity.
float sum = 0.0f;
for (float v : embedding) sum += v * v;
const float norm = std::sqrt(std::max(sum, 1e-12f));
for (float& v : embedding) v /= norm;
return embedding;
}
void ArcFaceEngineRKNN::release() {
if (ctx_ != 0) {
rknn_destroy(ctx_);
ctx_ = 0;
}
std::memset(&ioNum_, 0, sizeof(ioNum_));
std::memset(&inputAttr_, 0, sizeof(inputAttr_));
std::memset(&outputAttr_, 0, sizeof(outputAttr_));
initialized_ = false;
}

View File

@@ -0,0 +1,36 @@
#ifndef DIGITAL_PERSON_ARCFACE_ENGINE_RKNN_H
#define DIGITAL_PERSON_ARCFACE_ENGINE_RKNN_H
#include <cstdint>
#include <vector>
#include "rknn_api.h"
class ArcFaceEngineRKNN {
public:
ArcFaceEngineRKNN();
~ArcFaceEngineRKNN();
int init(const char* modelPath);
std::vector<float> extractEmbedding(
const uint32_t* argbPixels,
int width,
int height,
int strideBytes,
float left,
float top,
float right,
float bottom);
void release();
private:
rknn_context ctx_ = 0;
bool initialized_ = false;
rknn_input_output_num ioNum_{};
rknn_tensor_attr inputAttr_{};
rknn_tensor_attr outputAttr_{};
int inputW_ = 112;
int inputH_ = 112;
};
#endif

View File

@@ -0,0 +1,109 @@
#include <jni.h>
#include <android/bitmap.h>
#include <android/log.h>
#include "ArcFaceEngineRKNN.h"
#define LOG_TAG "ArcFaceJNI"
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__)
#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__)
extern "C" {
JNIEXPORT jlong JNICALL
Java_com_digitalperson_engine_ArcFaceEngineRKNN_createEngineNative(JNIEnv* env, jobject thiz) {
auto* engine = new ArcFaceEngineRKNN();
if (engine == nullptr) return 0;
return reinterpret_cast<jlong>(engine);
}
JNIEXPORT jint JNICALL
Java_com_digitalperson_engine_ArcFaceEngineRKNN_initNative(
JNIEnv* env,
jobject thiz,
jlong ptr,
jstring modelPath) {
auto* engine = reinterpret_cast<ArcFaceEngineRKNN*>(ptr);
if (engine == nullptr || modelPath == nullptr) return -1;
const char* model = env->GetStringUTFChars(modelPath, nullptr);
if (model == nullptr) return -1;
int ret = engine->init(model);
env->ReleaseStringUTFChars(modelPath, model);
return ret;
}
JNIEXPORT jfloatArray JNICALL
Java_com_digitalperson_engine_ArcFaceEngineRKNN_extractEmbeddingNative(
JNIEnv* env,
jobject thiz,
jlong ptr,
jobject bitmapObj,
jfloat left,
jfloat top,
jfloat right,
jfloat bottom) {
LOGI("extractEmbeddingNative called: ptr=%ld, left=%.2f, top=%.2f, right=%.2f, bottom=%.2f", ptr, left, top, right, bottom);
auto* engine = reinterpret_cast<ArcFaceEngineRKNN*>(ptr);
if (engine == nullptr || bitmapObj == nullptr) {
LOGE("Engine or bitmap is null: engine=%p, bitmap=%p", engine, bitmapObj);
return env->NewFloatArray(0);
}
AndroidBitmapInfo info{};
if (AndroidBitmap_getInfo(env, bitmapObj, &info) < 0) {
LOGE("AndroidBitmap_getInfo failed");
return env->NewFloatArray(0);
}
LOGI("Bitmap info: width=%d, height=%d, stride=%d, format=%d", info.width, info.height, info.stride, info.format);
if (info.format != ANDROID_BITMAP_FORMAT_RGBA_8888) {
LOGE("Unsupported bitmap format: %d", info.format);
return env->NewFloatArray(0);
}
void* pixels = nullptr;
if (AndroidBitmap_lockPixels(env, bitmapObj, &pixels) < 0 || pixels == nullptr) {
LOGE("AndroidBitmap_lockPixels failed");
return env->NewFloatArray(0);
}
LOGI("Bitmap pixels locked successfully");
std::vector<float> emb = engine->extractEmbedding(
reinterpret_cast<uint32_t*>(pixels),
static_cast<int>(info.width),
static_cast<int>(info.height),
static_cast<int>(info.stride),
left,
top,
right,
bottom);
LOGI("Engine extractEmbedding returned: size=%zu", emb.size());
AndroidBitmap_unlockPixels(env, bitmapObj);
jfloatArray out = env->NewFloatArray(static_cast<jsize>(emb.size()));
if (out == nullptr) {
LOGE("Failed to create float array");
return env->NewFloatArray(0);
}
if (!emb.empty()) {
env->SetFloatArrayRegion(out, 0, static_cast<jsize>(emb.size()), emb.data());
}
return out;
}
JNIEXPORT void JNICALL
Java_com_digitalperson_engine_ArcFaceEngineRKNN_releaseNative(
JNIEnv* env,
jobject thiz,
jlong ptr) {
auto* engine = reinterpret_cast<ArcFaceEngineRKNN*>(ptr);
if (engine != nullptr) {
engine->release();
delete engine;
}
}
} // extern "C"

View File

@@ -14,6 +14,11 @@ if (ANDROID)
set_target_properties(sentencepiece PROPERTIES IMPORTED_LOCATION
${JNI_LIBS_DIR}/libsentencepiece.so)
# 导入 rkllm 库
add_library(rkllmrt SHARED IMPORTED)
set_target_properties(rkllmrt PROPERTIES IMPORTED_LOCATION
${JNI_LIBS_DIR}/librkllmrt.so)
# Imported static libs
add_library(kaldi_native_fbank STATIC IMPORTED)
set_target_properties(kaldi_native_fbank PROPERTIES IMPORTED_LOCATION
@@ -26,6 +31,12 @@ if (ANDROID)
add_library(sensevoiceEngine SHARED
SenseVoiceEngineRKNN.cpp
SenseVoiceEngineRKNNJNI.cpp
RetinaFaceEngineRKNN.cpp
RetinaFaceEngineRKNNJNI.cpp
ArcFaceEngineRKNN.cpp
ArcFaceEngineRKNNJNI.cpp
RKLLMEngine.cpp
RKLLMEngineJNI.cpp
utils/audio_utils.c
)
@@ -40,9 +51,11 @@ if (ANDROID)
target_link_libraries(sensevoiceEngine
rknnrt
rkllmrt
kaldi_native_fbank
sndfile
sentencepiece
jnigraphics
log
)
endif()

View File

@@ -0,0 +1,118 @@
#include <android/log.h>
#include <jni.h>
#include "zipformer_headers/rkllm.h"
#define TAG "RKLLMEngine"
#define LOGD(...) __android_log_print(ANDROID_LOG_DEBUG, TAG, __VA_ARGS__)
#define LOGW(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__)
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
namespace {
// Keep these conservative first; too-large values may fail during init on RK3588.
constexpr int kDefaultMaxContextLen = 1024;
constexpr int kDefaultMaxNewTokens = 128;
}
struct LLmJniEnv {
JNIEnv *env;
jobject thiz;
jclass clazz;
};
void callbackToJava(const char *text, int state, LLmJniEnv *jenv) {
jmethodID method = jenv->env->GetMethodID(jenv->clazz, "callbackFromNative", "(Ljava/lang/String;I)V");
jstring jText = text ? jenv->env->NewStringUTF(text) : jenv->env->NewStringUTF("");
jenv->env->CallVoidMethod(jenv->thiz, method, jText, state);
}
int callback(RKLLMResult *result, void *userdata, LLMCallState state) {
auto jenv = (LLmJniEnv *)userdata;
if (state == RKLLM_RUN_FINISH) {
LOGI("<FINISH/>");
callbackToJava(nullptr, 0, jenv);
delete jenv;
} else if (state == RKLLM_RUN_ERROR) {
LOGE("<ERROR/>");
callbackToJava(nullptr, -1, jenv);
delete jenv;
} else if (state == RKLLM_RUN_NORMAL) {
//LOGD("NM: [%d] %s", result->token_id, result->text);
callbackToJava(result->text, 1, jenv);
}
return 0; // 返回 0 表示正常继续执行
}
// JNI 方法实现
extern "C" {
jlong initLLM(JNIEnv *env, jobject thiz, jstring model_path) {
const char* modelPath = env->GetStringUTFChars(model_path, nullptr);
LLMHandle llmHandle = nullptr;
//设置参数及初始化
RKLLMParam param = rkllm_createDefaultParam();
param.model_path = modelPath;
//设置采样参数
param.top_k = 1;
param.top_p = 0.95;
param.temperature = 0.8;
param.repeat_penalty = 1.1;
param.frequency_penalty = 0.0;
param.presence_penalty = 0.0;
param.max_new_tokens = kDefaultMaxNewTokens;
param.max_context_len = kDefaultMaxContextLen;
param.skip_special_token = true;
param.extend_param.base_domain_id = 0;
LOGI("rkllm init with module: %s", modelPath);
LOGI("rkllm init params: max_context_len=%d, max_new_tokens=%d, top_k=%d, top_p=%f, temp=%f",
param.max_context_len, param.max_new_tokens, param.top_k, param.top_p, param.temperature);
int ret = rkllm_init(&llmHandle, &param, callback);
if (ret == 0){
LOGI("rkllm init success, handle=%p", llmHandle);
} else {
LOGE("rkllm init failed, ret=%d", ret);
llmHandle = nullptr;
}
env->ReleaseStringUTFChars(model_path, modelPath);
return (jlong)llmHandle;
}
void deinitLLM(JNIEnv *env, jobject thiz, jlong handle) {
rkllm_destroy((LLMHandle)handle);
}
void infer(JNIEnv *env, jobject thiz, jlong handle, jstring text) {
if (handle == 0) {
LOGE("rkllm infer called with null handle");
jclass clazz = env->GetObjectClass(thiz);
jmethodID method = env->GetMethodID(clazz, "callbackFromNative", "(Ljava/lang/String;I)V");
jstring jText = env->NewStringUTF("RKLLM handle is null");
env->CallVoidMethod(thiz, method, jText, -1);
env->DeleteLocalRef(jText);
return;
}
auto *jnienv = new LLmJniEnv {
.env = env,
.thiz = thiz,
.clazz = env->GetObjectClass(thiz),
};
RKLLMInput rkllm_input = {};
RKLLMInferParam rkllm_infer_params = {};
const char* sText = env->GetStringUTFChars(text, nullptr);
rkllm_infer_params.mode = RKLLM_INFER_GENERATE;
rkllm_input.input_type = RKLLM_INPUT_PROMPT;
rkllm_input.prompt_input = (char *)sText;
rkllm_run((LLMHandle)handle, &rkllm_input, &rkllm_infer_params, jnienv);
env->ReleaseStringUTFChars(text, sText);
}
}

View File

@@ -0,0 +1,37 @@
#include <jni.h>
// JNI 方法声明
extern "C" {
JNIEXPORT jlong JNICALL
Java_com_digitalperson_llm_RKLLM_initLLM(JNIEnv *env, jobject thiz, jstring model_path);
JNIEXPORT void JNICALL
Java_com_digitalperson_llm_RKLLM_deinitLLM(JNIEnv *env, jobject thiz, jlong handle);
JNIEXPORT void JNICALL
Java_com_digitalperson_llm_RKLLM_infer(JNIEnv *env, jobject thiz, jlong handle, jstring text);
}
// 方法实现
extern "C" {
jlong initLLM(JNIEnv *env, jobject thiz, jstring model_path);
void deinitLLM(JNIEnv *env, jobject thiz, jlong handle);
void infer(JNIEnv *env, jobject thiz, jlong handle, jstring text);
}
extern "C" {
JNIEXPORT jlong JNICALL
Java_com_digitalperson_llm_RKLLM_initLLM(JNIEnv *env, jobject thiz, jstring model_path) {
return initLLM(env, thiz, model_path);
}
JNIEXPORT void JNICALL
Java_com_digitalperson_llm_RKLLM_deinitLLM(JNIEnv *env, jobject thiz, jlong handle) {
deinitLLM(env, thiz, handle);
}
JNIEXPORT void JNICALL
Java_com_digitalperson_llm_RKLLM_infer(JNIEnv *env, jobject thiz, jlong handle, jstring text) {
infer(env, thiz, handle, text);
}
}

View File

@@ -0,0 +1,435 @@
#include "RetinaFaceEngineRKNN.h"
#include <algorithm>
#include <android/log.h>
#include <cmath>
#include <cstring>
#define LOG_TAG "RetinaFaceRKNN"
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__)
#define LOGW(...) __android_log_print(ANDROID_LOG_WARN, LOG_TAG, __VA_ARGS__)
#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__)
namespace {
constexpr float kVariance0 = 0.1f;
constexpr float kVariance1 = 0.2f;
} // namespace
RetinaFaceEngineRKNN::RetinaFaceEngineRKNN() = default;
RetinaFaceEngineRKNN::~RetinaFaceEngineRKNN() {
release();
}
size_t RetinaFaceEngineRKNN::tensorElemCount(const rknn_tensor_attr& attr) {
if (attr.n_elems > 0) {
return static_cast<size_t>(attr.n_elems);
}
if (attr.n_dims <= 0) {
return 0;
}
size_t n = 1;
for (uint32_t i = 0; i < attr.n_dims; ++i) {
if (attr.dims[i] == 0) continue;
n *= static_cast<size_t>(attr.dims[i]);
}
return n;
}
int RetinaFaceEngineRKNN::init(
const char* modelPath,
int inputSize,
float scoreThreshold,
float nmsThreshold) {
release();
inputSize_ = inputSize;
scoreThreshold_ = scoreThreshold;
nmsThreshold_ = nmsThreshold;
int ret = rknn_init(&ctx_, (void*)modelPath, 0, 0, nullptr);
if (ret != RKNN_SUCC) {
LOGE("rknn_init failed: ret=%d, model=%s", ret, modelPath);
return ret;
}
std::memset(&ioNum_, 0, sizeof(ioNum_));
ret = rknn_query(ctx_, RKNN_QUERY_IN_OUT_NUM, &ioNum_, sizeof(ioNum_));
if (ret != RKNN_SUCC) {
LOGE("rknn_query(RKNN_QUERY_IN_OUT_NUM) failed: %d", ret);
release();
return ret;
}
if (ioNum_.n_input < 1 || ioNum_.n_output < 1) {
LOGE("invalid io num: input=%u output=%u", ioNum_.n_input, ioNum_.n_output);
release();
return -1;
}
std::memset(&inputAttr_, 0, sizeof(inputAttr_));
inputAttr_.index = 0;
ret = rknn_query(ctx_, RKNN_QUERY_INPUT_ATTR, &inputAttr_, sizeof(inputAttr_));
if (ret != RKNN_SUCC) {
LOGE("rknn_query input attr failed: %d", ret);
release();
return ret;
}
outputAttrs_.clear();
outputAttrs_.resize(ioNum_.n_output);
for (uint32_t i = 0; i < ioNum_.n_output; ++i) {
std::memset(&outputAttrs_[i], 0, sizeof(rknn_tensor_attr));
outputAttrs_[i].index = i;
int qret = rknn_query(ctx_, RKNN_QUERY_OUTPUT_ATTR, &outputAttrs_[i], sizeof(rknn_tensor_attr));
if (qret != RKNN_SUCC) {
LOGW("query output attr[%u] failed: %d", i, qret);
}
LOGI("output[%u] n_elems=%u n_dims=%u type=%d qnt=%d",
i, outputAttrs_[i].n_elems, outputAttrs_[i].n_dims, outputAttrs_[i].type, outputAttrs_[i].qnt_type);
for (uint32_t d = 0; d < outputAttrs_[i].n_dims; ++d) {
LOGI(" output[%u] dim[%u]=%u", i, d, outputAttrs_[i].dims[d]);
}
}
initialized_ = true;
LOGI("RetinaFace initialized, input_size=%d, outputs=%u", inputSize_, ioNum_.n_output);
return 0;
}
std::vector<RetinaFaceEngineRKNN::PriorBox> RetinaFaceEngineRKNN::buildPriors() const {
std::vector<PriorBox> priors;
const int steps[3] = {8, 16, 32};
const int minSizes[3][2] = {{16, 32}, {64, 128}, {256, 512}};
for (int s = 0; s < 3; ++s) {
const int step = steps[s];
const int featW = static_cast<int>(std::ceil(static_cast<float>(inputSize_) / step));
const int featH = static_cast<int>(std::ceil(static_cast<float>(inputSize_) / step));
for (int y = 0; y < featH; ++y) {
for (int x = 0; x < featW; ++x) {
for (int k = 0; k < 2; ++k) {
const float minSize = static_cast<float>(minSizes[s][k]);
PriorBox p;
p.cx = (x + 0.5f) * step / inputSize_;
p.cy = (y + 0.5f) * step / inputSize_;
p.w = minSize / inputSize_;
p.h = minSize / inputSize_;
priors.push_back(p);
}
}
}
}
return priors;
}
bool RetinaFaceEngineRKNN::parseRetinaOutputs(
rknn_output* outputs,
std::vector<float>* locOut,
std::vector<float>* scoreOut) const {
std::vector<std::vector<float>> locCandidates;
std::vector<std::vector<float>> confCandidates2;
std::vector<std::vector<float>> scoreCandidates1;
const int anchors8 = (inputSize_ / 8) * (inputSize_ / 8) * 2;
const int anchors16 = (inputSize_ / 16) * (inputSize_ / 16) * 2;
const int anchors32 = (inputSize_ / 32) * (inputSize_ / 32) * 2;
const int totalAnchors = anchors8 + anchors16 + anchors32;
const int expectedLoc8 = anchors8 * 4;
const int expectedLoc16 = anchors16 * 4;
const int expectedLoc32 = anchors32 * 4;
const int expectedConf8_2 = anchors8 * 2;
const int expectedConf16_2 = anchors16 * 2;
const int expectedConf32_2 = anchors32 * 2;
const int expectedConf8_1 = anchors8;
const int expectedConf16_1 = anchors16;
const int expectedConf32_1 = anchors32;
for (uint32_t i = 0; i < ioNum_.n_output; ++i) {
const size_t elems = tensorElemCount(outputAttrs_[i]);
if (elems == 0 || outputs[i].buf == nullptr) continue;
const float* ptr = reinterpret_cast<const float*>(outputs[i].buf);
std::vector<float> data(ptr, ptr + elems);
const int e = static_cast<int>(elems);
if (e == expectedLoc8 || e == expectedLoc16 || e == expectedLoc32 || e == totalAnchors * 4) {
locCandidates.push_back(std::move(data));
continue;
}
if (e == expectedConf8_2 || e == expectedConf16_2 || e == expectedConf32_2 || e == totalAnchors * 2) {
confCandidates2.push_back(std::move(data));
continue;
}
if (e == expectedConf8_1 || e == expectedConf16_1 || e == expectedConf32_1 || e == totalAnchors) {
scoreCandidates1.push_back(std::move(data));
continue;
}
}
locOut->clear();
scoreOut->clear();
auto sortBySize = [](const std::vector<float>& a, const std::vector<float>& b) {
return a.size() > b.size();
};
std::sort(locCandidates.begin(), locCandidates.end(), sortBySize);
std::sort(confCandidates2.begin(), confCandidates2.end(), sortBySize);
std::sort(scoreCandidates1.begin(), scoreCandidates1.end(), sortBySize);
auto mergeLoc = [&]() -> bool {
if (locCandidates.empty()) return false;
if (locCandidates.size() >= 3 &&
static_cast<int>(locCandidates[0].size()) == anchors8 * 4 &&
static_cast<int>(locCandidates[1].size()) == anchors16 * 4 &&
static_cast<int>(locCandidates[2].size()) == anchors32 * 4) {
locOut->reserve(static_cast<size_t>(totalAnchors) * 4);
locOut->insert(locOut->end(), locCandidates[0].begin(), locCandidates[0].end());
locOut->insert(locOut->end(), locCandidates[1].begin(), locCandidates[1].end());
locOut->insert(locOut->end(), locCandidates[2].begin(), locCandidates[2].end());
return true;
}
for (const auto& c : locCandidates) {
if (static_cast<int>(c.size()) == totalAnchors * 4) {
*locOut = c;
return true;
}
}
return false;
};
auto mergeScoreFrom2Class = [&]() -> bool {
if (confCandidates2.empty()) return false;
std::vector<float> merged2;
if (confCandidates2.size() >= 3 &&
static_cast<int>(confCandidates2[0].size()) == expectedConf8_2 &&
static_cast<int>(confCandidates2[1].size()) == expectedConf16_2 &&
static_cast<int>(confCandidates2[2].size()) == expectedConf32_2) {
merged2.reserve(static_cast<size_t>(totalAnchors) * 2);
merged2.insert(merged2.end(), confCandidates2[0].begin(), confCandidates2[0].end());
merged2.insert(merged2.end(), confCandidates2[1].begin(), confCandidates2[1].end());
merged2.insert(merged2.end(), confCandidates2[2].begin(), confCandidates2[2].end());
} else {
bool found = false;
for (const auto& c : confCandidates2) {
if (static_cast<int>(c.size()) == totalAnchors * 2) {
merged2 = c;
found = true;
break;
}
}
if (!found) return false;
}
scoreOut->reserve(totalAnchors);
for (int i = 0; i < totalAnchors; ++i) {
scoreOut->push_back(merged2[i * 2 + 1]);
}
return true;
};
auto mergeScoreFrom1Class = [&]() -> bool {
if (scoreCandidates1.empty()) return false;
if (scoreCandidates1.size() >= 3 &&
static_cast<int>(scoreCandidates1[0].size()) == expectedConf8_1 &&
static_cast<int>(scoreCandidates1[1].size()) == expectedConf16_1 &&
static_cast<int>(scoreCandidates1[2].size()) == expectedConf32_1) {
scoreOut->reserve(totalAnchors);
scoreOut->insert(scoreOut->end(), scoreCandidates1[0].begin(), scoreCandidates1[0].end());
scoreOut->insert(scoreOut->end(), scoreCandidates1[1].begin(), scoreCandidates1[1].end());
scoreOut->insert(scoreOut->end(), scoreCandidates1[2].begin(), scoreCandidates1[2].end());
return true;
}
for (const auto& c : scoreCandidates1) {
if (static_cast<int>(c.size()) == totalAnchors) {
*scoreOut = c;
return true;
}
}
return false;
};
const bool locOk = mergeLoc();
bool scoreOk = mergeScoreFrom2Class();
if (!scoreOk) {
scoreOk = mergeScoreFrom1Class();
}
if (!locOk || !scoreOk) {
LOGW("Unable to parse retina outputs, loc_candidates=%zu, conf2_candidates=%zu, conf1_candidates=%zu",
locCandidates.size(), confCandidates2.size(), scoreCandidates1.size());
return false;
}
return true;
}
float RetinaFaceEngineRKNN::iou(const FaceCandidate& a, const FaceCandidate& b) {
const float left = std::max(a.left, b.left);
const float top = std::max(a.top, b.top);
const float right = std::min(a.right, b.right);
const float bottom = std::min(a.bottom, b.bottom);
const float w = std::max(0.0f, right - left);
const float h = std::max(0.0f, bottom - top);
const float inter = w * h;
const float areaA = std::max(0.0f, a.right - a.left) * std::max(0.0f, a.bottom - a.top);
const float areaB = std::max(0.0f, b.right - b.left) * std::max(0.0f, b.bottom - b.top);
const float uni = areaA + areaB - inter;
return uni > 0.0f ? (inter / uni) : 0.0f;
}
std::vector<RetinaFaceEngineRKNN::FaceCandidate> RetinaFaceEngineRKNN::nms(
const std::vector<FaceCandidate>& boxes,
float threshold) {
std::vector<FaceCandidate> sorted = boxes;
std::sort(sorted.begin(), sorted.end(), [](const FaceCandidate& a, const FaceCandidate& b) {
return a.score > b.score;
});
std::vector<FaceCandidate> keep;
std::vector<char> removed(sorted.size(), 0);
for (size_t i = 0; i < sorted.size(); ++i) {
if (removed[i]) continue;
keep.push_back(sorted[i]);
for (size_t j = i + 1; j < sorted.size(); ++j) {
if (removed[j]) continue;
if (iou(sorted[i], sorted[j]) > threshold) {
removed[j] = 1;
}
}
}
return keep;
}
std::vector<float> RetinaFaceEngineRKNN::detect(
const uint32_t* argbPixels,
int width,
int height,
int strideBytes) {
std::vector<float> empty;
if (!initialized_ || ctx_ == 0 || argbPixels == nullptr || width <= 0 || height <= 0) {
return empty;
}
std::vector<uint8_t> rgb(inputSize_ * inputSize_ * 3);
const int srcStridePx = strideBytes / 4;
for (int y = 0; y < inputSize_; ++y) {
const int sy = y * height / inputSize_;
const uint32_t* srcRow = argbPixels + sy * srcStridePx;
uint8_t* dst = rgb.data() + y * inputSize_ * 3;
for (int x = 0; x < inputSize_; ++x) {
const int sx = x * width / inputSize_;
const uint32_t pixel = srcRow[sx];
const uint8_t r = (pixel >> 16) & 0xFF;
const uint8_t g = (pixel >> 8) & 0xFF;
const uint8_t b = pixel & 0xFF;
dst[3 * x + 0] = r;
dst[3 * x + 1] = g;
dst[3 * x + 2] = b;
}
}
rknn_input input{};
input.index = 0;
input.type = RKNN_TENSOR_UINT8;
input.size = rgb.size();
input.buf = rgb.data();
input.pass_through = 0;
input.fmt = (inputAttr_.fmt == RKNN_TENSOR_NCHW) ? RKNN_TENSOR_NCHW : RKNN_TENSOR_NHWC;
std::vector<uint8_t> nchw;
if (input.fmt == RKNN_TENSOR_NCHW) {
nchw.resize(rgb.size());
const int hw = inputSize_ * inputSize_;
for (int i = 0; i < hw; ++i) {
nchw[i] = rgb[3 * i + 0];
nchw[hw + i] = rgb[3 * i + 1];
nchw[2 * hw + i] = rgb[3 * i + 2];
}
input.buf = nchw.data();
}
int ret = rknn_inputs_set(ctx_, 1, &input);
if (ret != RKNN_SUCC) {
LOGW("rknn_inputs_set failed: %d", ret);
return empty;
}
ret = rknn_run(ctx_, nullptr);
if (ret != RKNN_SUCC) {
LOGW("rknn_run failed: %d", ret);
return empty;
}
std::vector<rknn_output> outputs(ioNum_.n_output);
for (uint32_t i = 0; i < ioNum_.n_output; ++i) {
std::memset(&outputs[i], 0, sizeof(rknn_output));
outputs[i].want_float = 1;
}
ret = rknn_outputs_get(ctx_, ioNum_.n_output, outputs.data(), nullptr);
if (ret != RKNN_SUCC) {
LOGW("rknn_outputs_get failed: %d", ret);
return empty;
}
std::vector<float> loc;
std::vector<float> scores;
if (!parseRetinaOutputs(outputs.data(), &loc, &scores)) {
rknn_outputs_release(ctx_, ioNum_.n_output, outputs.data());
return empty;
}
const std::vector<PriorBox> priors = buildPriors();
const size_t anchorCount = priors.size();
if (loc.size() < anchorCount * 4 || scores.size() < anchorCount) {
LOGW("Output size mismatch: priors=%zu loc=%zu scores=%zu", anchorCount, loc.size(), scores.size());
rknn_outputs_release(ctx_, ioNum_.n_output, outputs.data());
return empty;
}
std::vector<FaceCandidate> candidates;
candidates.reserve(anchorCount / 8);
for (size_t i = 0; i < anchorCount; ++i) {
const float score = scores[i];
if (score < scoreThreshold_) continue;
const PriorBox& p = priors[i];
const float dx = loc[i * 4 + 0];
const float dy = loc[i * 4 + 1];
const float dw = loc[i * 4 + 2];
const float dh = loc[i * 4 + 3];
const float cx = p.cx + dx * kVariance0 * p.w;
const float cy = p.cy + dy * kVariance0 * p.h;
const float w = p.w * std::exp(dw * kVariance1);
const float h = p.h * std::exp(dh * kVariance1);
FaceCandidate box;
box.left = std::max(0.0f, (cx - w * 0.5f) * width);
box.top = std::max(0.0f, (cy - h * 0.5f) * height);
box.right = std::min(static_cast<float>(width), (cx + w * 0.5f) * width);
box.bottom = std::min(static_cast<float>(height), (cy + h * 0.5f) * height);
box.score = score;
candidates.push_back(box);
}
rknn_outputs_release(ctx_, ioNum_.n_output, outputs.data());
std::vector<FaceCandidate> filtered = nms(candidates, nmsThreshold_);
std::vector<float> result;
result.reserve(filtered.size() * 5);
for (const auto& f : filtered) {
result.push_back(f.left);
result.push_back(f.top);
result.push_back(f.right);
result.push_back(f.bottom);
result.push_back(f.score);
}
return result;
}
void RetinaFaceEngineRKNN::release() {
if (ctx_ != 0) {
rknn_destroy(ctx_);
ctx_ = 0;
}
outputAttrs_.clear();
std::memset(&ioNum_, 0, sizeof(ioNum_));
std::memset(&inputAttr_, 0, sizeof(inputAttr_));
initialized_ = false;
}

View File

@@ -0,0 +1,55 @@
#ifndef DIGITAL_PERSON_RETINAFACE_ENGINE_RKNN_H
#define DIGITAL_PERSON_RETINAFACE_ENGINE_RKNN_H
#include <cstdint>
#include <string>
#include <vector>
#include "rknn_api.h"
class RetinaFaceEngineRKNN {
public:
RetinaFaceEngineRKNN();
~RetinaFaceEngineRKNN();
int init(const char* modelPath, int inputSize, float scoreThreshold, float nmsThreshold);
std::vector<float> detect(const uint32_t* argbPixels, int width, int height, int strideBytes);
void release();
private:
struct PriorBox {
float cx;
float cy;
float w;
float h;
};
struct FaceCandidate {
float left;
float top;
float right;
float bottom;
float score;
};
static size_t tensorElemCount(const rknn_tensor_attr& attr);
static float iou(const FaceCandidate& a, const FaceCandidate& b);
static std::vector<FaceCandidate> nms(const std::vector<FaceCandidate>& boxes, float threshold);
std::vector<PriorBox> buildPriors() const;
bool parseRetinaOutputs(
rknn_output* outputs,
std::vector<float>* locOut,
std::vector<float>* scoreOut) const;
rknn_context ctx_ = 0;
bool initialized_ = false;
int inputSize_ = 320;
float scoreThreshold_ = 0.6f;
float nmsThreshold_ = 0.4f;
rknn_input_output_num ioNum_{};
rknn_tensor_attr inputAttr_{};
std::vector<rknn_tensor_attr> outputAttrs_;
};
#endif

View File

@@ -0,0 +1,100 @@
#include <jni.h>
#include <android/bitmap.h>
#include <android/log.h>
#include "RetinaFaceEngineRKNN.h"
#define LOG_TAG "RetinaFaceJNI"
#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__)
extern "C" {
JNIEXPORT jlong JNICALL
Java_com_digitalperson_engine_RetinaFaceEngineRKNN_createEngineNative(JNIEnv* env, jobject thiz) {
auto* engine = new RetinaFaceEngineRKNN();
if (engine == nullptr) {
return 0;
}
return reinterpret_cast<jlong>(engine);
}
JNIEXPORT jint JNICALL
Java_com_digitalperson_engine_RetinaFaceEngineRKNN_initNative(
JNIEnv* env,
jobject thiz,
jlong ptr,
jstring modelPath,
jint inputSize,
jfloat scoreThreshold,
jfloat nmsThreshold) {
auto* engine = reinterpret_cast<RetinaFaceEngineRKNN*>(ptr);
if (engine == nullptr || modelPath == nullptr) {
return -1;
}
const char* model = env->GetStringUTFChars(modelPath, nullptr);
if (model == nullptr) {
return -1;
}
int ret = engine->init(model, static_cast<int>(inputSize), scoreThreshold, nmsThreshold);
env->ReleaseStringUTFChars(modelPath, model);
return ret;
}
JNIEXPORT jfloatArray JNICALL
Java_com_digitalperson_engine_RetinaFaceEngineRKNN_detectNative(
JNIEnv* env,
jobject thiz,
jlong ptr,
jobject bitmapObj) {
auto* engine = reinterpret_cast<RetinaFaceEngineRKNN*>(ptr);
if (engine == nullptr || bitmapObj == nullptr) {
return env->NewFloatArray(0);
}
AndroidBitmapInfo info{};
if (AndroidBitmap_getInfo(env, bitmapObj, &info) < 0) {
LOGE("AndroidBitmap_getInfo failed");
return env->NewFloatArray(0);
}
if (info.format != ANDROID_BITMAP_FORMAT_RGBA_8888) {
LOGE("Unsupported bitmap format: %d", info.format);
return env->NewFloatArray(0);
}
void* pixels = nullptr;
if (AndroidBitmap_lockPixels(env, bitmapObj, &pixels) < 0 || pixels == nullptr) {
LOGE("AndroidBitmap_lockPixels failed");
return env->NewFloatArray(0);
}
std::vector<float> result = engine->detect(
reinterpret_cast<uint32_t*>(pixels),
static_cast<int>(info.width),
static_cast<int>(info.height),
static_cast<int>(info.stride));
AndroidBitmap_unlockPixels(env, bitmapObj);
jfloatArray out = env->NewFloatArray(static_cast<jsize>(result.size()));
if (out == nullptr) {
return env->NewFloatArray(0);
}
if (!result.empty()) {
env->SetFloatArrayRegion(out, 0, static_cast<jsize>(result.size()), result.data());
}
return out;
}
JNIEXPORT void JNICALL
Java_com_digitalperson_engine_RetinaFaceEngineRKNN_releaseNative(
JNIEnv* env,
jobject thiz,
jlong ptr) {
auto* engine = reinterpret_cast<RetinaFaceEngineRKNN*>(ptr);
if (engine != nullptr) {
engine->release();
delete engine;
}
}
} // extern "C"

View File

@@ -0,0 +1,409 @@
#ifndef _RKLLM_H_
#define _RKLLM_H_
#include <cstdint>
#ifdef __cplusplus
extern "C" {
#endif
#define CPU0 (1 << 0) // 0x01
#define CPU1 (1 << 1) // 0x02
#define CPU2 (1 << 2) // 0x04
#define CPU3 (1 << 3) // 0x08
#define CPU4 (1 << 4) // 0x10
#define CPU5 (1 << 5) // 0x20
#define CPU6 (1 << 6) // 0x40
#define CPU7 (1 << 7) // 0x80
/**
* @typedef LLMHandle
* @brief A handle used to manage and interact with the large language model.
*/
typedef void* LLMHandle;
/**
* @enum LLMCallState
* @brief Describes the possible states of an LLM call.
*/
typedef enum {
RKLLM_RUN_NORMAL = 0, /**< The LLM call is in a normal running state. */
RKLLM_RUN_WAITING = 1, /**< The LLM call is waiting for complete UTF-8 encoded character. */
RKLLM_RUN_FINISH = 2, /**< The LLM call has finished execution. */
RKLLM_RUN_ERROR = 3, /**< An error occurred during the LLM call. */
} LLMCallState;
/**
* @enum RKLLMInputType
* @brief Defines the types of inputs that can be fed into the LLM.
*/
typedef enum {
RKLLM_INPUT_PROMPT = 0, /**< Input is a text prompt. */
RKLLM_INPUT_TOKEN = 1, /**< Input is a sequence of tokens. */
RKLLM_INPUT_EMBED = 2, /**< Input is an embedding vector. */
RKLLM_INPUT_MULTIMODAL = 3, /**< Input is multimodal (e.g., text and image). */
} RKLLMInputType;
/**
* @enum RKLLMInferMode
* @brief Specifies the inference modes of the LLM.
*/
typedef enum {
RKLLM_INFER_GENERATE = 0, /**< The LLM generates text based on input. */
RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1, /**< The LLM retrieves the last hidden layer for further processing. */
RKLLM_INFER_GET_LOGITS = 2, /**< The LLM retrieves logits for further processing. */
} RKLLMInferMode;
/**
* @struct RKLLMExtendParam
* @brief The extend parameters for configuring an LLM instance.
*/
typedef struct {
int32_t base_domain_id; /**< base_domain_id */
int8_t embed_flash; /**< Indicates whether to query word embedding vectors from flash memory (1) or not (0). */
int8_t enabled_cpus_num; /**< Number of CPUs enabled for inference. */
uint32_t enabled_cpus_mask; /**< Bitmask indicating which CPUs to enable for inference. */
uint8_t n_batch; /**< Number of input samples processed concurrently in one forward pass. Set to >1 to enable batched inference. Default is 1. */
int8_t use_cross_attn; /**< Whether to enable cross attention (non-zero to enable, 0 to disable). */
uint8_t reserved[104]; /**< reserved */
} RKLLMExtendParam;
/**
* @struct RKLLMParam
* @brief Defines the parameters for configuring an LLM instance.
*/
typedef struct {
const char* model_path; /**< Path to the model file. */
int32_t max_context_len; /**< Maximum number of tokens in the context window. */
int32_t max_new_tokens; /**< Maximum number of new tokens to generate. */
int32_t top_k; /**< Top-K sampling parameter for token generation. */
int32_t n_keep; /** number of kv cache to keep at the beginning when shifting context window */
float top_p; /**< Top-P (nucleus) sampling parameter. */
float temperature; /**< Sampling temperature, affecting the randomness of token selection. */
float repeat_penalty; /**< Penalty for repeating tokens in generation. */
float frequency_penalty; /**< Penalizes frequent tokens during generation. */
float presence_penalty; /**< Penalizes tokens based on their presence in the input. */
int32_t mirostat; /**< Mirostat sampling strategy flag (0 to disable). */
float mirostat_tau; /**< Tau parameter for Mirostat sampling. */
float mirostat_eta; /**< Eta parameter for Mirostat sampling. */
bool skip_special_token; /**< Whether to skip special tokens during generation. */
bool is_async; /**< Whether to run inference asynchronously. */
const char* img_start; /**< Starting position of an image in multimodal input. */
const char* img_end; /**< Ending position of an image in multimodal input. */
const char* img_content; /**< Pointer to the image content. */
RKLLMExtendParam extend_param; /**< Extend parameters. */
} RKLLMParam;
/**
* @struct RKLLMLoraAdapter
* @brief Defines parameters for a Lora adapter used in model fine-tuning.
*/
typedef struct {
const char* lora_adapter_path; /**< Path to the Lora adapter file. */
const char* lora_adapter_name; /**< Name of the Lora adapter. */
float scale; /**< Scaling factor for applying the Lora adapter. */
} RKLLMLoraAdapter;
/**
* @struct RKLLMEmbedInput
* @brief Represents an embedding input to the LLM.
*/
typedef struct {
float* embed; /**< Pointer to the embedding vector (of size n_tokens * n_embed). */
size_t n_tokens; /**< Number of tokens represented in the embedding. */
} RKLLMEmbedInput;
/**
* @struct RKLLMTokenInput
* @brief Represents token input to the LLM.
*/
typedef struct {
int32_t* input_ids; /**< Array of token IDs. */
size_t n_tokens; /**< Number of tokens in the input. */
} RKLLMTokenInput;
/**
* @struct RKLLMMultiModalInput
* @brief Represents multimodal input (e.g., text and image).
*/
typedef struct {
char* prompt; /**< Text prompt input. */
float* image_embed; /**< Embedding of the images (of size n_image * n_image_tokens * image_embed_length). */
size_t n_image_tokens; /**< Number of image_token. */
size_t n_image; /**< Number of image. */
size_t image_width; /**< Width of image. */
size_t image_height; /**< Height of image. */
} RKLLMMultiModalInput;
/**
* @struct RKLLMInput
* @brief Represents different types of input to the LLM via a union.
*/
typedef struct {
const char* role; /**< Message role: "user" (user input), "tool" (function result) */
bool enable_thinking; /**< Controls whether "thinking mode" is enabled for the Qwen3 model. */
RKLLMInputType input_type; /**< Specifies the type of input provided (e.g., prompt, token, embed, multimodal). */
union {
const char* prompt_input; /**< Text prompt input if input_type is RKLLM_INPUT_PROMPT. */
RKLLMEmbedInput embed_input; /**< Embedding input if input_type is RKLLM_INPUT_EMBED. */
RKLLMTokenInput token_input; /**< Token input if input_type is RKLLM_INPUT_TOKEN. */
RKLLMMultiModalInput multimodal_input; /**< Multimodal input if input_type is RKLLM_INPUT_MULTIMODAL. */
};
} RKLLMInput;
/**
* @struct RKLLMLoraParam
* @brief Structure defining parameters for Lora adapters.
*/
typedef struct {
const char* lora_adapter_name; /**< Name of the Lora adapter. */
} RKLLMLoraParam;
/**
* @struct RKLLMPromptCacheParam
* @brief Structure to define parameters for caching prompts.
*/
typedef struct {
int save_prompt_cache; /**< Flag to indicate whether to save the prompt cache (0 = don't save, 1 = save). */
const char* prompt_cache_path; /**< Path to the prompt cache file. */
} RKLLMPromptCacheParam;
/**
* @struct RKLLMCrossAttnParam
* @brief Structure holding parameters for cross-attention inference.
*
* This structure is used when performing cross-attention in the decoder.
* It provides the encoder output (key/value caches), position indices,
* and attention mask.
*
* - `encoder_k_cache` must be stored in contiguous memory with layout:
* [num_layers][num_tokens][num_kv_heads][head_dim]
* - `encoder_v_cache` must be stored in contiguous memory with layout:
* [num_layers][num_kv_heads][head_dim][num_tokens]
*/
typedef struct {
float* encoder_k_cache; /**< Pointer to encoder key cache (size: num_layers * num_tokens * num_kv_heads * head_dim). */
float* encoder_v_cache; /**< Pointer to encoder value cache (size: num_layers * num_kv_heads * head_dim * num_tokens). */
float* encoder_mask; /**< Pointer to encoder attention mask (array of size num_tokens). */
int32_t* encoder_pos; /**< Pointer to encoder token positions (array of size num_tokens). */
int num_tokens; /**< Number of tokens in the encoder sequence. */
} RKLLMCrossAttnParam;
/**
* @struct RKLLMInferParam
* @brief Structure for defining parameters during inference.
*/
typedef struct {
RKLLMInferMode mode; /**< Inference mode (e.g., generate or get last hidden layer). */
RKLLMLoraParam* lora_params; /**< Pointer to Lora adapter parameters. */
RKLLMPromptCacheParam* prompt_cache_params; /**< Pointer to prompt cache parameters. */
int keep_history; /**Flag to determine history retention (1: keep history, 0: discard history).*/
} RKLLMInferParam;
/**
* @struct RKLLMResultLastHiddenLayer
* @brief Structure to hold the hidden states from the last layer.
*/
typedef struct {
const float* hidden_states; /**< Pointer to the hidden states (of size num_tokens * embd_size). */
int embd_size; /**< Size of the embedding vector. */
int num_tokens; /**< Number of tokens for which hidden states are stored. */
} RKLLMResultLastHiddenLayer;
/**
* @struct RKLLMResultLogits
* @brief Structure to hold the logits.
*/
typedef struct {
const float* logits; /**< Pointer to the logits (of size num_tokens * vocab_size). */
int vocab_size; /**< Size of the vocab. */
int num_tokens; /**< Number of tokens for which logits are stored. */
} RKLLMResultLogits;
/**
* @struct RKLLMPerfStat
* @brief Structure to hold performance statistics for prefill and generate stages.
*/
typedef struct {
float prefill_time_ms; /**< Total time taken for the prefill stage in milliseconds. */
int prefill_tokens; /**< Number of tokens processed during the prefill stage. */
float generate_time_ms; /**< Total time taken for the generate stage in milliseconds. */
int generate_tokens; /**< Number of tokens processed during the generate stage. */
float memory_usage_mb; /**< VmHWM resident memory usage during inference, in megabytes. */
} RKLLMPerfStat;
/**
* @struct RKLLMResult
* @brief Structure to represent the result of LLM inference.
*/
typedef struct {
const char* text; /**< Generated text result. */
int32_t token_id; /**< ID of the generated token. */
RKLLMResultLastHiddenLayer last_hidden_layer; /**< Hidden states of the last layer (if requested). */
RKLLMResultLogits logits; /**< Model output logits. */
RKLLMPerfStat perf; /**< Pointer to performance statistics (prefill and generate). */
} RKLLMResult;
/**
* @typedef LLMResultCallback
* @brief Callback function to handle LLM results.
* @param result Pointer to the LLM result.
* @param userdata Pointer to user data for the callback.
* @param state State of the LLM call (e.g., finished, error).
* @return int Return value indicating the handling status:
* - 0: Continue inference normally.
* - 1: Pause inference. If the user wants to modify or intervene in the result (e.g., editing output, injecting new prompt),
* return 1 to suspend the current inference. Later, call `rkllm_run` with updated content to resume inference.
*/
typedef int(*LLMResultCallback)(RKLLMResult* result, void* userdata, LLMCallState state);
/**
* @brief Creates a default RKLLMParam structure with preset values.
* @return A default RKLLMParam structure.
*/
RKLLMParam rkllm_createDefaultParam();
/**
* @brief Initializes the LLM with the given parameters.
* @param handle Pointer to the LLM handle.
* @param param Configuration parameters for the LLM.
* @param callback Callback function to handle LLM results.
* @return Status code (0 for success, non-zero for failure).
*/
int rkllm_init(LLMHandle* handle, RKLLMParam* param, LLMResultCallback callback);
/**
* @brief Loads a Lora adapter into the LLM.
* @param handle LLM handle.
* @param lora_adapter Pointer to the Lora adapter structure.
* @return Status code (0 for success, non-zero for failure).
*/
int rkllm_load_lora(LLMHandle handle, RKLLMLoraAdapter* lora_adapter);
/**
* @brief Loads a prompt cache from a file.
* @param handle LLM handle.
* @param prompt_cache_path Path to the prompt cache file.
* @return Status code (0 for success, non-zero for failure).
*/
int rkllm_load_prompt_cache(LLMHandle handle, const char* prompt_cache_path);
/**
* @brief Releases the prompt cache from memory.
* @param handle LLM handle.
* @return Status code (0 for success, non-zero for failure).
*/
int rkllm_release_prompt_cache(LLMHandle handle);
/**
* @brief Destroys the LLM instance and releases resources.
* @param handle LLM handle.
* @return Status code (0 for success, non-zero for failure).
*/
int rkllm_destroy(LLMHandle handle);
/**
* @brief Runs an LLM inference task synchronously.
* @param handle LLM handle.
* @param rkllm_input Input data for the LLM.
* @param rkllm_infer_params Parameters for the inference task.
* @param userdata Pointer to user data for the callback.
* @return Status code (0 for success, non-zero for failure).
*/
int rkllm_run(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
/**
* @brief Runs an LLM inference task asynchronously.
* @param handle LLM handle.
* @param rkllm_input Input data for the LLM.
* @param rkllm_infer_params Parameters for the inference task.
* @param userdata Pointer to user data for the callback.
* @return Status code (0 for success, non-zero for failure).
*/
int rkllm_run_async(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
/**
* @brief Aborts an ongoing LLM task.
* @param handle LLM handle.
* @return Status code (0 for success, non-zero for failure).
*/
int rkllm_abort(LLMHandle handle);
/**
* @brief Checks if an LLM task is currently running.
* @param handle LLM handle.
* @return Status code (0 if a task is running, non-zero for otherwise).
*/
int rkllm_is_running(LLMHandle handle);
/**
* @brief Clear the key-value cache for a given LLM handle.
*
* This function is used to clear part or all of the KV cache.
*
* @param handle LLM handle.
* @param keep_system_prompt Flag indicating whether to retain the system prompt in the cache (1 to retain, 0 to clear).
* This flag is ignored if a specific range [start_pos, end_pos) is provided.
* @param start_pos Array of start positions (inclusive) of the KV cache ranges to clear, one per batch.
* @param end_pos Array of end positions (exclusive) of the KV cache ranges to clear, one per batch.
* If both start_pos and end_pos are set to nullptr, the entire cache will be cleared and keep_system_prompt will take effect,
* If start_pos[i] < end_pos[i], only the specified range will be cleared, and keep_system_prompt will be ignored.
* @note: start_pos or end_pos is only valid when keep_history == 0 and the generation has been paused by returning 1 in the callback
* @return Status code (0 if cache was cleared successfully, non-zero otherwise).
*/
int rkllm_clear_kv_cache(LLMHandle handle, int keep_system_prompt, int* start_pos, int* end_pos);
/**
* @brief Get the current size of the key-value cache for a given LLM handle.
*
* This function returns the total number of positions currently stored in the model's KV cache.
*
* @param handle LLM handle.
* @param cache_sizes Pointer to an array where the per-batch cache sizes will be stored.
* The array must be preallocated with space for `n_batch` elements.
*/
int rkllm_get_kv_cache_size(LLMHandle handle, int* cache_sizes);
/**
* @brief Sets the chat template for the LLM, including system prompt, prefix, and postfix.
*
* This function allows you to customize the chat template by providing a system prompt, a prompt prefix, and a prompt postfix.
* The system prompt is typically used to define the behavior or context of the language model,
* while the prefix and postfix are used to format the user input and output respectively.
*
* @param handle LLM handle.
* @param system_prompt The system prompt that defines the context or behavior of the language model.
* @param prompt_prefix The prefix added before the user input in the chat.
* @param prompt_postfix The postfix added after the user input in the chat.
*
* @return Status code (0 if the template was set successfully, non-zero for errors).
*/
int rkllm_set_chat_template(LLMHandle handle, const char* system_prompt, const char* prompt_prefix, const char* prompt_postfix);
/**
* @brief Sets the function calling configuration for the LLM, including system prompt, tool definitions, and tool response token.
*
* @param handle LLM handle.
* @param system_prompt The system prompt that defines the context or behavior of the language model.
* @param tools A JSON-formatted string that defines the available functions, including their names, descriptions, and parameters.
* @param tool_response_str A unique tag used to identify function call results within a conversation. It acts as the marker tag,
* allowing tokenizer to recognize tool outputs separately from normal dialogue turns.
* @return Status code (0 if the configuration was set successfully, non-zero for errors).
*/
int rkllm_set_function_tools(LLMHandle handle, const char* system_prompt, const char* tools, const char* tool_response_str);
/**
* @brief Sets the cross-attention parameters for the LLM decoder.
*
* @param handle LLM handle.
* @param cross_attn_params Pointer to the structure containing encoder-related input data
* used for cross-attention (see RKLLMCrossAttnParam for details).
*
* @return Status code (0 if the parameters were set successfully, non-zero for errors).
*/
int rkllm_set_cross_attn_params(LLMHandle handle, RKLLMCrossAttnParam* cross_attn_params);
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,409 @@
#ifndef _RKLLM_H_
#define _RKLLM_H_
#include <cstdint>
#ifdef __cplusplus
extern "C" {
#endif
#define CPU0 (1 << 0) // 0x01
#define CPU1 (1 << 1) // 0x02
#define CPU2 (1 << 2) // 0x04
#define CPU3 (1 << 3) // 0x08
#define CPU4 (1 << 4) // 0x10
#define CPU5 (1 << 5) // 0x20
#define CPU6 (1 << 6) // 0x40
#define CPU7 (1 << 7) // 0x80
/**
* @typedef LLMHandle
* @brief A handle used to manage and interact with the large language model.
*/
typedef void* LLMHandle;
/**
* @enum LLMCallState
* @brief Describes the possible states of an LLM call.
*/
typedef enum {
RKLLM_RUN_NORMAL = 0, /**< The LLM call is in a normal running state. */
RKLLM_RUN_WAITING = 1, /**< The LLM call is waiting for complete UTF-8 encoded character. */
RKLLM_RUN_FINISH = 2, /**< The LLM call has finished execution. */
RKLLM_RUN_ERROR = 3, /**< An error occurred during the LLM call. */
} LLMCallState;
/**
* @enum RKLLMInputType
* @brief Defines the types of inputs that can be fed into the LLM.
*/
typedef enum {
RKLLM_INPUT_PROMPT = 0, /**< Input is a text prompt. */
RKLLM_INPUT_TOKEN = 1, /**< Input is a sequence of tokens. */
RKLLM_INPUT_EMBED = 2, /**< Input is an embedding vector. */
RKLLM_INPUT_MULTIMODAL = 3, /**< Input is multimodal (e.g., text and image). */
} RKLLMInputType;
/**
* @enum RKLLMInferMode
* @brief Specifies the inference modes of the LLM.
*/
typedef enum {
RKLLM_INFER_GENERATE = 0, /**< The LLM generates text based on input. */
RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1, /**< The LLM retrieves the last hidden layer for further processing. */
RKLLM_INFER_GET_LOGITS = 2, /**< The LLM retrieves logits for further processing. */
} RKLLMInferMode;
/**
* @struct RKLLMExtendParam
* @brief The extend parameters for configuring an LLM instance.
*/
typedef struct {
int32_t base_domain_id; /**< base_domain_id */
int8_t embed_flash; /**< Indicates whether to query word embedding vectors from flash memory (1) or not (0). */
int8_t enabled_cpus_num; /**< Number of CPUs enabled for inference. */
uint32_t enabled_cpus_mask; /**< Bitmask indicating which CPUs to enable for inference. */
uint8_t n_batch; /**< Number of input samples processed concurrently in one forward pass. Set to >1 to enable batched inference. Default is 1. */
int8_t use_cross_attn; /**< Whether to enable cross attention (non-zero to enable, 0 to disable). */
uint8_t reserved[104]; /**< reserved */
} RKLLMExtendParam;
/**
* @struct RKLLMParam
* @brief Defines the parameters for configuring an LLM instance.
*/
typedef struct {
const char* model_path; /**< Path to the model file. */
int32_t max_context_len; /**< Maximum number of tokens in the context window. */
int32_t max_new_tokens; /**< Maximum number of new tokens to generate. */
int32_t top_k; /**< Top-K sampling parameter for token generation. */
int32_t n_keep; /** number of kv cache to keep at the beginning when shifting context window */
float top_p; /**< Top-P (nucleus) sampling parameter. */
float temperature; /**< Sampling temperature, affecting the randomness of token selection. */
float repeat_penalty; /**< Penalty for repeating tokens in generation. */
float frequency_penalty; /**< Penalizes frequent tokens during generation. */
float presence_penalty; /**< Penalizes tokens based on their presence in the input. */
int32_t mirostat; /**< Mirostat sampling strategy flag (0 to disable). */
float mirostat_tau; /**< Tau parameter for Mirostat sampling. */
float mirostat_eta; /**< Eta parameter for Mirostat sampling. */
bool skip_special_token; /**< Whether to skip special tokens during generation. */
bool is_async; /**< Whether to run inference asynchronously. */
const char* img_start; /**< Starting position of an image in multimodal input. */
const char* img_end; /**< Ending position of an image in multimodal input. */
const char* img_content; /**< Pointer to the image content. */
RKLLMExtendParam extend_param; /**< Extend parameters. */
} RKLLMParam;
/**
* @struct RKLLMLoraAdapter
* @brief Defines parameters for a Lora adapter used in model fine-tuning.
*/
typedef struct {
const char* lora_adapter_path; /**< Path to the Lora adapter file. */
const char* lora_adapter_name; /**< Name of the Lora adapter. */
float scale; /**< Scaling factor for applying the Lora adapter. */
} RKLLMLoraAdapter;
/**
* @struct RKLLMEmbedInput
* @brief Represents an embedding input to the LLM.
*/
typedef struct {
float* embed; /**< Pointer to the embedding vector (of size n_tokens * n_embed). */
size_t n_tokens; /**< Number of tokens represented in the embedding. */
} RKLLMEmbedInput;
/**
* @struct RKLLMTokenInput
* @brief Represents token input to the LLM.
*/
typedef struct {
int32_t* input_ids; /**< Array of token IDs. */
size_t n_tokens; /**< Number of tokens in the input. */
} RKLLMTokenInput;
/**
* @struct RKLLMMultiModalInput
* @brief Represents multimodal input (e.g., text and image).
*/
typedef struct {
char* prompt; /**< Text prompt input. */
float* image_embed; /**< Embedding of the images (of size n_image * n_image_tokens * image_embed_length). */
size_t n_image_tokens; /**< Number of image_token. */
size_t n_image; /**< Number of image. */
size_t image_width; /**< Width of image. */
size_t image_height; /**< Height of image. */
} RKLLMMultiModalInput;
/**
* @struct RKLLMInput
* @brief Represents different types of input to the LLM via a union.
*/
typedef struct {
const char* role; /**< Message role: "user" (user input), "tool" (function result) */
bool enable_thinking; /**< Controls whether "thinking mode" is enabled for the Qwen3 model. */
RKLLMInputType input_type; /**< Specifies the type of input provided (e.g., prompt, token, embed, multimodal). */
union {
const char* prompt_input; /**< Text prompt input if input_type is RKLLM_INPUT_PROMPT. */
RKLLMEmbedInput embed_input; /**< Embedding input if input_type is RKLLM_INPUT_EMBED. */
RKLLMTokenInput token_input; /**< Token input if input_type is RKLLM_INPUT_TOKEN. */
RKLLMMultiModalInput multimodal_input; /**< Multimodal input if input_type is RKLLM_INPUT_MULTIMODAL. */
};
} RKLLMInput;
/**
* @struct RKLLMLoraParam
* @brief Structure defining parameters for Lora adapters.
*/
typedef struct {
const char* lora_adapter_name; /**< Name of the Lora adapter. */
} RKLLMLoraParam;
/**
* @struct RKLLMPromptCacheParam
* @brief Structure to define parameters for caching prompts.
*/
typedef struct {
int save_prompt_cache; /**< Flag to indicate whether to save the prompt cache (0 = don't save, 1 = save). */
const char* prompt_cache_path; /**< Path to the prompt cache file. */
} RKLLMPromptCacheParam;
/**
* @struct RKLLMCrossAttnParam
* @brief Structure holding parameters for cross-attention inference.
*
* This structure is used when performing cross-attention in the decoder.
* It provides the encoder output (key/value caches), position indices,
* and attention mask.
*
* - `encoder_k_cache` must be stored in contiguous memory with layout:
* [num_layers][num_tokens][num_kv_heads][head_dim]
* - `encoder_v_cache` must be stored in contiguous memory with layout:
* [num_layers][num_kv_heads][head_dim][num_tokens]
*/
typedef struct {
float* encoder_k_cache; /**< Pointer to encoder key cache (size: num_layers * num_tokens * num_kv_heads * head_dim). */
float* encoder_v_cache; /**< Pointer to encoder value cache (size: num_layers * num_kv_heads * head_dim * num_tokens). */
float* encoder_mask; /**< Pointer to encoder attention mask (array of size num_tokens). */
int32_t* encoder_pos; /**< Pointer to encoder token positions (array of size num_tokens). */
int num_tokens; /**< Number of tokens in the encoder sequence. */
} RKLLMCrossAttnParam;
/**
* @struct RKLLMInferParam
* @brief Structure for defining parameters during inference.
*/
typedef struct {
RKLLMInferMode mode; /**< Inference mode (e.g., generate or get last hidden layer). */
RKLLMLoraParam* lora_params; /**< Pointer to Lora adapter parameters. */
RKLLMPromptCacheParam* prompt_cache_params; /**< Pointer to prompt cache parameters. */
int keep_history; /**Flag to determine history retention (1: keep history, 0: discard history).*/
} RKLLMInferParam;
/**
* @struct RKLLMResultLastHiddenLayer
* @brief Structure to hold the hidden states from the last layer.
*/
typedef struct {
const float* hidden_states; /**< Pointer to the hidden states (of size num_tokens * embd_size). */
int embd_size; /**< Size of the embedding vector. */
int num_tokens; /**< Number of tokens for which hidden states are stored. */
} RKLLMResultLastHiddenLayer;
/**
* @struct RKLLMResultLogits
* @brief Structure to hold the logits.
*/
typedef struct {
const float* logits; /**< Pointer to the logits (of size num_tokens * vocab_size). */
int vocab_size; /**< Size of the vocab. */
int num_tokens; /**< Number of tokens for which logits are stored. */
} RKLLMResultLogits;
/**
* @struct RKLLMPerfStat
* @brief Structure to hold performance statistics for prefill and generate stages.
*/
typedef struct {
float prefill_time_ms; /**< Total time taken for the prefill stage in milliseconds. */
int prefill_tokens; /**< Number of tokens processed during the prefill stage. */
float generate_time_ms; /**< Total time taken for the generate stage in milliseconds. */
int generate_tokens; /**< Number of tokens processed during the generate stage. */
float memory_usage_mb; /**< VmHWM resident memory usage during inference, in megabytes. */
} RKLLMPerfStat;
/**
* @struct RKLLMResult
* @brief Structure to represent the result of LLM inference.
*/
typedef struct {
const char* text; /**< Generated text result. */
int32_t token_id; /**< ID of the generated token. */
RKLLMResultLastHiddenLayer last_hidden_layer; /**< Hidden states of the last layer (if requested). */
RKLLMResultLogits logits; /**< Model output logits. */
RKLLMPerfStat perf; /**< Pointer to performance statistics (prefill and generate). */
} RKLLMResult;
/**
* @typedef LLMResultCallback
* @brief Callback function to handle LLM results.
* @param result Pointer to the LLM result.
* @param userdata Pointer to user data for the callback.
* @param state State of the LLM call (e.g., finished, error).
* @return int Return value indicating the handling status:
* - 0: Continue inference normally.
* - 1: Pause inference. If the user wants to modify or intervene in the result (e.g., editing output, injecting new prompt),
* return 1 to suspend the current inference. Later, call `rkllm_run` with updated content to resume inference.
*/
typedef int(*LLMResultCallback)(RKLLMResult* result, void* userdata, LLMCallState state);
/**
* @brief Creates a default RKLLMParam structure with preset values.
* @return A default RKLLMParam structure.
*/
RKLLMParam rkllm_createDefaultParam();
/**
* @brief Initializes the LLM with the given parameters.
* @param handle Pointer to the LLM handle.
* @param param Configuration parameters for the LLM.
* @param callback Callback function to handle LLM results.
* @return Status code (0 for success, non-zero for failure).
*/
int rkllm_init(LLMHandle* handle, RKLLMParam* param, LLMResultCallback callback);
/**
* @brief Loads a Lora adapter into the LLM.
* @param handle LLM handle.
* @param lora_adapter Pointer to the Lora adapter structure.
* @return Status code (0 for success, non-zero for failure).
*/
int rkllm_load_lora(LLMHandle handle, RKLLMLoraAdapter* lora_adapter);
/**
* @brief Loads a prompt cache from a file.
* @param handle LLM handle.
* @param prompt_cache_path Path to the prompt cache file.
* @return Status code (0 for success, non-zero for failure).
*/
int rkllm_load_prompt_cache(LLMHandle handle, const char* prompt_cache_path);
/**
* @brief Releases the prompt cache from memory.
* @param handle LLM handle.
* @return Status code (0 for success, non-zero for failure).
*/
int rkllm_release_prompt_cache(LLMHandle handle);
/**
* @brief Destroys the LLM instance and releases resources.
* @param handle LLM handle.
* @return Status code (0 for success, non-zero for failure).
*/
int rkllm_destroy(LLMHandle handle);
/**
* @brief Runs an LLM inference task synchronously.
* @param handle LLM handle.
* @param rkllm_input Input data for the LLM.
* @param rkllm_infer_params Parameters for the inference task.
* @param userdata Pointer to user data for the callback.
* @return Status code (0 for success, non-zero for failure).
*/
int rkllm_run(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
/**
* @brief Runs an LLM inference task asynchronously.
* @param handle LLM handle.
* @param rkllm_input Input data for the LLM.
* @param rkllm_infer_params Parameters for the inference task.
* @param userdata Pointer to user data for the callback.
* @return Status code (0 for success, non-zero for failure).
*/
int rkllm_run_async(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
/**
* @brief Aborts an ongoing LLM task.
* @param handle LLM handle.
* @return Status code (0 for success, non-zero for failure).
*/
int rkllm_abort(LLMHandle handle);
/**
* @brief Checks if an LLM task is currently running.
* @param handle LLM handle.
* @return Status code (0 if a task is running, non-zero for otherwise).
*/
int rkllm_is_running(LLMHandle handle);
/**
* @brief Clear the key-value cache for a given LLM handle.
*
* This function is used to clear part or all of the KV cache.
*
* @param handle LLM handle.
* @param keep_system_prompt Flag indicating whether to retain the system prompt in the cache (1 to retain, 0 to clear).
* This flag is ignored if a specific range [start_pos, end_pos) is provided.
* @param start_pos Array of start positions (inclusive) of the KV cache ranges to clear, one per batch.
* @param end_pos Array of end positions (exclusive) of the KV cache ranges to clear, one per batch.
* If both start_pos and end_pos are set to nullptr, the entire cache will be cleared and keep_system_prompt will take effect,
* If start_pos[i] < end_pos[i], only the specified range will be cleared, and keep_system_prompt will be ignored.
* @note: start_pos or end_pos is only valid when keep_history == 0 and the generation has been paused by returning 1 in the callback
* @return Status code (0 if cache was cleared successfully, non-zero otherwise).
*/
int rkllm_clear_kv_cache(LLMHandle handle, int keep_system_prompt, int* start_pos, int* end_pos);
/**
* @brief Get the current size of the key-value cache for a given LLM handle.
*
* This function returns the total number of positions currently stored in the model's KV cache.
*
* @param handle LLM handle.
* @param cache_sizes Pointer to an array where the per-batch cache sizes will be stored.
* The array must be preallocated with space for `n_batch` elements.
*/
int rkllm_get_kv_cache_size(LLMHandle handle, int* cache_sizes);
/**
* @brief Sets the chat template for the LLM, including system prompt, prefix, and postfix.
*
* This function allows you to customize the chat template by providing a system prompt, a prompt prefix, and a prompt postfix.
* The system prompt is typically used to define the behavior or context of the language model,
* while the prefix and postfix are used to format the user input and output respectively.
*
* @param handle LLM handle.
* @param system_prompt The system prompt that defines the context or behavior of the language model.
* @param prompt_prefix The prefix added before the user input in the chat.
* @param prompt_postfix The postfix added after the user input in the chat.
*
* @return Status code (0 if the template was set successfully, non-zero for errors).
*/
int rkllm_set_chat_template(LLMHandle handle, const char* system_prompt, const char* prompt_prefix, const char* prompt_postfix);
/**
* @brief Sets the function calling configuration for the LLM, including system prompt, tool definitions, and tool response token.
*
* @param handle LLM handle.
* @param system_prompt The system prompt that defines the context or behavior of the language model.
* @param tools A JSON-formatted string that defines the available functions, including their names, descriptions, and parameters.
* @param tool_response_str A unique tag used to identify function call results within a conversation. It acts as the marker tag,
* allowing tokenizer to recognize tool outputs separately from normal dialogue turns.
* @return Status code (0 if the configuration was set successfully, non-zero for errors).
*/
int rkllm_set_function_tools(LLMHandle handle, const char* system_prompt, const char* tools, const char* tool_response_str);
/**
* @brief Sets the cross-attention parameters for the LLM decoder.
*
* @param handle LLM handle.
* @param cross_attn_params Pointer to the structure containing encoder-related input data
* used for cross-attention (see RKLLMCrossAttnParam for details).
*
* @return Status code (0 if the parameters were set successfully, non-zero for errors).
*/
int rkllm_set_cross_attn_params(LLMHandle handle, RKLLMCrossAttnParam* cross_attn_params);
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -2,10 +2,15 @@ package com.digitalperson
import android.content.Intent
import android.os.Bundle
import android.util.Log
import androidx.appcompat.app.AppCompatActivity
import com.digitalperson.config.AppConfig
class EntryActivity : AppCompatActivity() {
companion object {
private const val TAG = "EntryActivity"
}
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
@@ -14,6 +19,7 @@ class EntryActivity : AppCompatActivity() {
} else {
MainActivity::class.java
}
Log.i(TAG, "USE_LIVE2D=${AppConfig.Avatar.USE_LIVE2D}, target=${target.simpleName}")
startActivity(Intent(this, target))
finish()
}

View File

@@ -2,20 +2,37 @@ package com.digitalperson
import android.Manifest
import android.content.pm.PackageManager
import android.graphics.Bitmap
import android.os.Bundle
import android.util.Log
import android.widget.Toast
import androidx.camera.core.CameraSelector
import androidx.camera.core.ImageAnalysis
import androidx.camera.core.ImageProxy
import androidx.camera.core.Preview
import androidx.camera.lifecycle.ProcessCameraProvider
import androidx.camera.view.PreviewView
import androidx.appcompat.app.AppCompatActivity
import androidx.core.app.ActivityCompat
import androidx.core.content.ContextCompat
import com.digitalperson.cloud.CloudApiManager
import com.digitalperson.audio.AudioProcessor
import com.digitalperson.vad.VadManager
import com.digitalperson.asr.AsrManager
import com.digitalperson.tts.TtsManager
import com.digitalperson.ui.Live2DUiManager
import com.digitalperson.config.AppConfig
import com.digitalperson.face.FaceDetectionPipeline
import com.digitalperson.face.FaceOverlayView
import com.digitalperson.face.ImageProxyBitmapConverter
import com.digitalperson.metrics.TraceManager
import com.digitalperson.metrics.TraceSession
import com.digitalperson.tts.TtsController
import com.digitalperson.llm.LLMManager
import com.digitalperson.llm.LLMManagerCallback
import com.digitalperson.util.FileHelper
import java.io.File
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
@@ -26,14 +43,24 @@ import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
class Live2DChatActivity : AppCompatActivity() {
companion object {
private const val TAG_ACTIVITY = "Live2DChatActivity"
private const val TAG_LLM = "LLM_ROUTE"
}
private lateinit var uiManager: Live2DUiManager
private lateinit var vadManager: VadManager
private lateinit var asrManager: AsrManager
private lateinit var ttsManager: TtsManager
private lateinit var ttsController: TtsController
private lateinit var audioProcessor: AudioProcessor
private var llmManager: LLMManager? = null
private var useLocalLLM = false // 默认使用云端 LLM
private val permissions: Array<String> = arrayOf(Manifest.permission.RECORD_AUDIO)
private val appPermissions: Array<String> = arrayOf(
Manifest.permission.RECORD_AUDIO,
Manifest.permission.CAMERA
)
private val micPermissions: Array<String> = arrayOf(Manifest.permission.RECORD_AUDIO)
@Volatile
private var isRecording: Boolean = false
@@ -55,23 +82,46 @@ class Live2DChatActivity : AppCompatActivity() {
@Volatile private var llmInFlight: Boolean = false
private var enableStreaming = false
private lateinit var cameraPreviewView: PreviewView
private lateinit var faceOverlayView: FaceOverlayView
private lateinit var faceDetectionPipeline: FaceDetectionPipeline
private var facePipelineReady: Boolean = false
private var cameraProvider: ProcessCameraProvider? = null
private lateinit var cameraAnalyzerExecutor: ExecutorService
override fun onRequestPermissionsResult(
requestCode: Int,
permissions: Array<String>,
grantResults: IntArray
) {
super.onRequestPermissionsResult(requestCode, permissions, grantResults)
val ok = requestCode == AppConfig.REQUEST_RECORD_AUDIO_PERMISSION &&
grantResults.isNotEmpty() &&
grantResults[0] == PackageManager.PERMISSION_GRANTED
if (!ok) {
if (requestCode != AppConfig.REQUEST_RECORD_AUDIO_PERMISSION) return
if (grantResults.isEmpty()) {
finish()
return
}
val granted = permissions.zip(grantResults.toTypedArray()).associate { it.first to it.second }
val micGranted = granted[Manifest.permission.RECORD_AUDIO] == PackageManager.PERMISSION_GRANTED
val cameraGranted = granted[Manifest.permission.CAMERA] == PackageManager.PERMISSION_GRANTED
if (!micGranted) {
Log.e(AppConfig.TAG, "Audio record is disallowed")
finish()
return
}
if (!cameraGranted) {
uiManager.showToast("未授予相机权限,暂不启用人脸检测")
Log.w(AppConfig.TAG, "Camera permission denied")
return
}
if (facePipelineReady) {
startCameraPreviewAndDetection()
}
}
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
Log.i(TAG_ACTIVITY, "onCreate")
setContentView(R.layout.activity_live2d_chat)
uiManager = Live2DUiManager(this)
@@ -82,10 +132,28 @@ class Live2DChatActivity : AppCompatActivity() {
stopButtonId = R.id.stop_button,
recordButtonId = R.id.record_button,
traditionalButtonsId = R.id.traditional_buttons,
llmModeSwitchId = R.id.llm_mode_switch,
llmModeSwitchRowId = R.id.llm_mode_switch_row,
silentPlayerViewId = 0,
speakingPlayerViewId = 0,
live2dViewId = R.id.live2d_view
)
cameraPreviewView = findViewById(R.id.camera_preview)
cameraPreviewView.implementationMode = PreviewView.ImplementationMode.COMPATIBLE
faceOverlayView = findViewById(R.id.face_overlay)
cameraAnalyzerExecutor = Executors.newSingleThreadExecutor()
faceDetectionPipeline = FaceDetectionPipeline(
context = applicationContext,
onResult = { result ->
faceOverlayView.updateResult(result)
},
onGreeting = { greeting ->
uiManager.appendToUi("\n[Face] $greeting\n")
ttsController.enqueueSegment(greeting)
ttsController.enqueueEnd()
}
)
// 根据配置选择交互方式
uiManager.setUseHoldToSpeak(AppConfig.USE_HOLD_TO_SPEAK)
@@ -105,7 +173,7 @@ class Live2DChatActivity : AppCompatActivity() {
uiManager.setStopButtonListener { onStopClicked(userInitiated = true) }
}
ActivityCompat.requestPermissions(this, permissions, AppConfig.REQUEST_RECORD_AUDIO_PERMISSION)
ActivityCompat.requestPermissions(this, appPermissions, AppConfig.REQUEST_RECORD_AUDIO_PERMISSION)
try {
val streamingSwitch = findViewById<android.widget.Switch>(R.id.streaming_switch)
@@ -119,6 +187,27 @@ class Live2DChatActivity : AppCompatActivity() {
Log.w(AppConfig.TAG, "Streaming switch not found in layout: ${e.message}")
}
try {
val ttsModeSwitch = findViewById<android.widget.Switch>(R.id.tts_mode_switch)
ttsModeSwitch.isChecked = false // 默认使用本地TTS
ttsModeSwitch.setOnCheckedChangeListener { _, isChecked ->
ttsController.setUseQCloudTts(isChecked)
uiManager.showToast("TTS模式已切换到${if (isChecked) "腾讯云" else "本地"}")
}
} catch (e: Exception) {
Log.w(AppConfig.TAG, "TTS mode switch not found in layout: ${e.message}")
}
// 设置 LLM 模式开关
uiManager.setLLMSwitchListener { isChecked ->
useLocalLLM = isChecked
Log.i(TAG_LLM, "LLM mode switched: useLocalLLM=$useLocalLLM")
uiManager.showToast("LLM模式已切换到${if (isChecked) "本地" else "云端"}")
// 重新初始化 LLM
initLLM()
}
// 默认不显示 LLM 开关,等模型下载完成后再显示
if (AppConfig.USE_HOLD_TO_SPEAK) {
uiManager.setButtonsEnabled(recordEnabled = false)
} else {
@@ -127,8 +216,8 @@ class Live2DChatActivity : AppCompatActivity() {
uiManager.setText("初始化中…")
audioProcessor = AudioProcessor(this)
ttsManager = TtsManager(this)
ttsManager.setCallback(createTtsCallback())
ttsController = TtsController(this)
ttsController.setCallback(createTtsCallback())
asrManager = AsrManager(this)
asrManager.setAudioProcessor(audioProcessor)
@@ -137,6 +226,64 @@ class Live2DChatActivity : AppCompatActivity() {
vadManager = VadManager(this)
vadManager.setCallback(createVadCallback())
// 初始化 LLM 管理器
initLLM()
// 检查是否需要下载模型
if (!FileHelper.isLocalLLMAvailable(this)) {
// 显示下载进度对话框
uiManager.showDownloadProgressDialog()
// 异步下载模型文件
FileHelper.downloadModelFilesWithProgress(
this,
onProgress = { fileName, downloaded, total, progress ->
runOnUiThread {
val downloadedMB = downloaded / (1024 * 1024)
val totalMB = total / (1024 * 1024)
uiManager.updateDownloadProgress(
fileName,
downloadedMB,
totalMB,
progress
)
}
},
onComplete = { success, message ->
runOnUiThread {
uiManager.dismissDownloadProgressDialog()
if (success) {
Log.i(AppConfig.TAG, "Model files downloaded successfully")
uiManager.showToast("模型下载完成", Toast.LENGTH_SHORT)
// 检查本地 LLM 是否可用
if (FileHelper.isLocalLLMAvailable(this)) {
Log.i(AppConfig.TAG, "Local LLM is available, enabling local LLM switch")
// 显示本地 LLM 开关,并同步状态
uiManager.showLLMSwitch(true)
uiManager.setLLMSwitchChecked(useLocalLLM)
}
} else {
Log.e(AppConfig.TAG, "Failed to download model files: $message")
uiManager.showToast("模型下载失败: $message", Toast.LENGTH_LONG)
}
// 下载完成后初始化其他组件
initializeOtherComponents()
}
}
)
} else {
// 模型已存在,直接初始化其他组件
initializeOtherComponents()
// 显示本地 LLM 开关,并同步状态
uiManager.showLLMSwitch(true)
uiManager.setLLMSwitchChecked(useLocalLLM)
}
}
/**
* 初始化其他组件VAD、ASR、TTS、人脸检测等
*/
private fun initializeOtherComponents() {
ioScope.launch {
try {
Log.i(AppConfig.TAG, "Init VAD + SenseVoice(RKNN) + TTS (background)")
@@ -144,7 +291,8 @@ class Live2DChatActivity : AppCompatActivity() {
vadManager.initVadModel()
asrManager.initSenseVoiceModel()
}
val ttsOk = ttsManager.initTtsAndAudioTrack()
val ttsOk = ttsController.init()
facePipelineReady = faceDetectionPipeline.initialize()
withContext(Dispatchers.Main) {
if (!ttsOk) {
uiManager.showToast(
@@ -152,6 +300,11 @@ class Live2DChatActivity : AppCompatActivity() {
Toast.LENGTH_LONG
)
}
if (!facePipelineReady) {
uiManager.showToast("RetinaFace 初始化失败,请检查模型和 rknn 运行库", Toast.LENGTH_LONG)
} else if (allPermissionsGranted()) {
startCameraPreviewAndDetection()
}
uiManager.setText(getString(R.string.hint))
if (AppConfig.USE_HOLD_TO_SPEAK) {
uiManager.setButtonsEnabled(recordEnabled = true)
@@ -203,14 +356,22 @@ class Live2DChatActivity : AppCompatActivity() {
Log.d(AppConfig.TAG, "ASR segment skipped: $reason")
}
override fun shouldSkipAsr(): Boolean = ttsManager.isPlaying()
override fun shouldSkipAsr(): Boolean = ttsController.isPlaying()
override fun isLlmInFlight(): Boolean = llmInFlight
override fun onLlmCalled(text: String) {
llmInFlight = true
Log.d(AppConfig.TAG, "Calling LLM with text: $text")
cloudApiManager.callLLM(text)
if (useLocalLLM) {
Log.i(TAG_LLM, "Routing to LOCAL LLM")
// 使用本地 LLM 生成回复
generateResponse(text)
} else {
Log.i(TAG_LLM, "Routing to CLOUD LLM")
// 使用云端 LLM 生成回复
cloudApiManager.callLLM(text)
}
}
}
@@ -220,7 +381,7 @@ class Live2DChatActivity : AppCompatActivity() {
asrManager.enqueueAudioSegment(originalAudio, processedAudio)
}
override fun shouldSkipProcessing(): Boolean = ttsManager.isPlaying() || llmInFlight
override fun shouldSkipProcessing(): Boolean = ttsController.isPlaying() || llmInFlight
}
private fun createCloudApiListener() = object : CloudApiManager.CloudApiListener {
@@ -232,9 +393,9 @@ class Live2DChatActivity : AppCompatActivity() {
if (enableStreaming) {
for (seg in segmenter.flush()) {
ttsManager.enqueueSegment(seg)
ttsController.enqueueSegment(seg)
}
ttsManager.enqueueEnd()
ttsController.enqueueEnd()
} else {
val previousMood = com.digitalperson.mood.MoodManager.getCurrentMood()
val (filteredText, mood) = com.digitalperson.mood.MoodManager.extractAndFilterMood(response)
@@ -247,8 +408,8 @@ class Live2DChatActivity : AppCompatActivity() {
runOnUiThread {
uiManager.appendToUi("${filteredText}\n")
}
ttsManager.enqueueSegment(filteredText)
ttsManager.enqueueEnd()
ttsController.enqueueSegment(filteredText)
ttsController.enqueueEnd()
}
}
@@ -271,7 +432,7 @@ class Live2DChatActivity : AppCompatActivity() {
val segments = segmenter.processChunk(filteredText)
for (seg in segments) {
ttsManager.enqueueSegment(seg)
ttsController.enqueueSegment(seg)
}
}
}
@@ -285,7 +446,7 @@ class Live2DChatActivity : AppCompatActivity() {
}
}
private fun createTtsCallback() = object : TtsManager.TtsCallback {
private fun createTtsCallback() = object : TtsController.TtsCallback {
override fun onTtsStarted(text: String) {
runOnUiThread {
uiManager.appendToUi("\n[TTS] 开始合成...\n")
@@ -310,32 +471,6 @@ class Live2DChatActivity : AppCompatActivity() {
uiManager.setSpeaking(speaking)
}
override fun getCurrentTrace(): TraceSession? = currentTrace
override fun onTraceMarkTtsRequestEnqueued() {
currentTrace?.markTtsRequestEnqueued()
}
override fun onTraceMarkTtsSynthesisStart() {
currentTrace?.markTtsSynthesisStart()
}
override fun onTraceMarkTtsFirstPcmReady() {
currentTrace?.markTtsFirstPcmReady()
}
override fun onTraceMarkTtsFirstAudioPlay() {
currentTrace?.markTtsFirstAudioPlay()
}
override fun onTraceMarkTtsDone() {
currentTrace?.markTtsDone()
}
override fun onTraceAddDuration(name: String, value: Long) {
currentTrace?.addDuration(name, value)
}
override fun onEndTurn() {
TraceManager.getInstance().endTurn()
currentTrace = null
@@ -344,27 +479,97 @@ class Live2DChatActivity : AppCompatActivity() {
override fun onDestroy() {
super.onDestroy()
stopCameraPreviewAndDetection()
onStopClicked(userInitiated = false)
ioScope.cancel()
synchronized(nativeLock) {
try { vadManager.release() } catch (_: Throwable) {}
try { asrManager.release() } catch (_: Throwable) {}
}
try { ttsManager.release() } catch (_: Throwable) {}
try { faceDetectionPipeline.release() } catch (_: Throwable) {}
try { cameraAnalyzerExecutor.shutdown() } catch (_: Throwable) {}
try { ttsController.release() } catch (_: Throwable) {}
try { llmManager?.destroy() } catch (_: Throwable) {}
try { uiManager.release() } catch (_: Throwable) {}
try { audioProcessor.release() } catch (_: Throwable) {}
}
override fun onResume() {
super.onResume()
Log.i(TAG_ACTIVITY, "onResume")
uiManager.onResume()
if (facePipelineReady && allPermissionsGranted()) {
startCameraPreviewAndDetection()
}
}
override fun onPause() {
Log.i(TAG_ACTIVITY, "onPause")
stopCameraPreviewAndDetection()
uiManager.onPause()
super.onPause()
}
private fun allPermissionsGranted(): Boolean {
return appPermissions.all {
ContextCompat.checkSelfPermission(this, it) == PackageManager.PERMISSION_GRANTED
}
}
private fun startCameraPreviewAndDetection() {
val cameraProviderFuture = ProcessCameraProvider.getInstance(this)
cameraProviderFuture.addListener({
try {
val provider = cameraProviderFuture.get()
cameraProvider = provider
provider.unbindAll()
val preview = Preview.Builder().build().apply {
setSurfaceProvider(cameraPreviewView.surfaceProvider)
}
cameraPreviewView.scaleType = PreviewView.ScaleType.FIT_CENTER
val analyzer = ImageAnalysis.Builder()
.setBackpressureStrategy(ImageAnalysis.STRATEGY_KEEP_ONLY_LATEST)
.build()
analyzer.setAnalyzer(cameraAnalyzerExecutor) { imageProxy ->
analyzeCameraFrame(imageProxy)
}
val selector = CameraSelector.Builder()
.requireLensFacing(CameraSelector.LENS_FACING_FRONT)
.build()
provider.bindToLifecycle(this, selector, preview, analyzer)
} catch (t: Throwable) {
Log.e(AppConfig.TAG, "startCameraPreviewAndDetection failed: ${t.message}", t)
}
}, ContextCompat.getMainExecutor(this))
}
private fun stopCameraPreviewAndDetection() {
try {
cameraProvider?.unbindAll()
} catch (_: Throwable) {
} finally {
cameraProvider = null
}
}
private fun analyzeCameraFrame(imageProxy: ImageProxy) {
try {
val bitmap: Bitmap? = ImageProxyBitmapConverter.toBitmap(imageProxy)
if (bitmap != null) {
faceDetectionPipeline.submitFrame(bitmap)
}
} catch (t: Throwable) {
Log.w(AppConfig.TAG, "analyzeCameraFrame error: ${t.message}")
} finally {
imageProxy.close()
}
}
private fun onStartClicked() {
Log.d(AppConfig.TAG, "onStartClicked called")
if (isRecording) {
@@ -372,7 +577,7 @@ class Live2DChatActivity : AppCompatActivity() {
return
}
if (!audioProcessor.initMicrophone(permissions, AppConfig.REQUEST_RECORD_AUDIO_PERMISSION)) {
if (!audioProcessor.initMicrophone(micPermissions, AppConfig.REQUEST_RECORD_AUDIO_PERMISSION)) {
uiManager.showToast("麦克风初始化失败/无权限")
return
}
@@ -383,8 +588,7 @@ class Live2DChatActivity : AppCompatActivity() {
uiManager.clearText()
ttsManager.reset()
ttsManager.setCurrentTrace(currentTrace)
ttsController.reset()
segmenter.reset()
vadManager.reset()
@@ -409,12 +613,12 @@ class Live2DChatActivity : AppCompatActivity() {
}
// 如果TTS正在播放打断它
val interrupted = ttsManager.interruptForNewTurn()
val interrupted = ttsController.interruptForNewTurn()
if (interrupted) {
uiManager.appendToUi("\n[LOG] 已打断TTS播放\n")
}
if (!audioProcessor.initMicrophone(permissions, AppConfig.REQUEST_RECORD_AUDIO_PERMISSION)) {
if (!audioProcessor.initMicrophone(micPermissions, AppConfig.REQUEST_RECORD_AUDIO_PERMISSION)) {
uiManager.showToast("麦克风初始化失败/无权限")
return
}
@@ -427,7 +631,7 @@ class Live2DChatActivity : AppCompatActivity() {
// interruptForNewTurn() already prepared TTS state for next turn.
// Keep reset() only for non-interrupt entry points.
ttsManager.setCurrentTrace(currentTrace)
segmenter.reset()
// 启动按住说话的动作
@@ -479,7 +683,7 @@ class Live2DChatActivity : AppCompatActivity() {
recordingJob?.cancel()
recordingJob = null
ttsManager.stop()
ttsController.stop()
if (AppConfig.USE_HOLD_TO_SPEAK) {
uiManager.setButtonsEnabled(recordEnabled = true)
@@ -515,10 +719,10 @@ class Live2DChatActivity : AppCompatActivity() {
while (isRecording && ioScope.coroutineContext.isActive) {
loopCount++
if (loopCount % 100 == 0) {
Log.d(AppConfig.TAG, "processSamplesLoop running, loopCount=$loopCount, ttsPlaying=${ttsManager.isPlaying()}")
Log.d(AppConfig.TAG, "processSamplesLoop running, loopCount=$loopCount, ttsPlaying=${ttsController.isPlaying()}")
}
if (ttsManager.isPlaying()) {
if (ttsController.isPlaying()) {
if (vadManager.isInSpeech()) {
Log.d(AppConfig.TAG, "TTS playing, resetting VAD state")
vadManager.clearState()
@@ -546,11 +750,134 @@ class Live2DChatActivity : AppCompatActivity() {
}
val forced = segmenter.maybeForceByTime()
for (seg in forced) ttsManager.enqueueSegment(seg)
for (seg in forced) ttsController.enqueueSegment(seg)
}
vadManager.forceFinalize()
}
Log.d(AppConfig.TAG, "processSamplesLoop stopped")
}
/**
* 初始化 LLM 管理器
*/
private fun initLLM() {
try {
Log.i(TAG_LLM, "initLLM called, useLocalLLM=$useLocalLLM")
llmManager?.destroy()
llmManager = null
if (useLocalLLM) {
// // 本地 LLM 初始化前,先暂停/释放重模块
// Log.i(AppConfig.TAG, "Pausing camera and releasing face detection before LLM initialization")
// stopCameraPreviewAndDetection()
// try {
// faceDetectionPipeline.release()
// Log.i(AppConfig.TAG, "Face detection pipeline released")
// } catch (e: Exception) {
// Log.w(AppConfig.TAG, "Failed to release face detection pipeline: ${e.message}")
// }
// // 释放 VAD 管理器
// try {
// vadManager.release()
// Log.i(AppConfig.TAG, "VAD manager released")
// } catch (e: Exception) {
// Log.w(AppConfig.TAG, "Failed to release VAD manager: ${e.message}")
// }
val modelPath = FileHelper.getLLMModelPath(applicationContext)
if (!File(modelPath).exists()) {
throw IllegalStateException("RKLLM model file missing: $modelPath")
}
Log.i(AppConfig.TAG, "Initializing LLM with model path: $modelPath")
val localLlmResponseBuffer = StringBuilder()
llmManager = LLMManager(modelPath, object : LLMManagerCallback {
override fun onThinking(msg: String, finished: Boolean) {
// 处理思考过程
Log.d(TAG_LLM, "LOCAL onThinking finished=$finished msg=${msg.take(60)}")
runOnUiThread {
if (!finished && enableStreaming) {
uiManager.appendToUi("\n[LLM] 思考中: $msg\n")
}
}
}
override fun onResult(msg: String, finished: Boolean) {
// 处理生成结果
Log.d(TAG_LLM, "LOCAL onResult finished=$finished len=${msg.length}")
runOnUiThread {
if (!finished) {
localLlmResponseBuffer.append(msg)
if (enableStreaming) {
uiManager.appendToUi(msg)
}
} else {
val finalText = localLlmResponseBuffer.toString().trim()
localLlmResponseBuffer.setLength(0)
if (!enableStreaming && finalText.isNotEmpty()) {
uiManager.appendToUi("$finalText\n")
}
uiManager.appendToUi("\n\n[LLM] 生成完成\n")
llmInFlight = false
if (finalText.isNotEmpty()) {
ttsController.enqueueSegment(finalText)
ttsController.enqueueEnd()
} else {
Log.w(TAG_LLM, "LOCAL final text is empty, skip TTS enqueue")
}
}
}
}
})
Log.i(AppConfig.TAG, "LLM initialized successfully")
Log.i(TAG_LLM, "LOCAL LLM initialized")
} else {
// 使用云端 LLM不需要初始化本地 LLM
Log.i(AppConfig.TAG, "Using cloud LLM, skipping local LLM initialization")
Log.i(TAG_LLM, "CLOUD mode active")
}
} catch (e: Exception) {
Log.e(AppConfig.TAG, "Failed to initialize LLM: ${e.message}", e)
Log.e(TAG_LLM, "LOCAL init failed: ${e.message}", e)
useLocalLLM = false
runOnUiThread {
uiManager.setLLMSwitchChecked(false)
uiManager.showToast("LLM 初始化失败: ${e.message}", Toast.LENGTH_LONG)
uiManager.appendToUi("\n[错误] LLM 初始化失败: ${e.message}\n")
}
}
}
/**
* 使用 LLM 生成回复
*/
private fun generateResponse(userInput: String) {
try {
if (useLocalLLM) {
val systemPrompt = "你是一个友好的数字人助手,回答要简洁明了。"
Log.d(AppConfig.TAG, "Generating response for: $userInput")
val local = llmManager
if (local == null) {
Log.e(TAG_LLM, "LOCAL LLM manager is null, fallback to CLOUD")
cloudApiManager.callLLM(userInput)
return
}
Log.i(TAG_LLM, "LOCAL generateResponseWithSystem")
local.generateResponseWithSystem(systemPrompt, userInput)
} else {
// 使用云端 LLM
Log.d(AppConfig.TAG, "Using cloud LLM for response: $userInput")
Log.i(TAG_LLM, "CLOUD callLLM")
// 调用云端 LLM
cloudApiManager.callLLM(userInput)
}
} catch (e: Exception) {
Log.e(AppConfig.TAG, "Failed to generate response: ${e.message}", e)
Log.e(TAG_LLM, "generateResponse failed: ${e.message}", e)
runOnUiThread {
uiManager.appendToUi("\n\n[Error] LLM 生成失败: ${e.message}\n")
llmInFlight = false
}
}
}
}

View File

@@ -34,6 +34,25 @@ object AppConfig {
const val MAX_TEXT_LENGTH = 50
const val MODEL_DIR = "sensevoice_models"
}
object Face {
const val MODEL_DIR = "RetinaFace"
const val MODEL_NAME = "RetinaFace_mobile320.rknn"
const val INPUT_SIZE = 320
const val SCORE_THRESHOLD = 0.6f
const val NMS_THRESHOLD = 0.4f
const val TRACK_IOU_THRESHOLD = 0.45f
const val STABLE_MS = 1000L
const val FRONTAL_MIN_FACE_SIZE = 90f
const val FRONTAL_MAX_ASPECT_DIFF = 0.35f
}
object FaceRecognition {
const val MODEL_DIR = "Insightface"
const val MODEL_NAME = "ms1mv3_arcface_r18.rknn"
const val SIMILARITY_THRESHOLD = 0.5f
const val GREETING_COOLDOWN_MS = 6000L
}
object Audio {
const val GAIN_SMOOTHING_FACTOR = 0.1f
@@ -48,4 +67,10 @@ object AppConfig {
const val MODEL_DIR = "live2d_model/Haru_pro_jp"
const val MODEL_JSON = "haru_greeter_t05.model3.json"
}
object QCloud {
const val APP_ID = "1302849512" // 替换为你的腾讯云APP_ID
const val SECRET_ID = "AKIDbBdyBGE5oPuIGA1iDlDYlFallaJ0YODB" // 替换为你的腾讯云SECRET_ID
const val SECRET_KEY = "32vhIl9OQIRclmLjvuleLp9LLAnFVYEp" // 替换为你的腾讯云SECRET_KEY
}
}

View File

@@ -0,0 +1,79 @@
package com.digitalperson.engine;
import android.content.Context;
import android.graphics.Bitmap;
import android.util.Log;
import com.digitalperson.config.AppConfig;
import com.digitalperson.util.FileHelper;
import java.io.File;
public class ArcFaceEngineRKNN {
private static final String TAG = "ArcFaceEngineRKNN";
static {
try {
System.loadLibrary("rknnrt");
System.loadLibrary("sensevoiceEngine");
Log.d(TAG, "Loaded native libs for ArcFace RKNN");
} catch (UnsatisfiedLinkError e) {
Log.e(TAG, "Failed to load native libraries for ArcFace", e);
throw e;
}
}
private final long nativePtr;
private boolean initialized = false;
private boolean released = false;
public ArcFaceEngineRKNN() {
nativePtr = createEngineNative();
if (nativePtr == 0) {
throw new RuntimeException("Failed to create native ArcFace engine");
}
}
public boolean initialize(Context context) {
if (released) return false;
File modelDir = FileHelper.copyInsightFaceAssets(context);
File modelFile = new File(modelDir, AppConfig.FaceRecognition.MODEL_NAME);
int ret = initNative(nativePtr, modelFile.getAbsolutePath());
initialized = ret == 0;
if (!initialized) {
Log.e(TAG, "ArcFace init failed, code=" + ret + ", model=" + modelFile.getAbsolutePath());
}
return initialized;
}
public float[] extractEmbedding(Bitmap bitmap, float left, float top, float right, float bottom) {
Log.d(TAG, "extractEmbedding called: initialized=" + initialized + ", released=" + released + ", bitmap=" + (bitmap != null));
if (!initialized || released || bitmap == null) {
Log.w(TAG, "extractEmbedding failed: initialized=" + initialized + ", released=" + released + ", bitmap=" + (bitmap != null));
return new float[0];
}
float[] emb = extractEmbeddingNative(nativePtr, bitmap, left, top, right, bottom);
Log.d(TAG, "extractEmbeddingNative returned: " + (emb != null ? emb.length : "null"));
return emb != null ? emb : new float[0];
}
public void release() {
if (!released && nativePtr != 0) {
releaseNative(nativePtr);
}
released = true;
initialized = false;
}
private native long createEngineNative();
private native int initNative(long ptr, String modelPath);
private native float[] extractEmbeddingNative(
long ptr,
Bitmap bitmap,
float left,
float top,
float right,
float bottom
);
private native void releaseNative(long ptr);
}

View File

@@ -0,0 +1,77 @@
package com.digitalperson.engine;
import android.content.Context;
import android.graphics.Bitmap;
import android.util.Log;
import com.digitalperson.config.AppConfig;
import com.digitalperson.util.FileHelper;
import java.io.File;
public class RetinaFaceEngineRKNN {
private static final String TAG = "RetinaFaceEngineRKNN";
static {
try {
System.loadLibrary("rknnrt");
System.loadLibrary("sensevoiceEngine");
Log.d(TAG, "Loaded native libs for RetinaFace RKNN");
} catch (UnsatisfiedLinkError e) {
Log.e(TAG, "Failed to load native libraries for RetinaFace", e);
throw e;
}
}
private final long nativePtr;
private boolean initialized = false;
private boolean released = false;
public RetinaFaceEngineRKNN() {
nativePtr = createEngineNative();
if (nativePtr == 0) {
throw new RuntimeException("Failed to create native RetinaFace engine");
}
}
public boolean initialize(Context context) {
if (released) {
return false;
}
File modelDir = FileHelper.copyRetinaFaceAssets(context);
File modelFile = new File(modelDir, AppConfig.Face.MODEL_NAME);
int ret = initNative(
nativePtr,
modelFile.getAbsolutePath(),
AppConfig.Face.INPUT_SIZE,
AppConfig.Face.SCORE_THRESHOLD,
AppConfig.Face.NMS_THRESHOLD
);
initialized = ret == 0;
if (!initialized) {
Log.e(TAG, "RetinaFace init failed, code=" + ret + ", model=" + modelFile.getAbsolutePath());
}
return initialized;
}
public float[] detect(Bitmap bitmap) {
if (!initialized || released || bitmap == null) {
return new float[0];
}
float[] raw = detectNative(nativePtr, bitmap);
return raw != null ? raw : new float[0];
}
public void release() {
if (!released && nativePtr != 0) {
releaseNative(nativePtr);
}
released = true;
initialized = false;
}
private native long createEngineNative();
private native int initNative(long ptr, String modelPath, int inputSize, float scoreThreshold, float nmsThreshold);
private native float[] detectNative(long ptr, Bitmap bitmap);
private native void releaseNative(long ptr);
}

View File

@@ -0,0 +1,223 @@
package com.digitalperson.face
import android.content.Context
import android.graphics.Bitmap
import android.util.Log
import com.digitalperson.config.AppConfig
import com.digitalperson.engine.RetinaFaceEngineRKNN
import java.util.concurrent.atomic.AtomicBoolean
import kotlin.math.abs
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.cancel
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
data class FaceBox(
val left: Float,
val top: Float,
val right: Float,
val bottom: Float,
val score: Float,
)
data class FaceDetectionResult(
val sourceWidth: Int,
val sourceHeight: Int,
val faces: List<FaceBox>,
)
class FaceDetectionPipeline(
private val context: Context,
private val onResult: (FaceDetectionResult) -> Unit,
private val onGreeting: (String) -> Unit,
) {
private val engine = RetinaFaceEngineRKNN()
private val recognizer = FaceRecognizer(context)
private val scope = CoroutineScope(SupervisorJob() + Dispatchers.Default)
private val frameInFlight = AtomicBoolean(false)
private val initialized = AtomicBoolean(false)
private var trackFace: FaceBox? = null
private var trackId: Long = 0
private var trackStableSinceMs: Long = 0
private var greetedTrackId: Long = -1
private var lastGreetMs: Long = 0
fun initialize(): Boolean {
val detectorOk = engine.initialize(context)
val recognizerOk = recognizer.initialize()
val ok = detectorOk && recognizerOk
initialized.set(ok)
Log.i(AppConfig.TAG, "Face pipeline initialize result=$ok detector=$detectorOk recognizer=$recognizerOk")
return ok
}
fun submitFrame(bitmap: Bitmap) {
if (!initialized.get()) {
bitmap.recycle()
return
}
if (!frameInFlight.compareAndSet(false, true)) {
bitmap.recycle()
return
}
scope.launch {
try {
val width = bitmap.width
val height = bitmap.height
val raw = engine.detect(bitmap)
val faceCount = raw.size / 5
val faces = ArrayList<FaceBox>(faceCount)
var i = 0
while (i + 4 < raw.size) {
faces.add(
FaceBox(
left = raw[i],
top = raw[i + 1],
right = raw[i + 2],
bottom = raw[i + 3],
score = raw[i + 4],
)
)
i += 5
}
// 过滤太小的人脸
val minFaceSize = 50 // 最小人脸大小(像素)
val filteredFaces = faces.filter { face ->
val width = face.right - face.left
val height = face.bottom - face.top
width > minFaceSize && height > minFaceSize
}
// if (filteredFaces.isNotEmpty()) {
// Log.d(
// AppConfig.TAG,"[Face] filtered detected ${filteredFaces.size} face(s)"
// )
// }
maybeRecognizeAndGreet(bitmap, filteredFaces)
withContext(Dispatchers.Main) {
onResult(FaceDetectionResult(width, height, filteredFaces))
}
} catch (t: Throwable) {
Log.e(AppConfig.TAG, "Face detection pipeline failed: ${t.message}", t)
} finally {
bitmap.recycle()
frameInFlight.set(false)
}
}
}
private suspend fun maybeRecognizeAndGreet(bitmap: Bitmap, faces: List<FaceBox>) {
val now = System.currentTimeMillis()
if (faces.isEmpty()) {
trackFace = null
trackStableSinceMs = 0
return
}
val primary = faces.maxByOrNull { (it.right - it.left) * (it.bottom - it.top) } ?: return
val prev = trackFace
if (prev == null || iou(prev, primary) < AppConfig.Face.TRACK_IOU_THRESHOLD) {
trackId += 1
greetedTrackId = -1
trackStableSinceMs = now
}
trackFace = primary
val stableMs = now - trackStableSinceMs
val frontal = isFrontal(primary, bitmap.width, bitmap.height)
val coolingDown = (now - lastGreetMs) < AppConfig.FaceRecognition.GREETING_COOLDOWN_MS
if (stableMs < AppConfig.Face.STABLE_MS || !frontal || greetedTrackId == trackId || coolingDown) {
return
}
val match = recognizer.identify(bitmap, primary)
Log.d(AppConfig.TAG, "[Face] Recognition result: matchedName=${match.matchedName}, similarity=${match.similarity}")
// 检查是否需要保存新人脸
if (match.matchedName.isNullOrBlank()) {
Log.d(AppConfig.TAG, "[Face] No match found, attempting to add new face")
// 提取人脸特征
val embedding = extractEmbedding(bitmap, primary)
Log.d(AppConfig.TAG, "[Face] Extracted embedding size: ${embedding.size}")
if (embedding.isNotEmpty()) {
// 尝试添加新人脸
val added = recognizer.addNewFace(embedding)
Log.d(AppConfig.TAG, "[Face] Add new face result: $added")
if (added) {
Log.i(AppConfig.TAG, "[Face] New face added to database")
} else {
Log.i(AppConfig.TAG, "[Face] Face already exists in database (similar face found)")
}
} else {
Log.w(AppConfig.TAG, "[Face] Failed to extract embedding")
}
} else {
Log.d(AppConfig.TAG, "[Face] Matched existing face: ${match.matchedName}")
}
val greeting = if (!match.matchedName.isNullOrBlank()) {
"你好,${match.matchedName}"
} else {
"你好,很高兴见到你。"
}
greetedTrackId = trackId
lastGreetMs = now
Log.i(
AppConfig.TAG,
"[Face] greeting track=$trackId stable=${stableMs}ms frontal=$frontal matched=${match.matchedName} score=${match.similarity}"
)
withContext(Dispatchers.Main) {
onGreeting(greeting)
}
}
private fun extractEmbedding(bitmap: Bitmap, face: FaceBox): FloatArray {
return recognizer.extractEmbedding(bitmap, face)
}
private fun isFrontal(face: FaceBox, frameW: Int, frameH: Int): Boolean {
val w = face.right - face.left
val h = face.bottom - face.top
if (w < AppConfig.Face.FRONTAL_MIN_FACE_SIZE || h < AppConfig.Face.FRONTAL_MIN_FACE_SIZE) {
return false
}
val aspectDiff = abs((w / h) - 1f)
if (aspectDiff > AppConfig.Face.FRONTAL_MAX_ASPECT_DIFF) {
return false
}
val cx = (face.left + face.right) * 0.5f
val cy = (face.top + face.bottom) * 0.5f
val minX = frameW * 0.15f
val maxX = frameW * 0.85f
val minY = frameH * 0.15f
val maxY = frameH * 0.85f
return cx in minX..maxX && cy in minY..maxY
}
private fun iou(a: FaceBox, b: FaceBox): Float {
val left = maxOf(a.left, b.left)
val top = maxOf(a.top, b.top)
val right = minOf(a.right, b.right)
val bottom = minOf(a.bottom, b.bottom)
val w = maxOf(0f, right - left)
val h = maxOf(0f, bottom - top)
val inter = w * h
val areaA = maxOf(0f, a.right - a.left) * maxOf(0f, a.bottom - a.top)
val areaB = maxOf(0f, b.right - b.left) * maxOf(0f, b.bottom - b.top)
val union = areaA + areaB - inter
return if (union <= 0f) 0f else inter / union
}
fun release() {
scope.cancel()
engine.release()
recognizer.release()
initialized.set(false)
}
}

View File

@@ -0,0 +1,93 @@
package com.digitalperson.face
import android.content.ContentValues
import android.content.Context
import android.database.sqlite.SQLiteDatabase
import android.database.sqlite.SQLiteOpenHelper
import android.util.Log
import com.digitalperson.config.AppConfig
import java.nio.ByteBuffer
import java.nio.ByteOrder
data class FaceProfile(
val id: Long,
val name: String,
val embedding: FloatArray,
)
class FaceFeatureStore(context: Context) : SQLiteOpenHelper(context, DB_NAME, null, DB_VERSION) {
override fun onCreate(db: SQLiteDatabase) {
db.execSQL(
"""
CREATE TABLE IF NOT EXISTS face_profiles (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE,
embedding BLOB NOT NULL,
updated_at INTEGER NOT NULL
)
""".trimIndent()
)
}
override fun onUpgrade(db: SQLiteDatabase, oldVersion: Int, newVersion: Int) {
db.execSQL("DROP TABLE IF EXISTS face_profiles")
onCreate(db)
}
fun loadAllProfiles(): List<FaceProfile> {
val db = readableDatabase
val list = ArrayList<FaceProfile>()
db.rawQuery("SELECT id, name, embedding FROM face_profiles", null).use { c ->
val idIdx = c.getColumnIndexOrThrow("id")
val nameIdx = c.getColumnIndexOrThrow("name")
val embIdx = c.getColumnIndexOrThrow("embedding")
while (c.moveToNext()) {
val embBlob = c.getBlob(embIdx) ?: continue
list.add(
FaceProfile(
id = c.getLong(idIdx),
name = c.getString(nameIdx),
embedding = blobToFloatArray(embBlob),
)
)
}
}
return list
}
fun upsertProfile(name: String, embedding: FloatArray) {
// 确保名字不为null使用空字符串作为默认值
val safeName = name.takeIf { it.isNotBlank() } ?: ""
val values = ContentValues().apply {
put("name", safeName)
put("embedding", floatArrayToBlob(embedding))
put("updated_at", System.currentTimeMillis())
}
val rowId = writableDatabase.insertWithOnConflict(
"face_profiles",
null,
values,
SQLiteDatabase.CONFLICT_REPLACE
)
Log.i(AppConfig.TAG, "[FaceFeatureStore] upsertProfile name='$safeName' rowId=$rowId dim=${embedding.size}")
}
private fun floatArrayToBlob(values: FloatArray): ByteArray {
val buf = ByteBuffer.allocate(values.size * 4).order(ByteOrder.LITTLE_ENDIAN)
for (v in values) buf.putFloat(v)
return buf.array()
}
private fun blobToFloatArray(blob: ByteArray): FloatArray {
if (blob.isEmpty()) return FloatArray(0)
val buf = ByteBuffer.wrap(blob).order(ByteOrder.LITTLE_ENDIAN)
val out = FloatArray(blob.size / 4)
for (i in out.indices) out[i] = buf.getFloat()
return out
}
companion object {
private const val DB_NAME = "face_feature.db"
private const val DB_VERSION = 1
}
}

View File

@@ -0,0 +1,61 @@
package com.digitalperson.face
import android.content.Context
import android.graphics.Canvas
import android.graphics.Color
import android.graphics.Paint
import android.graphics.RectF
import android.util.AttributeSet
import android.view.View
class FaceOverlayView @JvmOverloads constructor(
context: Context,
attrs: AttributeSet? = null,
) : View(context, attrs) {
private val boxPaint = Paint(Paint.ANTI_ALIAS_FLAG).apply {
color = Color.GREEN
style = Paint.Style.STROKE
strokeWidth = 4f
}
private val textPaint = Paint(Paint.ANTI_ALIAS_FLAG).apply {
color = Color.GREEN
textSize = 28f
}
@Volatile
private var latestResult: FaceDetectionResult? = null
fun updateResult(result: FaceDetectionResult) {
latestResult = result
postInvalidateOnAnimation()
}
override fun onDraw(canvas: Canvas) {
super.onDraw(canvas)
val result = latestResult ?: return
if (result.sourceWidth <= 0 || result.sourceHeight <= 0) return
val srcW = result.sourceWidth.toFloat()
val srcH = result.sourceHeight.toFloat()
val viewW = width.toFloat()
val viewH = height.toFloat()
if (viewW <= 0f || viewH <= 0f) return
val scale = minOf(viewW / srcW, viewH / srcH)
val dx = (viewW - srcW * scale) / 2f
val dy = (viewH - srcH * scale) / 2f
for (face in result.faces) {
val rect = RectF(
dx + face.left * scale,
dy + face.top * scale,
dx + face.right * scale,
dy + face.bottom * scale,
)
canvas.drawRect(rect, boxPaint)
canvas.drawText(String.format("%.2f", face.score), rect.left, rect.top - 8f, textPaint)
}
}
}

View File

@@ -0,0 +1,129 @@
package com.digitalperson.face
import android.content.Context
import android.graphics.Bitmap
import android.util.Log
import com.digitalperson.config.AppConfig
import com.digitalperson.engine.ArcFaceEngineRKNN
import kotlin.math.sqrt
data class FaceRecognitionResult(
val matchedName: String?,
val similarity: Float,
val embeddingDim: Int,
)
class FaceRecognizer(context: Context) {
private val appContext = context.applicationContext
private val engine = ArcFaceEngineRKNN()
private val store = FaceFeatureStore(appContext)
private val cache = ArrayList<FaceProfile>()
@Volatile
private var initialized = false
fun initialize(): Boolean {
Log.d(AppConfig.TAG, "[FaceRecognizer] initialize: starting...")
val ok = engine.initialize(appContext)
Log.d(AppConfig.TAG, "[FaceRecognizer] initialize: engine.initialize() returned $ok")
if (!ok) {
initialized = false
Log.e(AppConfig.TAG, "[FaceRecognizer] initialize: failed - engine initialization failed")
return false
}
cache.clear()
val profiles = store.loadAllProfiles()
cache.addAll(profiles)
initialized = true
Log.i(AppConfig.TAG, "[FaceRecognizer] initialized, profiles=${cache.size}")
return true
}
fun identify(bitmap: Bitmap, face: FaceBox): FaceRecognitionResult {
if (!initialized) return FaceRecognitionResult(null, 0f, 0)
val embedding = extractEmbedding(bitmap, face)
if (embedding.isEmpty()) return FaceRecognitionResult(null, 0f, 0)
var bestName: String? = null
var bestScore = -1f
for (p in cache) {
if (p.embedding.size != embedding.size) continue
val score = cosineSimilarity(embedding, p.embedding)
if (score > bestScore) {
bestScore = score
bestName = p.name
}
}
if (bestScore >= AppConfig.FaceRecognition.SIMILARITY_THRESHOLD) {
return FaceRecognitionResult(bestName, bestScore, embedding.size)
}
return FaceRecognitionResult(null, bestScore, embedding.size)
}
fun extractEmbedding(bitmap: Bitmap, face: FaceBox): FloatArray {
if (!initialized) return FloatArray(0)
return engine.extractEmbedding(bitmap, face.left, face.top, face.right, face.bottom)
}
fun addOrUpdateProfile(name: String?, embedding: FloatArray) {
val normalized = normalize(embedding)
store.upsertProfile(name ?: "", normalized)
// 移除旧的记录(如果存在)
if (name != null) {
cache.removeAll { it.name == name }
}
cache.add(FaceProfile(id = -1L, name = name ?: "", embedding = normalized))
}
fun addNewFace(embedding: FloatArray): Boolean {
Log.d(AppConfig.TAG, "[FaceRecognizer] addNewFace: embedding size=${embedding.size}, cache size=${cache.size}")
// 检查是否已经存在相似的人脸
for (p in cache) {
if (p.embedding.size != embedding.size) {
Log.d(AppConfig.TAG, "[FaceRecognizer] Skipping profile with different embedding size: ${p.embedding.size}")
continue
}
val score = cosineSimilarity(embedding, p.embedding)
Log.d(AppConfig.TAG, "[FaceRecognizer] Comparing with profile '${p.name}': similarity=$score, threshold=${AppConfig.FaceRecognition.SIMILARITY_THRESHOLD}")
if (score >= AppConfig.FaceRecognition.SIMILARITY_THRESHOLD) {
// 已经存在相似的人脸,不需要添加
Log.i(AppConfig.TAG, "[FaceRecognizer] Similar face found: ${p.name} with similarity=$score, not adding new face")
return false
}
}
// 添加新人脸名字为null
Log.i(AppConfig.TAG, "[FaceRecognizer] No similar face found, adding new face")
addOrUpdateProfile(null, embedding)
return true
}
fun release() {
initialized = false
engine.release()
store.close()
}
private fun cosineSimilarity(a: FloatArray, b: FloatArray): Float {
var dot = 0f
var na = 0f
var nb = 0f
for (i in a.indices) {
dot += a[i] * b[i]
na += a[i] * a[i]
nb += b[i] * b[i]
}
if (na <= 1e-12f || nb <= 1e-12f) return -1f
return (dot / (sqrt(na) * sqrt(nb))).coerceIn(-1f, 1f)
}
private fun normalize(v: FloatArray): FloatArray {
var sum = 0f
for (x in v) sum += x * x
val norm = sqrt(sum.coerceAtLeast(1e-12f))
val out = FloatArray(v.size)
for (i in v.indices) out[i] = v[i] / norm
return out
}
}

View File

@@ -0,0 +1,87 @@
package com.digitalperson.face
import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.graphics.ImageFormat
import android.graphics.Matrix
import android.graphics.Rect
import android.graphics.YuvImage
import androidx.camera.core.ImageProxy
import java.io.ByteArrayOutputStream
object ImageProxyBitmapConverter {
fun toBitmap(image: ImageProxy): Bitmap? {
val nv21 = yuv420ToNv21(image) ?: return null
val yuvImage = YuvImage(nv21, ImageFormat.NV21, image.width, image.height, null)
val out = ByteArrayOutputStream()
if (!yuvImage.compressToJpeg(Rect(0, 0, image.width, image.height), 80, out)) {
return null
}
val bytes = out.toByteArray()
var bitmap = BitmapFactory.decodeByteArray(bytes, 0, bytes.size) ?: return null
if (bitmap.config != Bitmap.Config.ARGB_8888) {
val converted = bitmap.copy(Bitmap.Config.ARGB_8888, false)
bitmap.recycle()
bitmap = converted
}
val matrix = Matrix()
// 前置摄像头需要水平翻转
// 注意:这里假设我们使用的是前置摄像头
// 如果需要支持后置摄像头,需要根据实际使用的摄像头类型来决定是否翻转
matrix.postScale(-1f, 1f, bitmap.width / 2f, bitmap.height / 2f)
// 处理旋转
val rotation = image.imageInfo.rotationDegrees
if (rotation != 0) {
matrix.postRotate(rotation.toFloat())
}
// 应用变换
val transformed = Bitmap.createBitmap(bitmap, 0, 0, bitmap.width, bitmap.height, matrix, true)
bitmap.recycle()
bitmap = transformed
return bitmap
}
private fun yuv420ToNv21(image: ImageProxy): ByteArray? {
val planes = image.planes
if (planes.size < 3) return null
val width = image.width
val height = image.height
val ySize = width * height
val uvSize = width * height / 4
val nv21 = ByteArray(ySize + uvSize * 2)
val yPlane = planes[0]
val yBuffer = yPlane.buffer
val yRowStride = yPlane.rowStride
var dst = 0
for (row in 0 until height) {
yBuffer.position(row * yRowStride)
yBuffer.get(nv21, dst, width)
dst += width
}
val uPlane = planes[1]
val vPlane = planes[2]
val uBuffer = uPlane.buffer
val vBuffer = vPlane.buffer
val uRowStride = uPlane.rowStride
val vRowStride = vPlane.rowStride
val uPixelStride = uPlane.pixelStride
val vPixelStride = vPlane.pixelStride
for (row in 0 until height / 2) {
for (col in 0 until width / 2) {
val uIndex = row * uRowStride + col * uPixelStride
val vIndex = row * vRowStride + col * vPixelStride
nv21[dst++] = vBuffer.get(vIndex)
nv21[dst++] = uBuffer.get(uIndex)
}
}
return nv21
}
}

View File

@@ -7,6 +7,7 @@ class Live2DAvatarManager(private val glSurfaceView: GLSurfaceView) {
init {
glSurfaceView.setEGLContextClientVersion(2)
glSurfaceView.setPreserveEGLContextOnPause(true)
glSurfaceView.setRenderer(renderer)
glSurfaceView.renderMode = GLSurfaceView.RENDERMODE_CONTINUOUSLY
}
@@ -16,11 +17,15 @@ class Live2DAvatarManager(private val glSurfaceView: GLSurfaceView) {
}
fun setMood(mood: String) {
renderer.setMood(mood)
glSurfaceView.queueEvent {
renderer.setMood(mood)
}
}
fun startSpecificMotion(motionName: String) {
renderer.startSpecificMotion(motionName)
glSurfaceView.queueEvent {
renderer.startSpecificMotion(motionName)
}
}
fun onResume() {
@@ -32,6 +37,8 @@ class Live2DAvatarManager(private val glSurfaceView: GLSurfaceView) {
}
fun release() {
renderer.release()
glSurfaceView.queueEvent {
renderer.release()
}
}
}

View File

@@ -214,32 +214,8 @@ class Live2DCharacter : CubismUserModel() {
}
private fun loadMoodMotions(assets: AssetManager, modelDir: String) {
// 开心心情动作
moodMotions["开心"] = listOf(
"haru_g_m22.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m22.motion3.json"),
"haru_g_m21.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m21.motion3.json"),
"haru_g_m18.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m18.motion3.json")
).mapNotNull { (fileName, motion) ->
motion?.let {
motionFileMap[it] = fileName
it
}
}
// 伤心心情动作
moodMotions["伤心"] = listOf(
"haru_g_m25.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m25.motion3.json"),
"haru_g_m24.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m24.motion3.json"),
"haru_g_m05.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m05.motion3.json")
).mapNotNull { (fileName, motion) ->
motion?.let {
motionFileMap[it] = fileName
it
}
}
// 平和心情动作
moodMotions["平和"] = listOf(
// 中性心情动作
moodMotions["中性"] = listOf(
"haru_g_m15.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m15.motion3.json"),
"haru_g_m07.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m07.motion3.json"),
"haru_g_m06.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m06.motion3.json"),
@@ -252,8 +228,50 @@ class Live2DCharacter : CubismUserModel() {
}
}
// 惊讶心情动作
moodMotions["惊讶"] = listOf(
// 悲伤心情动作
moodMotions["悲伤"] = listOf(
"haru_g_m25.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m25.motion3.json"),
"haru_g_m24.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m24.motion3.json"),
"haru_g_m05.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m05.motion3.json"),
"haru_g_m16.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m16.motion3.json"),
"haru_g_m20.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m20.motion3.json"),
).mapNotNull { (fileName, motion) ->
motion?.let {
motionFileMap[it] = fileName
it
}
}
// 高兴心情动作
moodMotions["高兴"] = listOf(
"haru_g_m22.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m22.motion3.json"),
"haru_g_m21.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m21.motion3.json"),
"haru_g_m18.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m18.motion3.json"),
"haru_g_m09.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m09.motion3.json"),
"haru_g_m08.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m08.motion3.json")
).mapNotNull { (fileName, motion) ->
motion?.let {
motionFileMap[it] = fileName
it
}
}
// 生气心情动作
moodMotions["生气"] = listOf(
"haru_g_m10.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m10.motion3.json"),
"haru_g_m11.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m11.motion3.json"),
"haru_g_m04.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m04.motion3.json"),
"haru_g_m03.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m03.motion3.json"),
).mapNotNull { (fileName, motion) ->
motion?.let {
motionFileMap[it] = fileName
it
}
}
// 恐惧心情动作
moodMotions["恐惧"] = listOf(
"haru_g_m26.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m26.motion3.json"),
"haru_g_m12.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m12.motion3.json")
).mapNotNull { (fileName, motion) ->
@@ -263,18 +281,8 @@ class Live2DCharacter : CubismUserModel() {
}
}
// 关心心情动作
moodMotions["关心"] = listOf(
"haru_g_m17.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m17.motion3.json")
).mapNotNull { (fileName, motion) ->
motion?.let {
motionFileMap[it] = fileName
it
}
}
// 害羞心情动作
moodMotions["害羞"] = listOf(
// 撒娇心情动作
moodMotions["撒娇"] = listOf(
"haru_g_m19.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m19.motion3.json")
).mapNotNull { (fileName, motion) ->
motion?.let {
@@ -282,6 +290,38 @@ class Live2DCharacter : CubismUserModel() {
it
}
}
// 震惊心情动作
moodMotions["震惊"] = listOf(
"haru_g_m26.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m26.motion3.json"),
"haru_g_m12.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m12.motion3.json")
).mapNotNull { (fileName, motion) ->
motion?.let {
motionFileMap[it] = fileName
it
}
}
// 厌恶心情动作
moodMotions["厌恶"] = listOf(
"haru_g_m14.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m14.motion3.json"),
"haru_g_m13.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m13.motion3.json")
).mapNotNull { (fileName, motion) ->
motion?.let {
motionFileMap[it] = fileName
it
}
}
// 兼容旧的心情名称
moodMotions["开心"] = moodMotions["高兴"] ?: emptyList()
moodMotions["伤心"] = moodMotions["悲伤"] ?: emptyList()
moodMotions["平和"] = moodMotions["平静"] ?: emptyList()
moodMotions["惊讶"] = moodMotions["震惊"] ?: emptyList()
moodMotions["关心"] = moodMotions["中性"] ?: emptyList()
moodMotions["害羞"] = moodMotions["撒娇"] ?: emptyList()
}
private fun loadSpecificMotions(assets: AssetManager, modelDir: String) {

View File

@@ -14,6 +14,9 @@ import javax.microedition.khronos.opengles.GL10
class Live2DRenderer(
private val context: Context
) : GLSurfaceView.Renderer {
companion object {
private const val TAG = "Live2DRenderer"
}
@Volatile
private var speaking = false
@@ -25,6 +28,7 @@ class Live2DRenderer(
GLES20.glClearColor(0f, 0f, 0f, 0f)
ensureFrameworkInitialized()
startTimeMs = SystemClock.elapsedRealtime()
Log.i(TAG, "onSurfaceCreated")
runCatching {
val model = Live2DCharacter()
@@ -35,6 +39,7 @@ class Live2DRenderer(
)
model.bindTextures(context.assets, AppConfig.Avatar.MODEL_DIR)
character = model
Log.i(TAG, "Live2D model loaded and textures bound")
}.onFailure {
Log.e(AppConfig.TAG, "Load Live2D model failed: ${it.message}", it)
character = null

View File

@@ -0,0 +1,46 @@
package com.digitalperson.llm
interface LLMManagerCallback {
fun onThinking(msg: String, finished: Boolean)
fun onResult(msg: String, finished: Boolean)
}
class LLMManager(modelPath: String, callback: LLMManagerCallback) :
RKLLM(modelPath, object : LLMCallback {
var inThinking = false
override fun onCallback(data: String, state: LLMCallback.State) {
if (state == LLMCallback.State.NORMAL) {
if (data == "<think>") {
inThinking = true
return
} else if (data == "</think>") {
inThinking = false
callback.onThinking("", true)
return
}
if (inThinking) {
callback.onThinking(data, false)
} else {
if (data == "\n") return
callback.onResult(data, false)
}
} else {
callback.onThinking("", true)
callback.onResult("", true)
}
}
})
{
fun generateResponse(prompt: String) {
val msg = "<User>$prompt<Assistant>"
say(msg)
}
fun generateResponseWithSystem(systemPrompt: String, userPrompt: String) {
val msg = "<System>$systemPrompt<User>$userPrompt<Assistant>"
say(msg)
}
}

View File

@@ -0,0 +1,52 @@
package com.digitalperson.llm
interface LLMCallback {
enum class State {
ERROR, NORMAL, FINISH
}
fun onCallback(data: String, state: State)
}
open class RKLLM(modelPath: String, callback: LLMCallback) {
companion object {
init {
System.loadLibrary("rkllmrt")
}
}
private var mInstance: Long
private var mCallback: LLMCallback
init {
mInstance = initLLM(modelPath)
mCallback = callback
if (mInstance == 0L) {
throw IllegalStateException("RKLLM init failed: native handle is null")
}
}
fun destroy() {
deinitLLM(mInstance)
mInstance = 0
}
protected fun say(text: String) {
if (mInstance == 0L) {
mCallback.onCallback("RKLLM is not initialized", LLMCallback.State.ERROR)
return
}
infer(mInstance, text)
}
fun callbackFromNative(data: String, state: Int) {
var s = LLMCallback.State.ERROR
s = if (state == 0) LLMCallback.State.FINISH
else if (state < 0) LLMCallback.State.ERROR
else LLMCallback.State.NORMAL
mCallback.onCallback(data, s)
}
private external fun initLLM(modelPath: String): Long
private external fun deinitLLM(handle: Long)
private external fun infer(handle: Long, text: String)
}

View File

@@ -0,0 +1,330 @@
package com.digitalperson.tts
import android.content.Context
import android.media.AudioAttributes
import android.media.AudioFormat
import android.media.AudioManager
import android.media.AudioTrack
import android.util.Log
import com.digitalperson.config.AppConfig
import com.digitalperson.mood.MoodManager
import com.tencent.cloud.realtime.tts.RealTimeSpeechSynthesizer
import com.tencent.cloud.realtime.tts.RealTimeSpeechSynthesizerListener
import com.tencent.cloud.realtime.tts.RealTimeSpeechSynthesizerRequest
import com.tencent.cloud.realtime.tts.SpeechSynthesizerResponse
import com.tencent.cloud.realtime.tts.core.ws.Credential
import com.tencent.cloud.realtime.tts.core.ws.SpeechClient
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import java.nio.ByteBuffer
import java.util.UUID
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.atomic.AtomicBoolean
class QCloudTtsManager(private val context: Context) {
companion object {
private const val TAG = "QCloudTtsManager"
private const val SAMPLE_RATE = 16000
private val proxy = SpeechClient()
}
private var audioTrack: AudioTrack? = null
private var synthesizer: RealTimeSpeechSynthesizer? = null
private sealed class TtsQueueItem {
data class Segment(val text: String) : TtsQueueItem()
data object End : TtsQueueItem()
}
private val ttsQueue = LinkedBlockingQueue<TtsQueueItem>()
private val ttsStopped = AtomicBoolean(false)
private val ttsWorkerRunning = AtomicBoolean(false)
private val ttsPlaying = AtomicBoolean(false)
private val interrupting = AtomicBoolean(false)
private val ioScope = CoroutineScope(Dispatchers.IO)
interface TtsCallback {
fun onTtsStarted(text: String)
fun onTtsCompleted()
fun onTtsSegmentCompleted(durationMs: Long)
fun isTtsStopped(): Boolean
fun onClearAsrQueue()
fun onSetSpeaking(speaking: Boolean)
fun onEndTurn()
}
private var callback: TtsCallback? = null
fun setCallback(callback: TtsCallback) {
this.callback = callback
}
fun init(): Boolean {
return try {
initAudioTrack()
true
} catch (e: Exception) {
Log.e(TAG, "Init QCloud TTS failed: ${e.message}", e)
false
}
}
private fun initAudioTrack() {
val bufferSize = AudioTrack.getMinBufferSize(
SAMPLE_RATE,
AudioFormat.CHANNEL_OUT_MONO,
AudioFormat.ENCODING_PCM_16BIT
)
val attr = AudioAttributes.Builder()
.setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
.setUsage(AudioAttributes.USAGE_MEDIA)
.build()
val format = AudioFormat.Builder()
.setEncoding(AudioFormat.ENCODING_PCM_16BIT)
.setChannelMask(AudioFormat.CHANNEL_OUT_MONO)
.setSampleRate(SAMPLE_RATE)
.build()
audioTrack = AudioTrack(
attr,
format,
bufferSize,
AudioTrack.MODE_STREAM,
AudioManager.AUDIO_SESSION_ID_GENERATE
)
}
fun enqueueSegment(seg: String) {
if (ttsStopped.get()) {
ttsStopped.set(false)
}
val cleanedSeg = seg.trimEnd('.', '。', '!', '', '?', '', ',', '', ';', '', ':', '')
ttsQueue.offer(TtsQueueItem.Segment(cleanedSeg))
ensureTtsWorker()
}
fun enqueueEnd() {
ttsQueue.offer(TtsQueueItem.End)
}
fun isPlaying(): Boolean = ttsPlaying.get()
fun reset() {
val workerRunning = ttsWorkerRunning.get()
val wasStopped = ttsStopped.get()
ttsStopped.set(false)
ttsPlaying.set(false)
ttsQueue.clear()
if (wasStopped && workerRunning) {
ttsQueue.offer(TtsQueueItem.End)
}
}
fun stop() {
ttsStopped.set(true)
ttsPlaying.set(false)
ttsQueue.clear()
ttsQueue.offer(TtsQueueItem.End)
try {
synthesizer?.cancel()
synthesizer = null
audioTrack?.pause()
audioTrack?.flush()
} catch (_: Throwable) {
}
}
fun interruptForNewTurn(waitTimeoutMs: Long = 300): Boolean {
if (!interrupting.compareAndSet(false, true)) return false
try {
val hadPendingPlayback = ttsPlaying.get() || ttsWorkerRunning.get() || ttsQueue.isNotEmpty()
if (!hadPendingPlayback) {
ttsStopped.set(false)
ttsPlaying.set(false)
return false
}
ttsStopped.set(true)
ttsPlaying.set(false)
ttsQueue.clear()
ttsQueue.offer(TtsQueueItem.End)
try {
synthesizer?.cancel()
synthesizer = null
audioTrack?.pause()
audioTrack?.flush()
} catch (_: Throwable) {
}
val deadline = System.currentTimeMillis() + waitTimeoutMs
while (ttsWorkerRunning.get() && System.currentTimeMillis() < deadline) {
Thread.sleep(10)
}
if (ttsWorkerRunning.get()) {
Log.w(TAG, "interruptForNewTurn timeout: worker still running")
}
ttsQueue.clear()
ttsStopped.set(false)
ttsPlaying.set(false)
callback?.onSetSpeaking(false)
return true
} finally {
interrupting.set(false)
}
}
fun release() {
try {
synthesizer?.cancel()
synthesizer = null
} catch (_: Throwable) {
}
try {
audioTrack?.release()
audioTrack = null
} catch (_: Throwable) {
}
}
private fun ensureTtsWorker() {
if (!ttsWorkerRunning.compareAndSet(false, true)) return
ioScope.launch {
try {
runTtsWorker()
} finally {
ttsWorkerRunning.set(false)
if (!ttsStopped.get() && ttsQueue.isNotEmpty()) {
ensureTtsWorker()
}
}
}
}
private fun runTtsWorker() {
val audioTrack = audioTrack ?: return
while (true) {
val item = ttsQueue.take()
if (ttsStopped.get()) break
when (item) {
is TtsQueueItem.Segment -> {
ttsPlaying.set(true)
callback?.onSetSpeaking(true)
Log.d(TAG, "QCloud TTS started: processing segment '${item.text}'")
callback?.onTtsStarted(item.text)
val startMs = System.currentTimeMillis()
try {
if (audioTrack.playState != AudioTrack.PLAYSTATE_PLAYING) {
audioTrack.play()
}
val credential = Credential(
AppConfig.QCloud.APP_ID,
AppConfig.QCloud.SECRET_ID,
AppConfig.QCloud.SECRET_KEY,
""
)
val request = RealTimeSpeechSynthesizerRequest()
request.setVolume(0f) // 音量大小,范围[-1010]
request.setSpeed(0f) // 语速,范围:[-26]
request.setCodec("pcm") // 返回音频格式pcm
request.setSampleRate(SAMPLE_RATE) // 音频采样率
request.setVoiceType(601010) // 音色ID
request.setEnableSubtitle(true) // 是否开启时间戳功能
// 根据当前心情设置情感类别
val currentMood = MoodManager.getCurrentMood()
val emotionCategory = when (currentMood) {
"中性" -> "neutral"
"悲伤" -> "sad"
"高兴" -> "happy"
"生气" -> "angry"
"恐惧" -> "fear"
"撒娇" -> "sajiao"
"震惊" -> "amaze"
"厌恶" -> "disgusted"
"平静" -> "peaceful"
// 兼容旧的心情名称
"开心" -> "happy"
"伤心" -> "sad"
"平和" -> "peaceful"
"惊讶" -> "amaze"
"关心" -> "neutral"
"害羞" -> "sajiao"
else -> "neutral"
}
request.setEmotionCategory(emotionCategory) // 控制合成音频的情感
request.setEmotionIntensity(100) // 控制合成音频情感程度
request.setSessionId(UUID.randomUUID().toString()) // sessionId
request.setText(item.text) // 合成文本
val listener = object : RealTimeSpeechSynthesizerListener() {
override fun onSynthesisStart(response: SpeechSynthesizerResponse) {
Log.d(TAG, "onSynthesisStart: ${response.sessionId}")
}
override fun onSynthesisEnd(response: SpeechSynthesizerResponse) {
Log.d(TAG, "onSynthesisEnd: ${response.sessionId}")
val ttsMs = System.currentTimeMillis() - startMs
callback?.onTtsSegmentCompleted(ttsMs)
}
override fun onAudioResult(buffer: ByteBuffer) {
val data = ByteArray(buffer.remaining())
buffer.get(data)
// 播放pcm
audioTrack.write(data, 0, data.size)
}
override fun onTextResult(response: SpeechSynthesizerResponse) {
Log.d(TAG, "onTextResult: ${response.sessionId}")
}
override fun onSynthesisCancel() {
Log.d(TAG, "onSynthesisCancel")
}
override fun onSynthesisFail(response: SpeechSynthesizerResponse) {
Log.e(TAG, "onSynthesisFail: ${response.sessionId}, error: ${response.message}")
}
}
synthesizer = RealTimeSpeechSynthesizer(proxy, credential, request, listener)
synthesizer?.start()
} catch (e: Exception) {
Log.e(TAG, "QCloud TTS error: ${e.message}", e)
}
}
TtsQueueItem.End -> {
callback?.onClearAsrQueue()
waitForPlaybackComplete(audioTrack)
callback?.onTtsCompleted()
ttsPlaying.set(false)
callback?.onSetSpeaking(false)
callback?.onEndTurn()
break
}
}
}
}
private fun waitForPlaybackComplete(audioTrack: AudioTrack) {
// 等待音频播放完成
Thread.sleep(1000)
}
}

View File

@@ -0,0 +1,181 @@
package com.digitalperson.tts
import android.content.Context
import android.util.Log
class TtsController(private val context: Context) {
companion object {
private const val TAG = "TtsController"
}
private var localTts: TtsManager? = null
private var qcloudTts: QCloudTtsManager? = null
private var useQCloudTts = false
interface TtsCallback {
fun onTtsStarted(text: String)
fun onTtsCompleted()
fun onTtsSegmentCompleted(durationMs: Long)
fun isTtsStopped(): Boolean
fun onClearAsrQueue()
fun onSetSpeaking(speaking: Boolean)
fun onEndTurn()
}
private var callback: TtsCallback? = null
fun setCallback(callback: TtsCallback) {
this.callback = callback
localTts?.setCallback(object : TtsManager.TtsCallback {
override fun onTtsStarted(text: String) {
callback.onTtsStarted(text)
}
override fun onTtsCompleted() {
callback.onTtsCompleted()
}
override fun onTtsSegmentCompleted(durationMs: Long) {
callback.onTtsSegmentCompleted(durationMs)
}
override fun isTtsStopped(): Boolean {
return callback.isTtsStopped()
}
override fun onClearAsrQueue() {
callback.onClearAsrQueue()
}
override fun onSetSpeaking(speaking: Boolean) {
callback.onSetSpeaking(speaking)
}
override fun getCurrentTrace() = null
override fun onTraceMarkTtsRequestEnqueued() {
}
override fun onTraceMarkTtsSynthesisStart() {
}
override fun onTraceMarkTtsFirstPcmReady() {
}
override fun onTraceMarkTtsFirstAudioPlay() {
}
override fun onTraceMarkTtsDone() {
}
override fun onTraceAddDuration(name: String, value: Long) {
}
override fun onEndTurn() {
callback.onEndTurn()
}
})
qcloudTts?.setCallback(object : QCloudTtsManager.TtsCallback {
override fun onTtsStarted(text: String) {
callback.onTtsStarted(text)
}
override fun onTtsCompleted() {
callback.onTtsCompleted()
}
override fun onTtsSegmentCompleted(durationMs: Long) {
callback.onTtsSegmentCompleted(durationMs)
}
override fun isTtsStopped(): Boolean {
return callback.isTtsStopped()
}
override fun onClearAsrQueue() {
callback.onClearAsrQueue()
}
override fun onSetSpeaking(speaking: Boolean) {
callback.onSetSpeaking(speaking)
}
override fun onEndTurn() {
callback.onEndTurn()
}
})
}
fun init(): Boolean {
// 初始化本地TTS
localTts = TtsManager(context)
val localInit = localTts?.initTtsAndAudioTrack() ?: false
Log.d(TAG, "Local TTS init: $localInit")
// 初始化腾讯云TTS
qcloudTts = QCloudTtsManager(context)
val qcloudInit = qcloudTts?.init() ?: false
Log.d(TAG, "QCloud TTS init: $qcloudInit")
return localInit || qcloudInit
}
fun setUseQCloudTts(useQCloud: Boolean) {
this.useQCloudTts = useQCloud
Log.d(TAG, "TTS mode changed: ${if (useQCloud) "QCloud" else "Local"}")
}
fun enqueueSegment(seg: String) {
if (useQCloudTts) {
qcloudTts?.enqueueSegment(seg)
} else {
localTts?.enqueueSegment(seg)
}
}
fun enqueueEnd() {
if (useQCloudTts) {
qcloudTts?.enqueueEnd()
} else {
localTts?.enqueueEnd()
}
}
fun isPlaying(): Boolean {
return if (useQCloudTts) {
qcloudTts?.isPlaying() ?: false
} else {
localTts?.isPlaying() ?: false
}
}
fun reset() {
if (useQCloudTts) {
qcloudTts?.reset()
} else {
localTts?.reset()
}
}
fun stop() {
if (useQCloudTts) {
qcloudTts?.stop()
} else {
localTts?.stop()
}
}
fun interruptForNewTurn(waitTimeoutMs: Long = 300): Boolean {
return if (useQCloudTts) {
qcloudTts?.interruptForNewTurn(waitTimeoutMs) ?: false
} else {
localTts?.interruptForNewTurn(waitTimeoutMs) ?: false
}
}
fun release() {
localTts?.release()
qcloudTts?.release()
}
}

View File

@@ -1,12 +1,14 @@
package com.digitalperson.ui
import android.app.Activity
import android.app.ProgressDialog
import android.opengl.GLSurfaceView
import android.text.method.ScrollingMovementMethod
import android.view.MotionEvent
import android.widget.Button
import android.widget.LinearLayout
import android.widget.ScrollView
import android.widget.Switch
import android.widget.TextView
import android.widget.Toast
import com.digitalperson.live2d.Live2DAvatarManager
@@ -18,7 +20,10 @@ class Live2DUiManager(private val activity: Activity) {
private var stopButton: Button? = null
private var recordButton: Button? = null
private var traditionalButtons: LinearLayout? = null
private var llmModeSwitch: Switch? = null
private var llmModeSwitchRow: LinearLayout? = null
private var avatarManager: Live2DAvatarManager? = null
private var downloadProgressDialog: ProgressDialog? = null
private var lastUiText: String = ""
@@ -29,6 +34,8 @@ class Live2DUiManager(private val activity: Activity) {
stopButtonId: Int = -1,
recordButtonId: Int = -1,
traditionalButtonsId: Int = -1,
llmModeSwitchId: Int = -1,
llmModeSwitchRowId: Int = -1,
silentPlayerViewId: Int,
speakingPlayerViewId: Int,
live2dViewId: Int
@@ -39,12 +46,17 @@ class Live2DUiManager(private val activity: Activity) {
if (stopButtonId != -1) stopButton = activity.findViewById(stopButtonId)
if (recordButtonId != -1) recordButton = activity.findViewById(recordButtonId)
if (traditionalButtonsId != -1) traditionalButtons = activity.findViewById(traditionalButtonsId)
if (llmModeSwitchId != -1) llmModeSwitch = activity.findViewById(llmModeSwitchId)
if (llmModeSwitchRowId != -1) llmModeSwitchRow = activity.findViewById(llmModeSwitchRowId)
textView?.movementMethod = ScrollingMovementMethod()
val glView = activity.findViewById<GLSurfaceView>(live2dViewId)
avatarManager = Live2DAvatarManager(glView)
avatarManager?.setSpeaking(false)
// 默认隐藏本地 LLM 开关
llmModeSwitchRow?.visibility = LinearLayout.GONE
}
fun setStartButtonListener(listener: () -> Unit) {
@@ -131,6 +143,72 @@ class Live2DUiManager(private val activity: Activity) {
}
}
/**
* 显示或隐藏本地 LLM 开关
*/
fun showLLMSwitch(show: Boolean) {
activity.runOnUiThread {
llmModeSwitchRow?.visibility = if (show) LinearLayout.VISIBLE else LinearLayout.GONE
}
}
/**
* 设置 LLM 模式开关的监听器
*/
fun setLLMSwitchListener(listener: (Boolean) -> Unit) {
llmModeSwitch?.setOnCheckedChangeListener { _, isChecked ->
listener(isChecked)
}
}
/**
* 设置 LLM 模式开关的状态
*/
fun setLLMSwitchChecked(checked: Boolean) {
activity.runOnUiThread {
llmModeSwitch?.isChecked = checked
}
}
/**
* 显示下载进度对话框
*/
fun showDownloadProgressDialog() {
activity.runOnUiThread {
downloadProgressDialog = ProgressDialog(activity).apply {
setTitle("下载模型")
setMessage("正在下载 LLM 模型文件,请稍候...")
setProgressStyle(ProgressDialog.STYLE_HORIZONTAL)
isIndeterminate = false
setCancelable(false)
setCanceledOnTouchOutside(false)
show()
}
}
}
/**
* 更新下载进度
*/
fun updateDownloadProgress(fileName: String, downloadedMB: Long, totalMB: Long, progress: Int) {
activity.runOnUiThread {
downloadProgressDialog?.apply {
setMessage("正在下载: $fileName\n$downloadedMB MB / $totalMB MB")
setProgress(progress)
}
}
}
/**
* 关闭下载进度对话框
*/
fun dismissDownloadProgressDialog() {
activity.runOnUiThread {
downloadProgressDialog?.dismiss()
downloadProgressDialog = null
}
}
fun onResume() {
avatarManager?.onResume()
}

View File

@@ -1,6 +1,8 @@
package com.digitalperson.util
import android.content.ContentUris
import android.content.Context
import android.provider.MediaStore
import android.util.Log
import com.digitalperson.config.AppConfig
import java.io.File
@@ -48,13 +50,270 @@ object FileHelper {
)
return copyAssetsToInternal(context, AppConfig.Asr.MODEL_DIR, outDir, files)
}
@JvmStatic
fun copyRetinaFaceAssets(context: Context): File {
val outDir = File(context.filesDir, AppConfig.Face.MODEL_DIR)
val files = arrayOf(AppConfig.Face.MODEL_NAME)
return copyAssetsToInternal(context, AppConfig.Face.MODEL_DIR, outDir, files)
}
@JvmStatic
fun copyInsightFaceAssets(context: Context): File {
val outDir = File(context.filesDir, AppConfig.FaceRecognition.MODEL_DIR)
val files = arrayOf(AppConfig.FaceRecognition.MODEL_NAME)
return copyAssetsToInternal(context, AppConfig.FaceRecognition.MODEL_DIR, outDir, files)
}
fun ensureDir(dir: File): File {
if (!dir.exists()) dir.mkdirs()
if (!dir.exists()) {
val created = dir.mkdirs()
if (!created) {
Log.e(TAG, "Failed to create directory: ${dir.absolutePath}")
// 如果创建失败,使用应用内部存储
return File("/data/data/${dir.parentFile?.parentFile?.name}/files/llm")
}
}
return dir
}
fun getAsrAudioDir(context: Context): File {
return ensureDir(File(context.filesDir, "asr_audio"))
}
// @JvmStatic
// 当前使用的模型文件名
private const val MODEL_FILE_NAME = "Qwen3-0.6B-rk3588-w8a8.rkllm"
fun getLLMModelPath(context: Context): String {
Log.d(TAG, "=== getLLMModelPath START ===")
// 从应用内部存储目录加载模型
val llmDir = ensureDir(File(context.filesDir, "llm"))
Log.d(TAG, "Loading models from: ${llmDir.absolutePath}")
// 检查文件是否存在
val rkllmFile = File(llmDir, MODEL_FILE_NAME)
if (!rkllmFile.exists()) {
Log.e(TAG, "RKLLM model not found: ${rkllmFile.absolutePath}")
} else {
Log.i(TAG, "RKLLM model exists, size: ${rkllmFile.length() / (1024*1024)} MB")
}
val modelPath = rkllmFile.absolutePath
Log.i(TAG, "Using RKLLM model path: $modelPath")
Log.d(TAG, "=== getLLMModelPath END ===")
return modelPath
}
/**
* 异步下载模型文件,带进度回调
* @param context 上下文
* @param onProgress 进度回调 (currentFile, downloadedBytes, totalBytes, progressPercent)
* @param onComplete 完成回调 (success, message)
*/
@JvmStatic
fun downloadModelFilesWithProgress(
context: Context,
onProgress: (String, Long, Long, Int) -> Unit,
onComplete: (Boolean, String) -> Unit
) {
Log.d(TAG, "=== downloadModelFilesWithProgress START ===")
val llmDir = ensureDir(File(context.filesDir, "llm"))
// 模型文件列表 - 使用 DeepSeek-R1-Distill-Qwen-1.5B 模型
val modelFiles = listOf(
MODEL_FILE_NAME
)
// 在后台线程下载
Thread {
try {
var allSuccess = true
var totalDownloaded: Long = 0
var totalSize: Long = 0
// 首先计算总大小
for (fileName in modelFiles) {
val modelFile = File(llmDir, fileName)
if (!modelFile.exists() || modelFile.length() == 0L) {
val size = getFileSizeFromServer("http://192.168.1.19:5000/download/$fileName")
if (size > 0) {
totalSize += size
} else {
// 如果无法获取文件大小,使用估计值
when (fileName) {
MODEL_FILE_NAME -> totalSize += 1L * 1024 * 1024 * 1024 // 1.5B模型约1GB
else -> totalSize += 1L * 1024 * 1024 * 1024 // 1GB 默认
}
Log.i(TAG, "Using estimated size for $fileName: ${totalSize / (1024*1024)} MB")
}
}
}
for (fileName in modelFiles) {
val modelFile = File(llmDir, fileName)
if (!modelFile.exists() || modelFile.length() == 0L) {
Log.i(TAG, "Downloading model file: $fileName")
try {
downloadFileWithProgress(
"http://192.168.1.19:5000/download/$fileName",
modelFile
) { downloaded, total ->
val progress = if (totalSize > 0) {
((totalDownloaded + downloaded) * 100 / totalSize).toInt()
} else 0
onProgress(fileName, downloaded, total, progress)
}
totalDownloaded += modelFile.length()
Log.i(TAG, "Downloaded model file: $fileName, size: ${modelFile.length() / (1024*1024)} MB")
} catch (e: Exception) {
Log.e(TAG, "Failed to download model file $fileName: ${e.message}")
allSuccess = false
}
} else {
totalDownloaded += modelFile.length()
Log.i(TAG, "Model file exists: $fileName, size: ${modelFile.length() / (1024*1024)} MB")
}
}
Log.d(TAG, "=== downloadModelFilesWithProgress END ===")
if (allSuccess) {
onComplete(true, "模型下载完成")
} else {
onComplete(false, "部分模型下载失败")
}
} catch (e: Exception) {
Log.e(TAG, "Download failed: ${e.message}")
onComplete(false, "下载失败: ${e.message}")
}
}.start()
}
/**
* 从服务器获取文件大小
*/
private fun getFileSizeFromServer(url: String): Long {
return try {
val connection = java.net.URL(url).openConnection() as java.net.HttpURLConnection
connection.requestMethod = "HEAD"
connection.connectTimeout = 15000
connection.readTimeout = 15000
// 从响应头获取 Content-Length避免 int 溢出
val contentLengthStr = connection.getHeaderField("Content-Length")
var size = 0L
if (contentLengthStr != null) {
try {
size = contentLengthStr.toLong()
if (size < 0) {
Log.w(TAG, "Invalid Content-Length value: $size")
size = 0
}
} catch (e: NumberFormatException) {
Log.w(TAG, "Invalid Content-Length format: $contentLengthStr")
size = 0
}
} else {
val contentLength = connection.contentLength
if (contentLength > 0) {
size = contentLength.toLong()
} else {
Log.w(TAG, "Content-Length not available or invalid: $contentLength")
size = 0
}
}
connection.disconnect()
Log.i(TAG, "File size for $url: $size bytes")
size
} catch (e: Exception) {
Log.w(TAG, "Failed to get file size: ${e.message}")
0
}
}
/**
* 从网络下载文件,带进度回调
*/
private fun downloadFileWithProgress(
url: String,
destination: File,
onProgress: (Long, Long) -> Unit
) {
val connection = java.net.URL(url).openConnection() as java.net.HttpURLConnection
connection.connectTimeout = 30000
connection.readTimeout = 6000000
// 从响应头获取 Content-Length避免 int 溢出
val contentLengthStr = connection.getHeaderField("Content-Length")
val totalSize = if (contentLengthStr != null) {
try {
contentLengthStr.toLong()
} catch (e: NumberFormatException) {
Log.w(TAG, "Invalid Content-Length format: $contentLengthStr")
0
}
} else {
connection.contentLength.toLong()
}
Log.i(TAG, "Downloading file $url, size: $totalSize bytes")
try {
connection.inputStream.use { input ->
FileOutputStream(destination).use { output ->
val buffer = ByteArray(8192)
var downloaded: Long = 0
var bytesRead: Int
while (input.read(buffer).also { bytesRead = it } != -1) {
output.write(buffer, 0, bytesRead)
downloaded += bytesRead
onProgress(downloaded, totalSize)
}
}
}
} finally {
connection.disconnect()
}
}
/**
* 检查本地 LLM 模型是否可用
*/
@JvmStatic
fun isLocalLLMAvailable(context: Context): Boolean {
val llmDir = File(context.filesDir, "llm")
val rkllmFile = File(llmDir, MODEL_FILE_NAME)
val rkllmExists = rkllmFile.exists() && rkllmFile.length() > 0
Log.i(TAG, "LLM model check: rkllm=$rkllmExists")
Log.i(TAG, "RKLLM file: ${rkllmFile.absolutePath}, size: ${if (rkllmFile.exists()) rkllmFile.length() / (1024*1024) else 0} MB")
return rkllmExists
}
/**
* 从网络下载文件
*/
private fun downloadFile(url: String, destination: File) {
val connection = java.net.URL(url).openConnection() as java.net.HttpURLConnection
connection.connectTimeout = 30000 // 30秒超时
connection.readTimeout = 60000 // 60秒读取超时
try {
connection.inputStream.use { input ->
FileOutputStream(destination).use { output ->
input.copyTo(output)
}
}
} finally {
connection.disconnect()
}
}
}

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -16,11 +16,36 @@
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toTopOf="parent" />
<FrameLayout
android:id="@+id/face_preview_container"
android:layout_width="220dp"
android:layout_height="300dp"
android:layout_margin="12dp"
android:background="#55000000"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintTop_toTopOf="parent">
<androidx.camera.view.PreviewView
android:id="@+id/camera_preview"
android:layout_width="match_parent"
android:layout_height="match_parent"
app:implementationMode="compatible"
app:scaleType="fitCenter" />
<com.digitalperson.face.FaceOverlayView
android:id="@+id/face_overlay"
android:layout_width="match_parent"
android:layout_height="match_parent" />
</FrameLayout>
<ScrollView
android:id="@+id/scroll_view"
android:layout_width="match_parent"
android:layout_width="0dp"
android:layout_height="200dp"
android:fillViewport="true">
android:fillViewport="true"
app:layout_constraintBottom_toTopOf="@+id/llm_mode_switch_row"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent">
<TextView
android:id="@+id/my_text"
@@ -33,6 +58,31 @@
android:textIsSelectable="true" />
</ScrollView>
<LinearLayout
android:id="@+id/llm_mode_switch_row"
android:layout_width="0dp"
android:layout_height="wrap_content"
android:gravity="center_vertical"
android:orientation="horizontal"
android:padding="16dp"
app:layout_constraintBottom_toTopOf="@+id/streaming_switch_row"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent">
<TextView
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_marginEnd="16dp"
android:text="本地LLM"
android:textSize="16sp" />
<Switch
android:id="@+id/llm_mode_switch"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:checked="false" />
</LinearLayout>
<LinearLayout
android:id="@+id/streaming_switch_row"
android:layout_width="0dp"
@@ -40,6 +90,31 @@
android:gravity="center_vertical"
android:orientation="horizontal"
android:padding="16dp"
app:layout_constraintBottom_toTopOf="@+id/tts_mode_switch_row"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent">
<TextView
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_marginEnd="16dp"
android:text="流式输出"
android:textSize="16sp" />
<Switch
android:id="@+id/streaming_switch"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:checked="false" />
</LinearLayout>
<LinearLayout
android:id="@+id/tts_mode_switch_row"
android:layout_width="0dp"
android:layout_height="wrap_content"
android:gravity="center_vertical"
android:orientation="horizontal"
android:padding="16dp"
app:layout_constraintBottom_toTopOf="@+id/button_row"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent">
@@ -48,11 +123,11 @@
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_marginEnd="16dp"
android:text="流式输出"
android:text="腾讯云TTS"
android:textSize="16sp" />
<Switch
android:id="@+id/streaming_switch"
android:id="@+id/tts_mode_switch"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:checked="false" />

View File

@@ -3,5 +3,5 @@
<string name="start">开始</string>
<string name="stop">结束</string>
<string name="hint">点击“开始”说话;识别后会请求大模型并用 TTS 播放回复。</string>
<string name="system_prompt">你是一名小学女老师喜欢回答学生的各种问题请简洁但温柔地回答每个回答不超过30字。在每次回复的最前面用方括号标注你的心情格式为[开心/伤心/愤怒/平和/惊讶/关心/害羞],例如:[开心]同学你好呀!请问有什么问题吗?</string>
<string name="system_prompt">你是一名小学女老师喜欢回答学生的各种问题请简洁但温柔地回答每个回答不超过30字。在每次回复的最前面用方括号标注你的心情格式为[中性、悲伤、高兴、生气、恐惧、撒娇、震惊、厌恶],例如:[高兴]同学你好呀!请问有什么问题吗?</string>
</resources>

View File

@@ -6,7 +6,7 @@
# http://www.gradle.org/docs/current/userguide/build_environment.html
# Specifies the JVM arguments used for the daemon process.
# The setting is particularly useful for tweaking memory settings.
org.gradle.jvmargs=-Xmx6g -Dfile.encoding=UTF-8
org.gradle.jvmargs=-Xmx8g -Dfile.encoding=UTF-8
# When configured, Gradle will run in incubating parallel mode.
# This option should only be used with decoupled projects. More details, visit
# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects