initial commit

This commit is contained in:
gcw_4spBpAfv
2026-02-25 18:13:26 +08:00
commit 6aa84d6b77
239 changed files with 995156 additions and 0 deletions

View File

@@ -0,0 +1,689 @@
package com.digitalperson
import android.Manifest
import android.content.pm.PackageManager
import android.media.AudioAttributes
import android.media.AudioFormat
import android.media.AudioManager
import android.media.AudioRecord
import android.media.AudioTrack
import android.media.MediaRecorder
import android.media.audiofx.AcousticEchoCanceler
import android.media.audiofx.NoiseSuppressor
import android.os.Bundle
import android.os.SystemClock
import android.text.method.ScrollingMovementMethod
import android.util.Log
import android.widget.Button
import android.widget.TextView
import android.widget.Toast
import androidx.appcompat.app.AppCompatActivity
import androidx.core.app.ActivityCompat
import com.digitalperson.cloud.CloudApiManager
import com.digitalperson.engine.SenseVoiceEngineRKNN
import com.digitalperson.metrics.TraceManager
import com.digitalperson.metrics.TraceSession
import com.k2fsa.sherpa.onnx.OfflineTts
import com.k2fsa.sherpa.onnx.SileroVadModelConfig
import com.k2fsa.sherpa.onnx.Vad
import com.k2fsa.sherpa.onnx.VadModelConfig
import com.k2fsa.sherpa.onnx.getOfflineTtsConfig
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.cancel
import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import java.io.File
import java.io.FileOutputStream
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.atomic.AtomicBoolean
import kotlin.math.max
private const val TAG = "DigitalPerson"
private const val REQUEST_RECORD_AUDIO_PERMISSION = 200
class MainActivity : AppCompatActivity() {
private lateinit var startButton: Button
private lateinit var stopButton: Button
private lateinit var textView: TextView
private lateinit var vad: Vad
private var senseVoice: SenseVoiceEngineRKNN? = null
private var tts: OfflineTts? = null
private var track: AudioTrack? = null
private var aec: AcousticEchoCanceler? = null
private var ns: NoiseSuppressor? = null
private var audioRecord: AudioRecord? = null
private val audioSource = MediaRecorder.AudioSource.MIC
private val sampleRateInHz = 16000
private val channelConfig = AudioFormat.CHANNEL_IN_MONO
private val audioFormat = AudioFormat.ENCODING_PCM_16BIT
private val permissions: Array<String> = arrayOf(Manifest.permission.RECORD_AUDIO)
@Volatile
private var isRecording: Boolean = false
private val ioScope = CoroutineScope(SupervisorJob() + Dispatchers.IO)
private var recordingJob: Job? = null
private val nativeLock = Any()
private lateinit var cloudApiManager: CloudApiManager
private val segmenter = StreamingTextSegmenter()
private sealed class TtsQueueItem {
data class Segment(val text: String) : TtsQueueItem()
data object End : TtsQueueItem()
}
private val ttsQueue = LinkedBlockingQueue<TtsQueueItem>()
private val ttsStopped = AtomicBoolean(false)
private val ttsWorkerRunning = AtomicBoolean(false)
private var currentTrace: TraceSession? = null
private var lastUiText: String = ""
@Volatile private var llmInFlight: Boolean = false
override fun onRequestPermissionsResult(
requestCode: Int,
permissions: Array<String>,
grantResults: IntArray
) {
super.onRequestPermissionsResult(requestCode, permissions, grantResults)
val ok = requestCode == REQUEST_RECORD_AUDIO_PERMISSION &&
grantResults.isNotEmpty() &&
grantResults[0] == PackageManager.PERMISSION_GRANTED
if (!ok) {
Log.e(TAG, "Audio record is disallowed")
finish()
}
}
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
setContentView(R.layout.activity_main)
ActivityCompat.requestPermissions(this, permissions, REQUEST_RECORD_AUDIO_PERMISSION)
startButton = findViewById(R.id.start_button)
stopButton = findViewById(R.id.stop_button)
textView = findViewById(R.id.my_text)
textView.movementMethod = ScrollingMovementMethod()
startButton.setOnClickListener { onStartClicked() }
stopButton.setOnClickListener { onStopClicked(userInitiated = true) }
// 避免 UI 线程重初始化导致 ANR在后台初始化模型与 AudioTrack
startButton.isEnabled = false
stopButton.isEnabled = false
textView.text = "初始化中…"
ioScope.launch {
try {
Log.i(TAG, "Init VAD + SenseVoice(RKNN) + TTS (background)")
synchronized(nativeLock) {
initVadModel()
initSenseVoiceModel()
}
withContext(Dispatchers.Main) {
initTtsAndAudioTrack()
textView.text = getString(R.string.hint)
startButton.isEnabled = true
stopButton.isEnabled = false
}
} catch (t: Throwable) {
Log.e(TAG, "Initialization failed: ${t.message}", t)
withContext(Dispatchers.Main) {
textView.text = "初始化失败:${t.javaClass.simpleName}: ${t.message}"
Toast.makeText(
this@MainActivity,
"初始化失败(请看 Logcat: ${t.javaClass.simpleName}",
Toast.LENGTH_LONG
).show()
startButton.isEnabled = false
stopButton.isEnabled = false
}
}
}
cloudApiManager = CloudApiManager(object : CloudApiManager.CloudApiListener {
private var llmFirstChunkMarked = false
override fun onLLMResponseReceived(response: String) {
currentTrace?.markLlmDone()
llmInFlight = false
// flush remaining buffer into TTS
for (seg in segmenter.flush()) {
enqueueTtsSegment(seg)
}
// signal queue end (no more segments after this)
ttsQueue.offer(TtsQueueItem.End)
}
override fun onLLMStreamingChunkReceived(chunk: String) {
if (!llmFirstChunkMarked) {
llmFirstChunkMarked = true
currentTrace?.markLlmFirstChunk()
}
appendToUi(chunk)
val segments = segmenter.processChunk(chunk)
for (seg in segments) {
enqueueTtsSegment(seg)
}
}
override fun onTTSAudioReceived(audioFilePath: String) {
// unused
}
override fun onError(errorMessage: String) {
llmInFlight = false
Toast.makeText(this@MainActivity, errorMessage, Toast.LENGTH_LONG).show()
onStopClicked(userInitiated = false)
}
})
}
override fun onDestroy() {
super.onDestroy()
onStopClicked(userInitiated = false)
ioScope.cancel()
synchronized(nativeLock) {
try {
vad.release()
} catch (_: Throwable) {
}
try {
senseVoice?.deinitialize()
} catch (_: Throwable) {
}
}
try {
tts?.release()
} catch (_: Throwable) {
}
}
private fun onStartClicked() {
if (isRecording) return
if (!initMicrophone()) {
Toast.makeText(this, "麦克风初始化失败/无权限", Toast.LENGTH_SHORT).show()
return
}
// Start a new trace turn
currentTrace = TraceManager.getInstance().startNewTurn()
currentTrace?.mark("turn_start")
llmInFlight = false
lastUiText = ""
textView.text = ""
ttsStopped.set(false)
ttsQueue.clear()
segmenter.reset()
vad.reset()
audioRecord!!.startRecording()
isRecording = true
startButton.isEnabled = false
stopButton.isEnabled = true
recordingJob?.cancel()
recordingJob = ioScope.launch {
processSamplesLoop()
}
}
private fun onStopClicked(userInitiated: Boolean) {
isRecording = false
try {
audioRecord?.stop()
} catch (_: Throwable) {
}
try {
audioRecord?.release()
} catch (_: Throwable) {
}
audioRecord = null
recordingJob?.cancel()
recordingJob = null
ttsStopped.set(true)
ttsQueue.clear()
// wake worker if waiting
ttsQueue.offer(TtsQueueItem.End)
try {
track?.pause()
track?.flush()
} catch (_: Throwable) {
}
try { aec?.release() } catch (_: Throwable) {}
try { ns?.release() } catch (_: Throwable) {}
aec = null
ns = null
startButton.isEnabled = true
stopButton.isEnabled = false
if (userInitiated) {
TraceManager.getInstance().endTurn()
currentTrace = null
}
}
private fun initVadModel() {
// 你的 VAD 模型在 assets/vad_model/ 下
val config = VadModelConfig(
sileroVadModelConfig = SileroVadModelConfig(
model = "vad_model/silero_vad.onnx",
threshold = 0.5F,
minSilenceDuration = 0.25F,
minSpeechDuration = 0.25F,
windowSize = 512,
),
sampleRate = sampleRateInHz,
numThreads = 1,
provider = "cpu",
)
vad = Vad(assetManager = application.assets, config = config)
}
private fun initSenseVoiceModel() {
Log.i(TAG, "ASR: init SenseVoice RKNN (scheme A)")
// Copy assets/sensevoice_models/* -> filesDir/sensevoice_models/*
val modelDir = copySenseVoiceAssetsToInternal()
val modelPath = File(modelDir, "sense-voice-encoder.rknn").absolutePath
val embeddingPath = File(modelDir, "embedding.npy").absolutePath
val bpePath = File(modelDir, "chn_jpn_yue_eng_ko_spectok.bpe.model").absolutePath
// Print quick diagnostics for native libs + model files
try {
val libDir = applicationInfo.nativeLibraryDir
Log.i(TAG, "nativeLibraryDir=$libDir")
try {
val names = File(libDir).list()?.joinToString(", ") ?: "(empty)"
Log.i(TAG, "nativeLibraryDir files: $names")
} catch (t: Throwable) {
Log.w(TAG, "Failed to list nativeLibraryDir: ${t.message}")
}
} catch (_: Throwable) {
}
Log.i(TAG, "SenseVoice model paths:")
Log.i(TAG, " model=$modelPath exists=${File(modelPath).exists()} size=${File(modelPath).length()}")
Log.i(TAG, " embedding=$embeddingPath exists=${File(embeddingPath).exists()} size=${File(embeddingPath).length()}")
Log.i(TAG, " bpe=$bpePath exists=${File(bpePath).exists()} size=${File(bpePath).length()}")
val t0 = SystemClock.elapsedRealtime()
val engine = try {
SenseVoiceEngineRKNN(this)
} catch (e: UnsatisfiedLinkError) {
// Most common: libsensevoiceEngine.so not packaged/built, or dependent libs missing
throw IllegalStateException("Load native libraries failed: ${e.message}", e)
}
val ok = try {
engine.loadModelDirectly(modelPath, embeddingPath, bpePath)
} catch (t: Throwable) {
throw IllegalStateException("SenseVoice loadModelDirectly crashed: ${t.message}", t)
}
val dt = SystemClock.elapsedRealtime() - t0
Log.i(TAG, "SenseVoice loadModelDirectly ok=$ok costMs=$dt")
if (!ok) throw IllegalStateException("SenseVoiceEngineRKNN loadModelDirectly returned false")
senseVoice = engine
}
private fun initTtsAndAudioTrack() {
try {
// 你放入的 sherpa-onnx VITS 中文模型目录:
// assets/tts_model/sherpa-onnx-vits-zh-ll/{model.onnx,tokens.txt,lexicon.txt,...}
val modelDir = "tts_model/sherpa-onnx-vits-zh-ll"
val modelName = "model.onnx"
val lexicon = "lexicon.txt"
val dataDir = ""
val ttsConfig = getOfflineTtsConfig(
modelDir = modelDir,
modelName = modelName,
acousticModelName = "",
vocoder = "",
voices = "",
lexicon = lexicon,
dataDir = dataDir,
dictDir = "",
// 中文规范化规则(目录里已有这些 fst
ruleFsts = "$modelDir/phone.fst,$modelDir/date.fst,$modelDir/number.fst,$modelDir/new_heteronym.fst",
ruleFars = "",
numThreads = null,
isKitten = false
)
tts = OfflineTts(assetManager = application.assets, config = ttsConfig)
} catch (t: Throwable) {
Log.e(TAG, "Init TTS failed: ${t.message}", t)
tts = null
runOnUiThread {
Toast.makeText(
this,
"TTS 初始化失败:请确认 assets/tts_model/sherpa-onnx-vits-zh-ll/ 下有 model.onnx、tokens.txt、lexicon.txt 以及 phone/date/number/new_heteronym.fst",
Toast.LENGTH_LONG
).show()
}
}
val t = tts ?: return
val sr = t.sampleRate()
val bufLength = AudioTrack.getMinBufferSize(
sr,
AudioFormat.CHANNEL_OUT_MONO,
AudioFormat.ENCODING_PCM_FLOAT
)
val attr = AudioAttributes.Builder()
.setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
.setUsage(AudioAttributes.USAGE_MEDIA)
.build()
val format = AudioFormat.Builder()
.setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
.setChannelMask(AudioFormat.CHANNEL_OUT_MONO)
.setSampleRate(sr)
.build()
track = AudioTrack(
attr,
format,
bufLength,
AudioTrack.MODE_STREAM,
AudioManager.AUDIO_SESSION_ID_GENERATE
)
track?.play()
}
private fun assetExists(path: String): Boolean {
return try {
application.assets.open(path).close()
true
} catch (_: Throwable) {
false
}
}
private fun copySenseVoiceAssetsToInternal(): File {
val outDir = File(filesDir, "sensevoice_models")
if (!outDir.exists()) outDir.mkdirs()
val files = arrayOf(
"am.mvn",
"chn_jpn_yue_eng_ko_spectok.bpe.model",
"embedding.npy",
"sense-voice-encoder.rknn"
)
for (name in files) {
val assetPath = "sensevoice_models/$name"
val outFile = File(outDir, name)
if (outFile.exists() && outFile.length() > 0) continue
application.assets.open(assetPath).use { input ->
FileOutputStream(outFile).use { output ->
input.copyTo(output)
}
}
}
return outDir
}
private fun initMicrophone(): Boolean {
if (ActivityCompat.checkSelfPermission(this, Manifest.permission.RECORD_AUDIO)
!= PackageManager.PERMISSION_GRANTED
) {
ActivityCompat.requestPermissions(this, permissions, REQUEST_RECORD_AUDIO_PERMISSION)
return false
}
val numBytes = AudioRecord.getMinBufferSize(sampleRateInHz, channelConfig, audioFormat)
audioRecord = AudioRecord(
audioSource,
sampleRateInHz,
channelConfig,
audioFormat,
numBytes * 2
)
val sessionId = audioRecord?.audioSessionId ?: 0
if (sessionId != 0) {
if (android.media.audiofx.AcousticEchoCanceler.isAvailable()) {
aec = android.media.audiofx.AcousticEchoCanceler.create(sessionId)?.apply {
enabled = true
}
Log.i(TAG, "AEC enabled=${aec?.enabled}")
} else {
Log.w(TAG, "AEC not available on this device")
}
if (android.media.audiofx.NoiseSuppressor.isAvailable()) {
ns = android.media.audiofx.NoiseSuppressor.create(sessionId)?.apply {
enabled = true
}
Log.i(TAG, "NS enabled=${ns?.enabled}")
} else {
Log.w(TAG, "NS not available on this device")
}
}
return true
}
private suspend fun processSamplesLoop() {
// Avoid calling vad.front()/vad.pop() (native queue APIs) since it crashes on some builds.
// Use vad.compute() and implement a simple VAD segmenter in Kotlin instead.
val windowSize = 512
val buffer = ShortArray(windowSize)
val threshold = 0.5f
val minSilenceSamples = (0.25f * sampleRateInHz).toInt()
val minSpeechSamples = (0.25f * sampleRateInHz).toInt()
val maxSpeechSamples = (5.0f * sampleRateInHz).toInt()
var inSpeech = false
var silenceSamples = 0
var speechBuf = FloatArray(0)
var speechLen = 0
fun appendSpeech(chunk: FloatArray) {
val needed = speechLen + chunk.size
if (speechBuf.size < needed) {
var newCap = maxOf(needed, maxOf(1024, speechBuf.size * 2))
if (newCap > maxSpeechSamples) newCap = maxSpeechSamples
val n = FloatArray(newCap)
if (speechLen > 0) System.arraycopy(speechBuf, 0, n, 0, speechLen)
speechBuf = n
}
val copyN = minOf(chunk.size, max(0, maxSpeechSamples - speechLen))
if (copyN > 0) {
System.arraycopy(chunk, 0, speechBuf, speechLen, copyN)
speechLen += copyN
}
}
suspend fun finalizeSegmentIfAny() {
if (speechLen < minSpeechSamples) {
speechLen = 0
inSpeech = false
silenceSamples = 0
return
}
val seg = speechBuf.copyOf(speechLen)
speechLen = 0
inSpeech = false
silenceSamples = 0
// 每次只允许一个 LLM 请求在飞,避免堆积导致卡死/竞态
if (llmInFlight) return
val trace = currentTrace
trace?.markASRStart()
val raw = synchronized(nativeLock) {
val e = senseVoice
if (e == null || !e.isInitialized) "" else e.transcribeBuffer(seg)
}
val text = removeTokens(raw)
if (text.isBlank()) return
trace?.markASREnd()
if (text.isBlank()) return
withContext(Dispatchers.Main) {
appendToUi("\n\n[ASR] ${text}\n")
}
trace?.markRecordingDone()
trace?.markLlmResponseReceived()
if (BuildConfig.LLM_API_KEY.isBlank()) {
withContext(Dispatchers.Main) {
Toast.makeText(
this@MainActivity,
"未配置 LLM_API_KEY在 local.properties 或 gradle.properties 里设置)",
Toast.LENGTH_LONG
).show()
}
return
}
llmInFlight = true
cloudApiManager.callLLM(text)
}
while (isRecording && ioScope.coroutineContext.isActive) {
val ret = audioRecord?.read(buffer, 0, buffer.size) ?: break
if (ret <= 0) continue
if (ret != windowSize) continue
val chunk = FloatArray(ret) { buffer[it] / 32768.0f }
val prob = synchronized(nativeLock) { vad.compute(chunk) }
if (prob >= threshold) {
if (!inSpeech) {
inSpeech = true
silenceSamples = 0
}
appendSpeech(chunk)
if (speechLen >= maxSpeechSamples) {
finalizeSegmentIfAny()
}
} else {
if (inSpeech) {
silenceSamples += ret
if (silenceSamples >= minSilenceSamples) {
finalizeSegmentIfAny()
} else {
// keep a bit of trailing silence to avoid chopping
appendSpeech(chunk)
}
}
}
// 时间兜底切段(避免长时间无标点导致首包太慢)
val forced = segmenter.maybeForceByTime()
for (seg in forced) enqueueTtsSegment(seg)
}
// flush last partial segment
finalizeSegmentIfAny()
}
private fun removeTokens(text: String): String {
// Remove tokens like <|zh|>, <|NEUTRAL|>, <|Speech|>, <|woitn|> and stray '>' chars
var cleaned = text.replace(Regex("<\\|[^>]+\\|>"), "")
cleaned = cleaned.replace(Regex("[>>≥≫]"), "")
cleaned = cleaned.trim().replace(Regex("\\s+"), " ")
return cleaned
}
private fun enqueueTtsSegment(seg: String) {
currentTrace?.markTtsRequestEnqueued()
ttsQueue.offer(TtsQueueItem.Segment(seg))
ensureTtsWorker()
}
private fun ensureTtsWorker() {
if (!ttsWorkerRunning.compareAndSet(false, true)) return
ioScope.launch {
try {
runTtsWorker()
} finally {
ttsWorkerRunning.set(false)
}
}
}
private fun runTtsWorker() {
val t = tts ?: return
val audioTrack = track ?: return
var firstAudioMarked = false
while (true) {
val item = ttsQueue.take()
if (ttsStopped.get()) break
when (item) {
is TtsQueueItem.Segment -> {
val trace = currentTrace
trace?.markTtsSynthesisStart()
val startMs = System.currentTimeMillis()
var firstPcmMarked = false
// flush to reduce latency between segments
try {
audioTrack.pause()
audioTrack.flush()
audioTrack.play()
} catch (_: Throwable) {
}
t.generateWithCallback(
text = item.text,
sid = 0,
speed = 1.0f
) { samples ->
if (ttsStopped.get()) return@generateWithCallback 0
if (!firstPcmMarked && samples.isNotEmpty()) {
firstPcmMarked = true
trace?.markTtsFirstPcmReady()
}
if (!firstAudioMarked && samples.isNotEmpty()) {
firstAudioMarked = true
trace?.markTtsFirstAudioPlay()
}
audioTrack.write(samples, 0, samples.size, AudioTrack.WRITE_BLOCKING)
1
}
val ttsMs = System.currentTimeMillis() - startMs
trace?.addDuration("tts_segment_ms_total", ttsMs)
}
TtsQueueItem.End -> {
currentTrace?.markTtsDone()
TraceManager.getInstance().endTurn()
currentTrace = null
break
}
}
}
}
private fun appendToUi(s: String) {
lastUiText += s
textView.text = lastUiText
}
}

View File

@@ -0,0 +1,93 @@
package com.digitalperson
/**
* 将大模型流式 chunk 做“伪流式 TTS”的分段器
* - 优先按中文/英文标点断句,尽早产出第一段,缩短首包时间
* - 无标点时,按长度/时间兜底切段
*/
class StreamingTextSegmenter(
private val maxLen: Int = 18,
private val maxWaitMs: Long = 400,
) {
private val buf = StringBuilder()
private var lastEmitAtMs: Long = 0
@Synchronized
fun reset(nowMs: Long = System.currentTimeMillis()) {
buf.setLength(0)
lastEmitAtMs = nowMs
}
@Synchronized
fun processChunk(chunk: String, nowMs: Long = System.currentTimeMillis()): List<String> {
if (chunk.isEmpty()) return emptyList()
buf.append(chunk)
return drain(nowMs, forceByTime = false)
}
@Synchronized
fun flush(nowMs: Long = System.currentTimeMillis()): List<String> {
val out = mutableListOf<String>()
val remaining = buf.toString().trim()
buf.setLength(0)
if (remaining.isNotEmpty()) out.add(remaining)
lastEmitAtMs = nowMs
return out
}
@Synchronized
fun maybeForceByTime(nowMs: Long = System.currentTimeMillis()): List<String> {
return drain(nowMs, forceByTime = true)
}
private fun drain(nowMs: Long, forceByTime: Boolean): List<String> {
val out = mutableListOf<String>()
// 1) 优先按标点切分(尽早产出一小段)
while (true) {
val idx = firstPunctuationIndex(buf)
if (idx < 0) break
val seg = buf.substring(0, idx + 1).trim()
buf.delete(0, idx + 1)
if (seg.isNotEmpty()) {
out.add(seg)
lastEmitAtMs = nowMs
}
}
// 2) 长度兜底
while (buf.length >= maxLen) {
val seg = buf.substring(0, maxLen).trim()
buf.delete(0, maxLen)
if (seg.isNotEmpty()) {
out.add(seg)
lastEmitAtMs = nowMs
}
}
// 3) 时间兜底:长时间无标点也要切一段
if (forceByTime && buf.isNotEmpty() && nowMs - lastEmitAtMs >= maxWaitMs) {
val cut = minOf(buf.length, maxLen)
val seg = buf.substring(0, cut).trim()
buf.delete(0, cut)
if (seg.isNotEmpty()) {
out.add(seg)
lastEmitAtMs = nowMs
}
}
return out
}
private fun firstPunctuationIndex(sb: StringBuilder): Int {
// 强/弱标点都允许尽早断句
val punct = charArrayOf('。', '', '', '', '', '.', '!', '?', ';', ',', '\n')
var best = -1
for (p in punct) {
val i = sb.indexOf(p.toString())
if (i >= 0 && (best < 0 || i < best)) best = i
}
return best
}
}

View File

@@ -0,0 +1,207 @@
package com.digitalperson.cloud;
import android.os.Handler;
import android.os.Looper;
import android.util.Log;
import com.digitalperson.BuildConfig;
import org.json.JSONArray;
import org.json.JSONException;
import org.json.JSONObject;
import java.io.BufferedReader;
import java.io.DataOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.HttpURLConnection;
import java.net.URL;
import java.nio.charset.StandardCharsets;
public class CloudApiManager {
private static final String TAG = "CloudApiManager";
// 火山引擎OpenAI兼容API配置
private static final String LLM_API_URL = BuildConfig.LLM_API_URL;
private static final String API_KEY = BuildConfig.LLM_API_KEY;
private static final String LLM_MODEL = BuildConfig.LLM_MODEL;
private CloudApiListener mListener;
private Handler mMainHandler; // 用于在主线程执行UI更新
private JSONArray mConversationHistory; // 存储对话历史
public interface CloudApiListener {
void onLLMResponseReceived(String response);
void onLLMStreamingChunkReceived(String chunk);
void onTTSAudioReceived(String audioFilePath);
void onError(String errorMessage);
}
public CloudApiManager(CloudApiListener listener) {
this.mListener = listener;
this.mMainHandler = new Handler(Looper.getMainLooper()); // 初始化主线程Handler
this.mConversationHistory = new JSONArray(); // 初始化对话历史
}
public void callLLM(String userInput) {
new Thread(() -> {
try {
// 添加用户输入到对话历史
addMessageToHistory("user", userInput);
// 创建HTTP连接
URL url = new URL(LLM_API_URL);
HttpURLConnection conn = (HttpURLConnection) url.openConnection();
conn.setRequestMethod("POST");
conn.setRequestProperty("Content-Type", "application/json");
conn.setRequestProperty("Authorization", "Bearer " + API_KEY);
conn.setDoOutput(true);
conn.setConnectTimeout(10000);
conn.setReadTimeout(60000); // 延长读取超时以支持流式响应
// 构建请求体
JSONObject requestBody = new JSONObject();
requestBody.put("model", LLM_MODEL);
requestBody.put("messages", mConversationHistory);
requestBody.put("stream", true); // 启用流式响应
String jsonBody = requestBody.toString();
Log.d(TAG, "LLM Request: " + jsonBody);
// 发送请求
try (DataOutputStream dos = new DataOutputStream(conn.getOutputStream())) {
dos.write(jsonBody.getBytes(StandardCharsets.UTF_8));
dos.flush();
}
// 读取响应
int responseCode = conn.getResponseCode();
StringBuilder fullResponse = new StringBuilder();
StringBuilder accumulatedContent = new StringBuilder();
Log.d(TAG, "LLM Response Code: " + responseCode);
if (responseCode == 200) {
// 逐行读取流式响应
try (BufferedReader br = new BufferedReader(
new InputStreamReader(conn.getInputStream(), StandardCharsets.UTF_8))) {
String line;
while ((line = br.readLine()) != null) {
Log.d(TAG, "LLM Streaming Line: " + line);
// 处理SSE格式的响应
if (line.startsWith("data: ")) {
String dataPart = line.substring(6);
if (dataPart.equals("[DONE]")) {
// 流式响应结束
break;
}
try {
// 解析JSON
JSONObject chunkObj = new JSONObject(dataPart);
JSONArray choices = chunkObj.getJSONArray("choices");
if (choices.length() > 0) {
JSONObject choice = choices.getJSONObject(0);
JSONObject delta = choice.getJSONObject("delta");
if (delta.has("content")) {
String chunkContent = delta.getString("content");
accumulatedContent.append(chunkContent);
// 发送流式chunk到监听器
if (mListener != null) {
mMainHandler.post(() -> {
mListener.onLLMStreamingChunkReceived(chunkContent);
});
}
}
}
} catch (JSONException e) {
Log.e(TAG, "Failed to parse streaming chunk: " + e.getMessage());
}
}
fullResponse.append(line).append("\n");
}
}
String content = accumulatedContent.toString();
Log.d(TAG, "Full LLM Response: " + content);
// 添加AI回复到对话历史
addMessageToHistory("assistant", content);
if (mListener != null) {
mMainHandler.post(() -> {
mListener.onLLMResponseReceived(content);
});
}
} else {
// 读取错误响应
StringBuilder errorResponse = new StringBuilder();
try (BufferedReader br = new BufferedReader(
new InputStreamReader(conn.getErrorStream(), StandardCharsets.UTF_8))) {
String line;
while ((line = br.readLine()) != null) {
errorResponse.append(line);
}
}
throw new IOException("HTTP " + responseCode + ": " + errorResponse.toString());
}
} catch (Exception e) {
Log.e(TAG, "LLM call failed: " + e.getMessage());
if (mListener != null) {
mMainHandler.post(() -> {
mListener.onError("LLM调用失败: " + e.getMessage());
});
}
}
}).start();
}
/**
* 添加消息到对话历史
*/
private void addMessageToHistory(String role, String content) {
try {
JSONObject message = new JSONObject();
message.put("role", role);
message.put("content", content);
mConversationHistory.put(message);
} catch (JSONException e) {
Log.e(TAG, "Failed to add message to history: " + e.getMessage());
}
}
/**
* 清空对话历史
*/
public void clearConversationHistory() {
mConversationHistory = new JSONArray();
}
public void callTTS(String text, File outputFile) {
if (mListener != null) {
mMainHandler.post(() -> {
mListener.onError("TTS功能暂未实现");
});
}
}
private String extractContentFromResponse(String response) {
try {
int contentStart = response.indexOf("\"content\":\"") + 11;
int contentEnd = response.indexOf("\"", contentStart);
if (contentStart > 10 && contentEnd > contentStart) {
return response.substring(contentStart, contentEnd);
}
} catch (Exception e) {
Log.e(TAG, "Failed to parse response: " + e.getMessage());
}
return "抱歉,无法解析响应";
}
}

View File

@@ -0,0 +1,95 @@
package com.digitalperson.engine;
import android.content.Context;
import android.util.Log;
/**
* RKNN SenseVoice engine (copied from school_teacher, used as ASR backend on RK3588).
*
* It depends on native libs:
* - libsentencepiece.so
* - librknnrt.so
* - libsensevoiceEngine.so (built from app/src/main/cpp)
*/
public class SenseVoiceEngineRKNN implements WhisperEngine {
private static final String TAG = "SenseVoiceEngineRKNN";
private final long nativePtr;
private boolean mIsInitialized = false;
static {
try {
System.loadLibrary("sentencepiece");
System.loadLibrary("rknnrt");
System.loadLibrary("sensevoiceEngine");
Log.d(TAG, "Loaded libsentencepiece.so, librknnrt.so, libsensevoiceEngine.so");
} catch (UnsatisfiedLinkError e) {
Log.e(TAG, "Failed to load native libraries for SenseVoice", e);
throw e;
}
}
public SenseVoiceEngineRKNN(Context context) {
nativePtr = createSenseVoiceEngine();
if (nativePtr == 0) {
throw new RuntimeException("Failed to create native SenseVoice engine");
}
}
@Override
public boolean isInitialized() {
return mIsInitialized;
}
@Override
public boolean initialize(String modelPath, String vocabPath, boolean multilingual) {
// SenseVoice needs three files: model, embedding, bpe
String embeddingPath = modelPath.replace("sense-voice-encoder.rknn", "embedding.npy");
String bpePath = modelPath.replace("sense-voice-encoder.rknn", "chn_jpn_yue_eng_ko_spectok.bpe.model");
int ret = loadModel(nativePtr, modelPath, embeddingPath, bpePath);
if (ret == 0) {
mIsInitialized = true;
return true;
}
return false;
}
@Override
public void deinitialize() {
if (nativePtr != 0) {
freeModel(nativePtr);
}
mIsInitialized = false;
}
@Override
public String transcribeBuffer(float[] samples) {
if (!mIsInitialized) {
return "Error: Engine not initialized";
}
return transcribeBufferNative(nativePtr, samples, 0, false);
}
@Override
public String transcribeFile(String waveFile) {
if (!mIsInitialized) {
return "Error: Engine not initialized";
}
return transcribeFileNative(nativePtr, waveFile, 0, false);
}
public boolean loadModelDirectly(String modelPath, String embeddingPath, String bpePath) {
int ret = loadModel(nativePtr, modelPath, embeddingPath, bpePath);
if (ret == 0) {
mIsInitialized = true;
return true;
}
return false;
}
private native long createSenseVoiceEngine();
private native int loadModel(long nativePtr, String modelPath, String embeddingPath, String bpeModelPath);
private native void freeModel(long ptr);
private native String transcribeBufferNative(long ptr, float[] samples, int language, boolean use_itn);
private native String transcribeFileNative(long ptr, String waveFile, int language, boolean use_itn);
}

View File

@@ -0,0 +1,12 @@
package com.digitalperson.engine;
import java.io.IOException;
public interface WhisperEngine {
boolean isInitialized();
boolean initialize(String modelPath, String vocabPath, boolean multilingual) throws IOException;
void deinitialize();
String transcribeFile(String wavePath);
String transcribeBuffer(float[] samples);
}

View File

@@ -0,0 +1,74 @@
package com.digitalperson.metrics;
import android.util.Log;
import java.util.ArrayList;
import java.util.List;
public class TraceManager {
private static final String TAG = "TraceManager";
private static TraceManager instance;
private TraceSession currentSession;
private final List<TraceSession> sessionHistory = new ArrayList<>();
private TraceManager() {
}
public static synchronized TraceManager getInstance() {
if (instance == null) {
instance = new TraceManager();
}
return instance;
}
public TraceSession startNewTurn() {
endTurn(); // 清理之前的会话
currentSession = new TraceSession();
Log.i(TAG, "Started new trace turn: " + currentSession.getTurnId());
return currentSession;
}
public TraceSession getCurrent() {
return currentSession;
}
public void setCurrent(TraceSession session) {
this.currentSession = session;
}
public void endTurn() {
if (currentSession != null) {
Log.i(TAG, "Ending trace turn: " + currentSession.getTurnId());
currentSession.printSummary();
sessionHistory.add(currentSession);
// 限制历史会话数量,避免内存占用过大
if (sessionHistory.size() > 10) {
sessionHistory.remove(0);
}
currentSession = null;
}
}
public boolean hasActiveSession() {
return currentSession != null;
}
public List<TraceSession> getSessionHistory() {
return sessionHistory;
}
public TraceSession getSessionById(String turnId) {
for (TraceSession session : sessionHistory) {
if (session.getTurnId().equals(turnId)) {
return session;
}
}
return null;
}
public void clearSessionHistory() {
sessionHistory.clear();
Log.i(TAG, "Cleared session history");
}
}

View File

@@ -0,0 +1,264 @@
package com.digitalperson.metrics;
import android.util.Log;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
public class TraceSession {
private static final String TAG = "TraceSession";
private final String turnId;
private final Map<String, Long> marks = new ConcurrentHashMap<>();
private final Map<String, Long> durations = new ConcurrentHashMap<>();
private final Map<String, SegmentMetrics> segments = new ConcurrentHashMap<>();
public TraceSession() {
this.turnId = String.valueOf(System.currentTimeMillis());
}
public TraceSession(String turnId) {
this.turnId = turnId;
}
public String getTurnId() {
return turnId;
}
public void mark(String name) {
long timestamp = System.currentTimeMillis();
marks.put(name, timestamp);
Log.d(TAG, "Marked " + name + " at " + timestamp);
}
public void addDuration(String name, long deltaMs) {
// 使用循环确保原子性,处理并发更新
while (true) {
Long currentValue = durations.get(name);
long newValue = (currentValue != null) ? currentValue + deltaMs : deltaMs;
if (currentValue == null) {
// 如果键不存在,尝试添加
if (durations.putIfAbsent(name, newValue) == null) {
break;
}
} else {
// 如果键存在,尝试更新
if (durations.replace(name, currentValue, newValue)) {
break;
}
}
}
}
public void recordSegment(String segId, int textLen, long phonemizeMs, long encoderMs, long decoderMs, int invalidCount, int pcmSamples) {
segments.put(segId, new SegmentMetrics(textLen, phonemizeMs, encoderMs, decoderMs, invalidCount, pcmSamples));
}
public void recordSegment(String segId, int textLen, long phonemizeMs, long encoderMs, long decoderMs, int invalidCount) {
segments.put(segId, new SegmentMetrics(textLen, phonemizeMs, encoderMs, decoderMs, invalidCount, 0));
}
public Map<String, Long> getMarks() {
return marks;
}
public Map<String, Long> getDurations() {
return durations;
}
public Map<String, SegmentMetrics> getSegments() {
return segments;
}
public String summary() {
StringBuilder sb = new StringBuilder();
sb.append("=== 会话 " + turnId.substring(8) + " ===\n");
// ASR metrics
long asrTime = getASRTime();
if (asrTime > 0) {
sb.append("ASR处理: " + asrTime + "ms\n");
} else {
sb.append("ASR处理: 无数据\n");
}
// Derived metrics
if (marks.containsKey("tts_first_audio_play")) {
if (marks.containsKey("llm_first_chunk")) {
long llmToTts = marks.get("tts_first_audio_play") - marks.get("llm_first_chunk");
sb.append("LLM到TTS播放: " + llmToTts + "ms\n");
} else {
sb.append("LLM到TTS播放: 无数据\n");
}
if (marks.containsKey("recording_done")) {
long recToTts = marks.get("tts_first_audio_play") - marks.get("recording_done");
sb.append("录音到TTS播放: " + recToTts + "ms\n");
} else {
sb.append("录音到TTS播放: 无数据\n");
}
} else {
sb.append("TTS播放未开始\n");
}
// Durations
if (!durations.isEmpty()) {
sb.append("处理耗时:\n");
if (durations.containsKey("phonemize_ms_total")) {
sb.append(" 音素化: " + durations.get("phonemize_ms_total") + "ms\n");
} else {
sb.append(" 音素化: 无数据\n");
}
if (durations.containsKey("encoder_ms_total")) {
sb.append(" 编码器: " + durations.get("encoder_ms_total") + "ms\n");
} else {
sb.append(" 编码器: 无数据\n");
}
if (durations.containsKey("decoder_ms_total")) {
sb.append(" 解码器: " + durations.get("decoder_ms_total") + "ms\n");
} else {
sb.append(" 解码器: 无数据\n");
}
} else {
sb.append("处理耗时: 无数据\n");
}
// Segments count
sb.append("片段数: " + segments.size() + "\n");
// Invalid samples total
long invalidTotal = 0;
for (SegmentMetrics seg : segments.values()) {
invalidTotal += seg.invalidCount;
}
sb.append("无效样本数: " + invalidTotal + "\n");
sb.append("\n");
return sb.toString();
}
// 获取ASR处理耗时
public long getASRTime() {
if (marks.containsKey("asr_processing_start") && marks.containsKey("asr_processing_done")) {
return marks.get("asr_processing_done") - marks.get("asr_processing_start");
}
return 0;
}
// 标记ASR处理开始
public void markASRStart() {
mark("asr_processing_start");
}
// 标记ASR处理完成
public void markASREnd() {
mark("asr_processing_done");
}
// 标记录音完成
public void markRecordingDone() {
mark("recording_done");
}
// 标记LLM首字回复
public void markLlmFirstChunk() {
mark("llm_first_chunk");
}
// 标记LLM回复完成
public void markLlmDone() {
mark("llm_done");
}
// 标记LLM响应收到
public void markLlmResponseReceived() {
mark("llm_response_received");
}
// 标记TTS请求入队
public void markTtsRequestEnqueued() {
mark("tts_request_enqueued");
}
// 标记TTS合成开始
public void markTtsSynthesisStart() {
mark("tts_synthesis_start");
}
// 标记TTS第一段PCM准备就绪
public void markTtsFirstPcmReady() {
mark("tts_first_pcm_ready");
}
// 标记TTS第一段音频播放
public void markTtsFirstAudioPlay() {
mark("tts_first_audio_play");
}
// 标记TTS合成完成
public void markTtsDone() {
mark("tts_done");
}
// 获取录音完成到LLM首字回复的时间
public long getRecordingToLlmFirstChunkMs() {
return getStageTime("recording_done", "llm_first_chunk");
}
// 获取LLM首字回复到TTS第一段音频播放的时间
public long getLlmFirstChunkToTtsFirstAudioMs() {
return getStageTime("llm_first_chunk", "tts_first_audio_play");
}
// 获取录音完成到TTS第一段音频播放的时间
public long getRecordingToTtsFirstAudioMs() {
return getStageTime("recording_done", "tts_first_audio_play");
}
// 获取指定阶段的耗时
public long getStageTime(String startMark, String endMark) {
if (marks.containsKey(startMark) && marks.containsKey(endMark)) {
return marks.get(endMark) - marks.get(startMark);
}
return 0;
}
public String getShortSummary() {
StringBuilder sb = new StringBuilder();
// Derived metrics
if (marks.containsKey("tts_first_audio_play")) {
if (marks.containsKey("llm_first_chunk")) {
long llmToTts = marks.get("tts_first_audio_play") - marks.get("llm_first_chunk");
sb.append("LLM→TTS: " + llmToTts + "ms");
}
if (marks.containsKey("recording_done")) {
long recToTts = marks.get("tts_first_audio_play") - marks.get("recording_done");
sb.append(" | 录音→TTS: " + recToTts + "ms");
}
}
return sb.toString();
}
public void printSummary() {
Log.i(TAG, summary());
}
public static class SegmentMetrics {
public final int textLen;
public final long phonemizeMs;
public final long encoderMs;
public final long decoderMs;
public final int invalidCount;
public final int pcmSamples;
public SegmentMetrics(int textLen, long phonemizeMs, long encoderMs, long decoderMs, int invalidCount, int pcmSamples) {
this.textLen = textLen;
this.phonemizeMs = phonemizeMs;
this.encoderMs = encoderMs;
this.decoderMs = decoderMs;
this.invalidCount = invalidCount;
this.pcmSamples = pcmSamples;
}
}
}

View File

@@ -0,0 +1,11 @@
package com.k2fsa.sherpa.onnx
data class FeatureConfig(
var sampleRate: Int = 16000,
var featureDim: Int = 80,
var dither: Float = 0.0f
)
fun getFeatureConfig(sampleRate: Int, featureDim: Int): FeatureConfig {
return FeatureConfig(sampleRate = sampleRate, featureDim = featureDim)
}

View File

@@ -0,0 +1,7 @@
package com.k2fsa.sherpa.onnx
data class HomophoneReplacerConfig(
var dictDir: String = "", // unused
var lexicon: String = "",
var ruleFsts: String = "",
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,32 @@
package com.k2fsa.sherpa.onnx
class OfflineStream(var ptr: Long) {
fun acceptWaveform(samples: FloatArray, sampleRate: Int) =
acceptWaveform(ptr, samples, sampleRate)
protected fun finalize() {
if (ptr != 0L) {
delete(ptr)
ptr = 0
}
}
fun release() = finalize()
fun use(block: (OfflineStream) -> Unit) {
try {
block(this)
} finally {
release()
}
}
private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int)
private external fun delete(ptr: Long)
companion object {
init {
System.loadLibrary("sherpa-onnx-jni")
}
}
}

View File

@@ -0,0 +1,7 @@
package com.k2fsa.sherpa.onnx
data class QnnConfig(
var backendLib: String = "",
var contextBinary: String = "",
var systemLib: String = "",
)

View File

@@ -0,0 +1,373 @@
// Copyright (c) 2023 Xiaomi Corporation
package com.k2fsa.sherpa.onnx
import android.content.res.AssetManager
data class OfflineTtsVitsModelConfig(
var model: String = "",
var lexicon: String = "",
var tokens: String = "",
var dataDir: String = "",
var dictDir: String = "", // unused
var noiseScale: Float = 0.667f,
var noiseScaleW: Float = 0.8f,
var lengthScale: Float = 1.0f,
)
data class OfflineTtsMatchaModelConfig(
var acousticModel: String = "",
var vocoder: String = "",
var lexicon: String = "",
var tokens: String = "",
var dataDir: String = "",
var dictDir: String = "", // unused
var noiseScale: Float = 1.0f,
var lengthScale: Float = 1.0f,
)
data class OfflineTtsKokoroModelConfig(
var model: String = "",
var voices: String = "",
var tokens: String = "",
var dataDir: String = "",
var lexicon: String = "",
var lang: String = "",
var dictDir: String = "", // unused
var lengthScale: Float = 1.0f,
)
data class OfflineTtsKittenModelConfig(
var model: String = "",
var voices: String = "",
var tokens: String = "",
var dataDir: String = "",
var lengthScale: Float = 1.0f,
)
/**
* Configuration for Pocket TTS models.
*
* See https://k2-fsa.github.io/sherpa/onnx/tts/pocket/index.html for details.
*
* @property lmFlow Path to the LM flow model (.onnx)
* @property lmMain Path to the LM main model (.onnx)
* @property encoder Path to the encoder model (.onnx)
* @property decoder Path to the decoder model (.onnx)
* @property textConditioner Path to the text conditioner model (.onnx)
* @property vocabJson Path to vocabulary JSON file
* @property tokenScoresJson Path to token scores JSON file
*/
data class OfflineTtsPocketModelConfig(
var lmFlow: String = "",
var lmMain: String = "",
var encoder: String = "",
var decoder: String = "",
var textConditioner: String = "",
var vocabJson: String = "",
var tokenScoresJson: String = "",
)
data class OfflineTtsModelConfig(
var vits: OfflineTtsVitsModelConfig = OfflineTtsVitsModelConfig(),
var matcha: OfflineTtsMatchaModelConfig = OfflineTtsMatchaModelConfig(),
var kokoro: OfflineTtsKokoroModelConfig = OfflineTtsKokoroModelConfig(),
var kitten: OfflineTtsKittenModelConfig = OfflineTtsKittenModelConfig(),
val pocket: OfflineTtsPocketModelConfig = OfflineTtsPocketModelConfig(),
var numThreads: Int = 1,
var debug: Boolean = false,
var provider: String = "cpu",
)
data class OfflineTtsConfig(
var model: OfflineTtsModelConfig = OfflineTtsModelConfig(),
var ruleFsts: String = "",
var ruleFars: String = "",
var maxNumSentences: Int = 1,
var silenceScale: Float = 0.2f,
)
class GeneratedAudio(
val samples: FloatArray,
val sampleRate: Int,
) {
fun save(filename: String) =
saveImpl(filename = filename, samples = samples, sampleRate = sampleRate)
private external fun saveImpl(
filename: String,
samples: FloatArray,
sampleRate: Int
): Boolean
}
data class GenerationConfig(
var silenceScale: Float = 0.2f,
var speed: Float = 1.0f,
var sid: Int = 0,
var referenceAudio: FloatArray? = null,
var referenceSampleRate: Int = 0,
var referenceText: String? = null,
var numSteps: Int = 5,
var extra: Map<String, String>? = null
)
class OfflineTts(
assetManager: AssetManager? = null,
var config: OfflineTtsConfig,
) {
private var ptr: Long
init {
ptr = if (assetManager != null) {
newFromAsset(assetManager, config)
} else {
newFromFile(config)
}
}
fun sampleRate() = getSampleRate(ptr)
fun numSpeakers() = getNumSpeakers(ptr)
fun generate(
text: String,
sid: Int = 0,
speed: Float = 1.0f
): GeneratedAudio {
return toGeneratedAudio(generateImpl(ptr, text = text, sid = sid, speed = speed))
}
fun generateWithCallback(
text: String,
sid: Int = 0,
speed: Float = 1.0f,
callback: (samples: FloatArray) -> Int
): GeneratedAudio {
return toGeneratedAudio(generateWithCallbackImpl(
ptr,
text = text,
sid = sid,
speed = speed,
callback = callback
))
}
fun generateWithConfig(
text: String,
config: GenerationConfig
): GeneratedAudio {
return toGeneratedAudio(generateWithConfigImpl(ptr, text, config, null))
}
fun generateWithConfigAndCallback(
text: String,
config: GenerationConfig,
callback: (samples: FloatArray) -> Int
): GeneratedAudio {
return toGeneratedAudio(generateWithConfigImpl(ptr, text, config, callback))
}
@Suppress("UNCHECKED_CAST")
private fun toGeneratedAudio(obj: Any): GeneratedAudio {
return when (obj) {
is GeneratedAudio -> obj
is Array<*> -> {
// Native may return Object[]{ float[] samples, int sampleRate }
val samples = obj.getOrNull(0) as? FloatArray
?: error("Unexpected native TTS return: element[0] is not FloatArray")
val sampleRate = (obj.getOrNull(1) as? Number)?.toInt()
?: error("Unexpected native TTS return: element[1] is not Int/Number")
GeneratedAudio(samples = samples, sampleRate = sampleRate)
}
else -> error("Unexpected native TTS return type: ${obj::class.java.name}")
}
}
fun allocate(assetManager: AssetManager? = null) {
if (ptr == 0L) {
ptr = if (assetManager != null) {
newFromAsset(assetManager, config)
} else {
newFromFile(config)
}
}
}
fun free() {
if (ptr != 0L) {
delete(ptr)
ptr = 0
}
}
protected fun finalize() {
if (ptr != 0L) {
delete(ptr)
ptr = 0
}
}
fun release() = finalize()
private external fun newFromAsset(
assetManager: AssetManager,
config: OfflineTtsConfig,
): Long
private external fun newFromFile(
config: OfflineTtsConfig,
): Long
private external fun delete(ptr: Long)
private external fun getSampleRate(ptr: Long): Int
private external fun getNumSpeakers(ptr: Long): Int
// The returned array has two entries:
// - the first entry is an 1-D float array containing audio samples.
// Each sample is normalized to the range [-1, 1]
// - the second entry is the sample rate
private external fun generateImpl(
ptr: Long,
text: String,
sid: Int = 0,
speed: Float = 1.0f
): Any
private external fun generateWithCallbackImpl(
ptr: Long,
text: String,
sid: Int = 0,
speed: Float = 1.0f,
callback: (samples: FloatArray) -> Int
): Any
private external fun generateWithConfigImpl(
ptr: Long,
text: String,
config: GenerationConfig,
callback: ((samples: FloatArray) -> Int)?
): Any
companion object {
init {
System.loadLibrary("sherpa-onnx-jni")
}
}
}
// please refer to
// https://k2-fsa.github.io/sherpa/onnx/tts/pretrained_models/index.html
// to download models
fun getOfflineTtsConfig(
modelDir: String,
modelName: String, // for VITS
acousticModelName: String, // for Matcha
vocoder: String, // for Matcha
voices: String, // for Kokoro or kitten
lexicon: String,
dataDir: String,
dictDir: String, // unused
ruleFsts: String,
ruleFars: String,
numThreads: Int? = null,
isKitten: Boolean = false
): OfflineTtsConfig {
// For Matcha TTS, please set
// acousticModelName, vocoder
// For Kokoro TTS, please set
// modelName, voices
// For Kitten TTS, please set
// modelName, voices, isKitten
// For VITS, please set
// modelName
val numberOfThreads = if (numThreads != null) {
numThreads
} else if (voices.isNotEmpty()) {
// for Kokoro and Kitten TTS models, we use more threads
4
} else {
2
}
if (modelName.isEmpty() && acousticModelName.isEmpty()) {
throw IllegalArgumentException("Please specify a TTS model")
}
if (modelName.isNotEmpty() && acousticModelName.isNotEmpty()) {
throw IllegalArgumentException("Please specify either a VITS or a Matcha model, but not both")
}
if (acousticModelName.isNotEmpty() && vocoder.isEmpty()) {
throw IllegalArgumentException("Please provide vocoder for Matcha TTS")
}
val vits = if (modelName.isNotEmpty() && voices.isEmpty()) {
OfflineTtsVitsModelConfig(
model = "$modelDir/$modelName",
lexicon = "$modelDir/$lexicon",
tokens = "$modelDir/tokens.txt",
dataDir = dataDir,
)
} else {
OfflineTtsVitsModelConfig()
}
val matcha = if (acousticModelName.isNotEmpty()) {
OfflineTtsMatchaModelConfig(
acousticModel = "$modelDir/$acousticModelName",
vocoder = vocoder,
lexicon = "$modelDir/$lexicon",
tokens = "$modelDir/tokens.txt",
dataDir = dataDir,
)
} else {
OfflineTtsMatchaModelConfig()
}
val kokoro = if (voices.isNotEmpty() && !isKitten) {
OfflineTtsKokoroModelConfig(
model = "$modelDir/$modelName",
voices = "$modelDir/$voices",
tokens = "$modelDir/tokens.txt",
dataDir = dataDir,
lexicon = when {
lexicon == "" -> lexicon
"," in lexicon -> lexicon
else -> "$modelDir/$lexicon"
},
)
} else {
OfflineTtsKokoroModelConfig()
}
val kitten = if (isKitten) {
OfflineTtsKittenModelConfig(
model = "$modelDir/$modelName",
voices = "$modelDir/$voices",
tokens = "$modelDir/tokens.txt",
dataDir = dataDir,
)
} else {
OfflineTtsKittenModelConfig()
}
return OfflineTtsConfig(
model = OfflineTtsModelConfig(
vits = vits,
matcha = matcha,
kokoro = kokoro,
kitten = kitten,
numThreads = numberOfThreads,
debug = true,
provider = "cpu",
),
ruleFsts = ruleFsts,
ruleFars = ruleFars,
)
}

View File

@@ -0,0 +1,149 @@
// Copyright (c) 2023 Xiaomi Corporation
package com.k2fsa.sherpa.onnx
import android.content.res.AssetManager
data class SileroVadModelConfig(
var model: String = "",
var threshold: Float = 0.5F,
var minSilenceDuration: Float = 0.25F,
var minSpeechDuration: Float = 0.25F,
var windowSize: Int = 512,
var maxSpeechDuration: Float = 5.0F,
)
data class TenVadModelConfig(
var model: String = "",
var threshold: Float = 0.5F,
var minSilenceDuration: Float = 0.25F,
var minSpeechDuration: Float = 0.25F,
var windowSize: Int = 256,
var maxSpeechDuration: Float = 5.0F,
)
data class VadModelConfig(
var sileroVadModelConfig: SileroVadModelConfig = SileroVadModelConfig(),
var tenVadModelConfig: TenVadModelConfig = TenVadModelConfig(),
var sampleRate: Int = 16000,
var numThreads: Int = 1,
var provider: String = "cpu",
var debug: Boolean = false,
)
class SpeechSegment(val start: Int, val samples: FloatArray)
class Vad(
assetManager: AssetManager? = null,
var config: VadModelConfig,
) {
private var ptr: Long
init {
if (assetManager != null) {
ptr = newFromAsset(assetManager, config)
} else {
ptr = newFromFile(config)
}
}
protected fun finalize() {
if (ptr != 0L) {
delete(ptr)
ptr = 0
}
}
fun release() = finalize()
fun compute(samples: FloatArray): Float = compute(ptr, samples)
fun acceptWaveform(samples: FloatArray) = acceptWaveform(ptr, samples)
fun empty(): Boolean = empty(ptr)
fun pop() = pop(ptr)
fun front(): SpeechSegment {
return front(ptr)
}
fun clear() = clear(ptr)
fun isSpeechDetected(): Boolean = isSpeechDetected(ptr)
fun reset() = reset(ptr)
fun flush() = flush(ptr)
private external fun delete(ptr: Long)
private external fun newFromAsset(
assetManager: AssetManager,
config: VadModelConfig,
): Long
private external fun newFromFile(
config: VadModelConfig,
): Long
private external fun acceptWaveform(ptr: Long, samples: FloatArray)
private external fun compute(ptr: Long, samples: FloatArray): Float
private external fun empty(ptr: Long): Boolean
private external fun pop(ptr: Long)
private external fun clear(ptr: Long)
private external fun front(ptr: Long): SpeechSegment
private external fun isSpeechDetected(ptr: Long): Boolean
private external fun reset(ptr: Long)
private external fun flush(ptr: Long)
companion object {
init {
System.loadLibrary("sherpa-onnx-jni")
}
}
}
// Please visit
// https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx
// to download silero_vad.onnx
// and put it inside the assets/
// directory
//
// For ten-vad, please use
// https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/ten-vad.onnx
//
fun getVadModelConfig(type: Int): VadModelConfig? {
when (type) {
0 -> {
return VadModelConfig(
sileroVadModelConfig = SileroVadModelConfig(
model = "silero_vad.onnx",
threshold = 0.5F,
minSilenceDuration = 0.25F,
minSpeechDuration = 0.25F,
windowSize = 512,
),
sampleRate = 16000,
numThreads = 1,
provider = "cpu",
)
}
1 -> {
return VadModelConfig(
tenVadModelConfig = TenVadModelConfig(
model = "ten-vad.onnx",
threshold = 0.5F,
minSilenceDuration = 0.25F,
minSpeechDuration = 0.25F,
windowSize = 256,
),
sampleRate = 16000,
numThreads = 1,
provider = "cpu",
)
}
}
return null
}