initial commit
This commit is contained in:
689
app/src/main/java/com/digitalperson/MainActivity.kt
Normal file
689
app/src/main/java/com/digitalperson/MainActivity.kt
Normal file
@@ -0,0 +1,689 @@
|
||||
package com.digitalperson
|
||||
|
||||
import android.Manifest
|
||||
import android.content.pm.PackageManager
|
||||
import android.media.AudioAttributes
|
||||
import android.media.AudioFormat
|
||||
import android.media.AudioManager
|
||||
import android.media.AudioRecord
|
||||
import android.media.AudioTrack
|
||||
import android.media.MediaRecorder
|
||||
import android.media.audiofx.AcousticEchoCanceler
|
||||
import android.media.audiofx.NoiseSuppressor
|
||||
import android.os.Bundle
|
||||
import android.os.SystemClock
|
||||
import android.text.method.ScrollingMovementMethod
|
||||
import android.util.Log
|
||||
import android.widget.Button
|
||||
import android.widget.TextView
|
||||
import android.widget.Toast
|
||||
import androidx.appcompat.app.AppCompatActivity
|
||||
import androidx.core.app.ActivityCompat
|
||||
import com.digitalperson.cloud.CloudApiManager
|
||||
import com.digitalperson.engine.SenseVoiceEngineRKNN
|
||||
import com.digitalperson.metrics.TraceManager
|
||||
import com.digitalperson.metrics.TraceSession
|
||||
import com.k2fsa.sherpa.onnx.OfflineTts
|
||||
import com.k2fsa.sherpa.onnx.SileroVadModelConfig
|
||||
import com.k2fsa.sherpa.onnx.Vad
|
||||
import com.k2fsa.sherpa.onnx.VadModelConfig
|
||||
import com.k2fsa.sherpa.onnx.getOfflineTtsConfig
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.Job
|
||||
import kotlinx.coroutines.SupervisorJob
|
||||
import kotlinx.coroutines.cancel
|
||||
import kotlinx.coroutines.isActive
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.withContext
|
||||
import java.io.File
|
||||
import java.io.FileOutputStream
|
||||
import java.util.concurrent.LinkedBlockingQueue
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
import kotlin.math.max
|
||||
|
||||
private const val TAG = "DigitalPerson"
|
||||
private const val REQUEST_RECORD_AUDIO_PERMISSION = 200
|
||||
|
||||
class MainActivity : AppCompatActivity() {
|
||||
|
||||
private lateinit var startButton: Button
|
||||
private lateinit var stopButton: Button
|
||||
private lateinit var textView: TextView
|
||||
|
||||
private lateinit var vad: Vad
|
||||
private var senseVoice: SenseVoiceEngineRKNN? = null
|
||||
private var tts: OfflineTts? = null
|
||||
private var track: AudioTrack? = null
|
||||
|
||||
private var aec: AcousticEchoCanceler? = null
|
||||
private var ns: NoiseSuppressor? = null
|
||||
|
||||
private var audioRecord: AudioRecord? = null
|
||||
private val audioSource = MediaRecorder.AudioSource.MIC
|
||||
private val sampleRateInHz = 16000
|
||||
private val channelConfig = AudioFormat.CHANNEL_IN_MONO
|
||||
private val audioFormat = AudioFormat.ENCODING_PCM_16BIT
|
||||
private val permissions: Array<String> = arrayOf(Manifest.permission.RECORD_AUDIO)
|
||||
|
||||
@Volatile
|
||||
private var isRecording: Boolean = false
|
||||
|
||||
private val ioScope = CoroutineScope(SupervisorJob() + Dispatchers.IO)
|
||||
private var recordingJob: Job? = null
|
||||
private val nativeLock = Any()
|
||||
|
||||
private lateinit var cloudApiManager: CloudApiManager
|
||||
private val segmenter = StreamingTextSegmenter()
|
||||
|
||||
private sealed class TtsQueueItem {
|
||||
data class Segment(val text: String) : TtsQueueItem()
|
||||
data object End : TtsQueueItem()
|
||||
}
|
||||
|
||||
private val ttsQueue = LinkedBlockingQueue<TtsQueueItem>()
|
||||
private val ttsStopped = AtomicBoolean(false)
|
||||
private val ttsWorkerRunning = AtomicBoolean(false)
|
||||
|
||||
private var currentTrace: TraceSession? = null
|
||||
|
||||
private var lastUiText: String = ""
|
||||
@Volatile private var llmInFlight: Boolean = false
|
||||
|
||||
override fun onRequestPermissionsResult(
|
||||
requestCode: Int,
|
||||
permissions: Array<String>,
|
||||
grantResults: IntArray
|
||||
) {
|
||||
super.onRequestPermissionsResult(requestCode, permissions, grantResults)
|
||||
val ok = requestCode == REQUEST_RECORD_AUDIO_PERMISSION &&
|
||||
grantResults.isNotEmpty() &&
|
||||
grantResults[0] == PackageManager.PERMISSION_GRANTED
|
||||
if (!ok) {
|
||||
Log.e(TAG, "Audio record is disallowed")
|
||||
finish()
|
||||
}
|
||||
}
|
||||
|
||||
override fun onCreate(savedInstanceState: Bundle?) {
|
||||
super.onCreate(savedInstanceState)
|
||||
setContentView(R.layout.activity_main)
|
||||
|
||||
ActivityCompat.requestPermissions(this, permissions, REQUEST_RECORD_AUDIO_PERMISSION)
|
||||
|
||||
startButton = findViewById(R.id.start_button)
|
||||
stopButton = findViewById(R.id.stop_button)
|
||||
textView = findViewById(R.id.my_text)
|
||||
textView.movementMethod = ScrollingMovementMethod()
|
||||
|
||||
startButton.setOnClickListener { onStartClicked() }
|
||||
stopButton.setOnClickListener { onStopClicked(userInitiated = true) }
|
||||
|
||||
// 避免 UI 线程重初始化导致 ANR:在后台初始化模型与 AudioTrack
|
||||
startButton.isEnabled = false
|
||||
stopButton.isEnabled = false
|
||||
textView.text = "初始化中…"
|
||||
ioScope.launch {
|
||||
try {
|
||||
Log.i(TAG, "Init VAD + SenseVoice(RKNN) + TTS (background)")
|
||||
synchronized(nativeLock) {
|
||||
initVadModel()
|
||||
initSenseVoiceModel()
|
||||
}
|
||||
withContext(Dispatchers.Main) {
|
||||
initTtsAndAudioTrack()
|
||||
textView.text = getString(R.string.hint)
|
||||
startButton.isEnabled = true
|
||||
stopButton.isEnabled = false
|
||||
}
|
||||
} catch (t: Throwable) {
|
||||
Log.e(TAG, "Initialization failed: ${t.message}", t)
|
||||
withContext(Dispatchers.Main) {
|
||||
textView.text = "初始化失败:${t.javaClass.simpleName}: ${t.message}"
|
||||
Toast.makeText(
|
||||
this@MainActivity,
|
||||
"初始化失败(请看 Logcat): ${t.javaClass.simpleName}",
|
||||
Toast.LENGTH_LONG
|
||||
).show()
|
||||
startButton.isEnabled = false
|
||||
stopButton.isEnabled = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cloudApiManager = CloudApiManager(object : CloudApiManager.CloudApiListener {
|
||||
private var llmFirstChunkMarked = false
|
||||
|
||||
override fun onLLMResponseReceived(response: String) {
|
||||
currentTrace?.markLlmDone()
|
||||
llmInFlight = false
|
||||
// flush remaining buffer into TTS
|
||||
for (seg in segmenter.flush()) {
|
||||
enqueueTtsSegment(seg)
|
||||
}
|
||||
// signal queue end (no more segments after this)
|
||||
ttsQueue.offer(TtsQueueItem.End)
|
||||
}
|
||||
|
||||
override fun onLLMStreamingChunkReceived(chunk: String) {
|
||||
if (!llmFirstChunkMarked) {
|
||||
llmFirstChunkMarked = true
|
||||
currentTrace?.markLlmFirstChunk()
|
||||
}
|
||||
appendToUi(chunk)
|
||||
|
||||
val segments = segmenter.processChunk(chunk)
|
||||
for (seg in segments) {
|
||||
enqueueTtsSegment(seg)
|
||||
}
|
||||
}
|
||||
|
||||
override fun onTTSAudioReceived(audioFilePath: String) {
|
||||
// unused
|
||||
}
|
||||
|
||||
override fun onError(errorMessage: String) {
|
||||
llmInFlight = false
|
||||
Toast.makeText(this@MainActivity, errorMessage, Toast.LENGTH_LONG).show()
|
||||
onStopClicked(userInitiated = false)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
override fun onDestroy() {
|
||||
super.onDestroy()
|
||||
onStopClicked(userInitiated = false)
|
||||
ioScope.cancel()
|
||||
synchronized(nativeLock) {
|
||||
try {
|
||||
vad.release()
|
||||
} catch (_: Throwable) {
|
||||
}
|
||||
try {
|
||||
senseVoice?.deinitialize()
|
||||
} catch (_: Throwable) {
|
||||
}
|
||||
}
|
||||
try {
|
||||
tts?.release()
|
||||
} catch (_: Throwable) {
|
||||
}
|
||||
}
|
||||
|
||||
private fun onStartClicked() {
|
||||
if (isRecording) return
|
||||
|
||||
if (!initMicrophone()) {
|
||||
Toast.makeText(this, "麦克风初始化失败/无权限", Toast.LENGTH_SHORT).show()
|
||||
return
|
||||
}
|
||||
|
||||
// Start a new trace turn
|
||||
currentTrace = TraceManager.getInstance().startNewTurn()
|
||||
currentTrace?.mark("turn_start")
|
||||
llmInFlight = false
|
||||
|
||||
lastUiText = ""
|
||||
textView.text = ""
|
||||
|
||||
ttsStopped.set(false)
|
||||
ttsQueue.clear()
|
||||
segmenter.reset()
|
||||
|
||||
vad.reset()
|
||||
audioRecord!!.startRecording()
|
||||
isRecording = true
|
||||
|
||||
startButton.isEnabled = false
|
||||
stopButton.isEnabled = true
|
||||
|
||||
recordingJob?.cancel()
|
||||
recordingJob = ioScope.launch {
|
||||
processSamplesLoop()
|
||||
}
|
||||
}
|
||||
|
||||
private fun onStopClicked(userInitiated: Boolean) {
|
||||
isRecording = false
|
||||
try {
|
||||
audioRecord?.stop()
|
||||
} catch (_: Throwable) {
|
||||
}
|
||||
try {
|
||||
audioRecord?.release()
|
||||
} catch (_: Throwable) {
|
||||
}
|
||||
audioRecord = null
|
||||
|
||||
recordingJob?.cancel()
|
||||
recordingJob = null
|
||||
|
||||
ttsStopped.set(true)
|
||||
ttsQueue.clear()
|
||||
// wake worker if waiting
|
||||
ttsQueue.offer(TtsQueueItem.End)
|
||||
|
||||
try {
|
||||
track?.pause()
|
||||
track?.flush()
|
||||
} catch (_: Throwable) {
|
||||
}
|
||||
try { aec?.release() } catch (_: Throwable) {}
|
||||
try { ns?.release() } catch (_: Throwable) {}
|
||||
aec = null
|
||||
ns = null
|
||||
startButton.isEnabled = true
|
||||
stopButton.isEnabled = false
|
||||
|
||||
if (userInitiated) {
|
||||
TraceManager.getInstance().endTurn()
|
||||
currentTrace = null
|
||||
}
|
||||
}
|
||||
|
||||
private fun initVadModel() {
|
||||
// 你的 VAD 模型在 assets/vad_model/ 下
|
||||
val config = VadModelConfig(
|
||||
sileroVadModelConfig = SileroVadModelConfig(
|
||||
model = "vad_model/silero_vad.onnx",
|
||||
threshold = 0.5F,
|
||||
minSilenceDuration = 0.25F,
|
||||
minSpeechDuration = 0.25F,
|
||||
windowSize = 512,
|
||||
),
|
||||
sampleRate = sampleRateInHz,
|
||||
numThreads = 1,
|
||||
provider = "cpu",
|
||||
)
|
||||
vad = Vad(assetManager = application.assets, config = config)
|
||||
}
|
||||
|
||||
private fun initSenseVoiceModel() {
|
||||
Log.i(TAG, "ASR: init SenseVoice RKNN (scheme A)")
|
||||
// Copy assets/sensevoice_models/* -> filesDir/sensevoice_models/*
|
||||
val modelDir = copySenseVoiceAssetsToInternal()
|
||||
val modelPath = File(modelDir, "sense-voice-encoder.rknn").absolutePath
|
||||
val embeddingPath = File(modelDir, "embedding.npy").absolutePath
|
||||
val bpePath = File(modelDir, "chn_jpn_yue_eng_ko_spectok.bpe.model").absolutePath
|
||||
|
||||
// Print quick diagnostics for native libs + model files
|
||||
try {
|
||||
val libDir = applicationInfo.nativeLibraryDir
|
||||
Log.i(TAG, "nativeLibraryDir=$libDir")
|
||||
try {
|
||||
val names = File(libDir).list()?.joinToString(", ") ?: "(empty)"
|
||||
Log.i(TAG, "nativeLibraryDir files: $names")
|
||||
} catch (t: Throwable) {
|
||||
Log.w(TAG, "Failed to list nativeLibraryDir: ${t.message}")
|
||||
}
|
||||
} catch (_: Throwable) {
|
||||
}
|
||||
Log.i(TAG, "SenseVoice model paths:")
|
||||
Log.i(TAG, " model=$modelPath exists=${File(modelPath).exists()} size=${File(modelPath).length()}")
|
||||
Log.i(TAG, " embedding=$embeddingPath exists=${File(embeddingPath).exists()} size=${File(embeddingPath).length()}")
|
||||
Log.i(TAG, " bpe=$bpePath exists=${File(bpePath).exists()} size=${File(bpePath).length()}")
|
||||
|
||||
val t0 = SystemClock.elapsedRealtime()
|
||||
val engine = try {
|
||||
SenseVoiceEngineRKNN(this)
|
||||
} catch (e: UnsatisfiedLinkError) {
|
||||
// Most common: libsensevoiceEngine.so not packaged/built, or dependent libs missing
|
||||
throw IllegalStateException("Load native libraries failed: ${e.message}", e)
|
||||
}
|
||||
|
||||
val ok = try {
|
||||
engine.loadModelDirectly(modelPath, embeddingPath, bpePath)
|
||||
} catch (t: Throwable) {
|
||||
throw IllegalStateException("SenseVoice loadModelDirectly crashed: ${t.message}", t)
|
||||
}
|
||||
|
||||
val dt = SystemClock.elapsedRealtime() - t0
|
||||
Log.i(TAG, "SenseVoice loadModelDirectly ok=$ok costMs=$dt")
|
||||
if (!ok) throw IllegalStateException("SenseVoiceEngineRKNN loadModelDirectly returned false")
|
||||
|
||||
senseVoice = engine
|
||||
}
|
||||
|
||||
private fun initTtsAndAudioTrack() {
|
||||
try {
|
||||
// 你放入的 sherpa-onnx VITS 中文模型目录:
|
||||
// assets/tts_model/sherpa-onnx-vits-zh-ll/{model.onnx,tokens.txt,lexicon.txt,...}
|
||||
val modelDir = "tts_model/sherpa-onnx-vits-zh-ll"
|
||||
val modelName = "model.onnx"
|
||||
val lexicon = "lexicon.txt"
|
||||
val dataDir = ""
|
||||
|
||||
val ttsConfig = getOfflineTtsConfig(
|
||||
modelDir = modelDir,
|
||||
modelName = modelName,
|
||||
acousticModelName = "",
|
||||
vocoder = "",
|
||||
voices = "",
|
||||
lexicon = lexicon,
|
||||
dataDir = dataDir,
|
||||
dictDir = "",
|
||||
// 中文规范化规则(目录里已有这些 fst)
|
||||
ruleFsts = "$modelDir/phone.fst,$modelDir/date.fst,$modelDir/number.fst,$modelDir/new_heteronym.fst",
|
||||
ruleFars = "",
|
||||
numThreads = null,
|
||||
isKitten = false
|
||||
)
|
||||
tts = OfflineTts(assetManager = application.assets, config = ttsConfig)
|
||||
} catch (t: Throwable) {
|
||||
Log.e(TAG, "Init TTS failed: ${t.message}", t)
|
||||
tts = null
|
||||
runOnUiThread {
|
||||
Toast.makeText(
|
||||
this,
|
||||
"TTS 初始化失败:请确认 assets/tts_model/sherpa-onnx-vits-zh-ll/ 下有 model.onnx、tokens.txt、lexicon.txt 以及 phone/date/number/new_heteronym.fst",
|
||||
Toast.LENGTH_LONG
|
||||
).show()
|
||||
}
|
||||
}
|
||||
|
||||
val t = tts ?: return
|
||||
val sr = t.sampleRate()
|
||||
val bufLength = AudioTrack.getMinBufferSize(
|
||||
sr,
|
||||
AudioFormat.CHANNEL_OUT_MONO,
|
||||
AudioFormat.ENCODING_PCM_FLOAT
|
||||
)
|
||||
val attr = AudioAttributes.Builder()
|
||||
.setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
|
||||
.setUsage(AudioAttributes.USAGE_MEDIA)
|
||||
.build()
|
||||
val format = AudioFormat.Builder()
|
||||
.setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
|
||||
.setChannelMask(AudioFormat.CHANNEL_OUT_MONO)
|
||||
.setSampleRate(sr)
|
||||
.build()
|
||||
track = AudioTrack(
|
||||
attr,
|
||||
format,
|
||||
bufLength,
|
||||
AudioTrack.MODE_STREAM,
|
||||
AudioManager.AUDIO_SESSION_ID_GENERATE
|
||||
)
|
||||
track?.play()
|
||||
}
|
||||
|
||||
private fun assetExists(path: String): Boolean {
|
||||
return try {
|
||||
application.assets.open(path).close()
|
||||
true
|
||||
} catch (_: Throwable) {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
private fun copySenseVoiceAssetsToInternal(): File {
|
||||
val outDir = File(filesDir, "sensevoice_models")
|
||||
if (!outDir.exists()) outDir.mkdirs()
|
||||
|
||||
val files = arrayOf(
|
||||
"am.mvn",
|
||||
"chn_jpn_yue_eng_ko_spectok.bpe.model",
|
||||
"embedding.npy",
|
||||
"sense-voice-encoder.rknn"
|
||||
)
|
||||
|
||||
for (name in files) {
|
||||
val assetPath = "sensevoice_models/$name"
|
||||
val outFile = File(outDir, name)
|
||||
if (outFile.exists() && outFile.length() > 0) continue
|
||||
application.assets.open(assetPath).use { input ->
|
||||
FileOutputStream(outFile).use { output ->
|
||||
input.copyTo(output)
|
||||
}
|
||||
}
|
||||
}
|
||||
return outDir
|
||||
}
|
||||
|
||||
private fun initMicrophone(): Boolean {
|
||||
if (ActivityCompat.checkSelfPermission(this, Manifest.permission.RECORD_AUDIO)
|
||||
!= PackageManager.PERMISSION_GRANTED
|
||||
) {
|
||||
ActivityCompat.requestPermissions(this, permissions, REQUEST_RECORD_AUDIO_PERMISSION)
|
||||
return false
|
||||
}
|
||||
|
||||
val numBytes = AudioRecord.getMinBufferSize(sampleRateInHz, channelConfig, audioFormat)
|
||||
audioRecord = AudioRecord(
|
||||
audioSource,
|
||||
sampleRateInHz,
|
||||
channelConfig,
|
||||
audioFormat,
|
||||
numBytes * 2
|
||||
)
|
||||
val sessionId = audioRecord?.audioSessionId ?: 0
|
||||
if (sessionId != 0) {
|
||||
if (android.media.audiofx.AcousticEchoCanceler.isAvailable()) {
|
||||
aec = android.media.audiofx.AcousticEchoCanceler.create(sessionId)?.apply {
|
||||
enabled = true
|
||||
}
|
||||
Log.i(TAG, "AEC enabled=${aec?.enabled}")
|
||||
} else {
|
||||
Log.w(TAG, "AEC not available on this device")
|
||||
}
|
||||
|
||||
if (android.media.audiofx.NoiseSuppressor.isAvailable()) {
|
||||
ns = android.media.audiofx.NoiseSuppressor.create(sessionId)?.apply {
|
||||
enabled = true
|
||||
}
|
||||
Log.i(TAG, "NS enabled=${ns?.enabled}")
|
||||
} else {
|
||||
Log.w(TAG, "NS not available on this device")
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
private suspend fun processSamplesLoop() {
|
||||
// Avoid calling vad.front()/vad.pop() (native queue APIs) since it crashes on some builds.
|
||||
// Use vad.compute() and implement a simple VAD segmenter in Kotlin instead.
|
||||
val windowSize = 512
|
||||
val buffer = ShortArray(windowSize)
|
||||
val threshold = 0.5f
|
||||
val minSilenceSamples = (0.25f * sampleRateInHz).toInt()
|
||||
val minSpeechSamples = (0.25f * sampleRateInHz).toInt()
|
||||
val maxSpeechSamples = (5.0f * sampleRateInHz).toInt()
|
||||
|
||||
var inSpeech = false
|
||||
var silenceSamples = 0
|
||||
|
||||
var speechBuf = FloatArray(0)
|
||||
var speechLen = 0
|
||||
|
||||
fun appendSpeech(chunk: FloatArray) {
|
||||
val needed = speechLen + chunk.size
|
||||
if (speechBuf.size < needed) {
|
||||
var newCap = maxOf(needed, maxOf(1024, speechBuf.size * 2))
|
||||
if (newCap > maxSpeechSamples) newCap = maxSpeechSamples
|
||||
val n = FloatArray(newCap)
|
||||
if (speechLen > 0) System.arraycopy(speechBuf, 0, n, 0, speechLen)
|
||||
speechBuf = n
|
||||
}
|
||||
val copyN = minOf(chunk.size, max(0, maxSpeechSamples - speechLen))
|
||||
if (copyN > 0) {
|
||||
System.arraycopy(chunk, 0, speechBuf, speechLen, copyN)
|
||||
speechLen += copyN
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun finalizeSegmentIfAny() {
|
||||
if (speechLen < minSpeechSamples) {
|
||||
speechLen = 0
|
||||
inSpeech = false
|
||||
silenceSamples = 0
|
||||
return
|
||||
}
|
||||
|
||||
val seg = speechBuf.copyOf(speechLen)
|
||||
speechLen = 0
|
||||
inSpeech = false
|
||||
silenceSamples = 0
|
||||
|
||||
// 每次只允许一个 LLM 请求在飞,避免堆积导致卡死/竞态
|
||||
if (llmInFlight) return
|
||||
|
||||
val trace = currentTrace
|
||||
trace?.markASRStart()
|
||||
val raw = synchronized(nativeLock) {
|
||||
val e = senseVoice
|
||||
if (e == null || !e.isInitialized) "" else e.transcribeBuffer(seg)
|
||||
}
|
||||
val text = removeTokens(raw)
|
||||
if (text.isBlank()) return
|
||||
trace?.markASREnd()
|
||||
if (text.isBlank()) return
|
||||
|
||||
withContext(Dispatchers.Main) {
|
||||
appendToUi("\n\n[ASR] ${text}\n")
|
||||
}
|
||||
|
||||
trace?.markRecordingDone()
|
||||
trace?.markLlmResponseReceived()
|
||||
|
||||
if (BuildConfig.LLM_API_KEY.isBlank()) {
|
||||
withContext(Dispatchers.Main) {
|
||||
Toast.makeText(
|
||||
this@MainActivity,
|
||||
"未配置 LLM_API_KEY(在 local.properties 或 gradle.properties 里设置)",
|
||||
Toast.LENGTH_LONG
|
||||
).show()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
llmInFlight = true
|
||||
cloudApiManager.callLLM(text)
|
||||
}
|
||||
|
||||
while (isRecording && ioScope.coroutineContext.isActive) {
|
||||
val ret = audioRecord?.read(buffer, 0, buffer.size) ?: break
|
||||
if (ret <= 0) continue
|
||||
if (ret != windowSize) continue
|
||||
|
||||
val chunk = FloatArray(ret) { buffer[it] / 32768.0f }
|
||||
val prob = synchronized(nativeLock) { vad.compute(chunk) }
|
||||
|
||||
if (prob >= threshold) {
|
||||
if (!inSpeech) {
|
||||
inSpeech = true
|
||||
silenceSamples = 0
|
||||
}
|
||||
appendSpeech(chunk)
|
||||
|
||||
if (speechLen >= maxSpeechSamples) {
|
||||
finalizeSegmentIfAny()
|
||||
}
|
||||
} else {
|
||||
if (inSpeech) {
|
||||
silenceSamples += ret
|
||||
if (silenceSamples >= minSilenceSamples) {
|
||||
finalizeSegmentIfAny()
|
||||
} else {
|
||||
// keep a bit of trailing silence to avoid chopping
|
||||
appendSpeech(chunk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 时间兜底切段(避免长时间无标点导致首包太慢)
|
||||
val forced = segmenter.maybeForceByTime()
|
||||
for (seg in forced) enqueueTtsSegment(seg)
|
||||
}
|
||||
|
||||
// flush last partial segment
|
||||
finalizeSegmentIfAny()
|
||||
}
|
||||
|
||||
private fun removeTokens(text: String): String {
|
||||
// Remove tokens like <|zh|>, <|NEUTRAL|>, <|Speech|>, <|woitn|> and stray '>' chars
|
||||
var cleaned = text.replace(Regex("<\\|[^>]+\\|>"), "")
|
||||
cleaned = cleaned.replace(Regex("[>>≥≫]"), "")
|
||||
cleaned = cleaned.trim().replace(Regex("\\s+"), " ")
|
||||
return cleaned
|
||||
}
|
||||
|
||||
private fun enqueueTtsSegment(seg: String) {
|
||||
currentTrace?.markTtsRequestEnqueued()
|
||||
ttsQueue.offer(TtsQueueItem.Segment(seg))
|
||||
ensureTtsWorker()
|
||||
}
|
||||
|
||||
private fun ensureTtsWorker() {
|
||||
if (!ttsWorkerRunning.compareAndSet(false, true)) return
|
||||
ioScope.launch {
|
||||
try {
|
||||
runTtsWorker()
|
||||
} finally {
|
||||
ttsWorkerRunning.set(false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun runTtsWorker() {
|
||||
val t = tts ?: return
|
||||
val audioTrack = track ?: return
|
||||
|
||||
var firstAudioMarked = false
|
||||
while (true) {
|
||||
val item = ttsQueue.take()
|
||||
if (ttsStopped.get()) break
|
||||
|
||||
when (item) {
|
||||
is TtsQueueItem.Segment -> {
|
||||
val trace = currentTrace
|
||||
trace?.markTtsSynthesisStart()
|
||||
|
||||
val startMs = System.currentTimeMillis()
|
||||
var firstPcmMarked = false
|
||||
|
||||
// flush to reduce latency between segments
|
||||
try {
|
||||
audioTrack.pause()
|
||||
audioTrack.flush()
|
||||
audioTrack.play()
|
||||
} catch (_: Throwable) {
|
||||
}
|
||||
|
||||
t.generateWithCallback(
|
||||
text = item.text,
|
||||
sid = 0,
|
||||
speed = 1.0f
|
||||
) { samples ->
|
||||
if (ttsStopped.get()) return@generateWithCallback 0
|
||||
if (!firstPcmMarked && samples.isNotEmpty()) {
|
||||
firstPcmMarked = true
|
||||
trace?.markTtsFirstPcmReady()
|
||||
}
|
||||
if (!firstAudioMarked && samples.isNotEmpty()) {
|
||||
firstAudioMarked = true
|
||||
trace?.markTtsFirstAudioPlay()
|
||||
}
|
||||
audioTrack.write(samples, 0, samples.size, AudioTrack.WRITE_BLOCKING)
|
||||
1
|
||||
}
|
||||
|
||||
val ttsMs = System.currentTimeMillis() - startMs
|
||||
trace?.addDuration("tts_segment_ms_total", ttsMs)
|
||||
}
|
||||
|
||||
TtsQueueItem.End -> {
|
||||
currentTrace?.markTtsDone()
|
||||
TraceManager.getInstance().endTurn()
|
||||
currentTrace = null
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun appendToUi(s: String) {
|
||||
lastUiText += s
|
||||
textView.text = lastUiText
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
package com.digitalperson
|
||||
|
||||
/**
|
||||
* 将大模型流式 chunk 做“伪流式 TTS”的分段器:
|
||||
* - 优先按中文/英文标点断句,尽早产出第一段,缩短首包时间
|
||||
* - 无标点时,按长度/时间兜底切段
|
||||
*/
|
||||
class StreamingTextSegmenter(
|
||||
private val maxLen: Int = 18,
|
||||
private val maxWaitMs: Long = 400,
|
||||
) {
|
||||
private val buf = StringBuilder()
|
||||
private var lastEmitAtMs: Long = 0
|
||||
|
||||
@Synchronized
|
||||
fun reset(nowMs: Long = System.currentTimeMillis()) {
|
||||
buf.setLength(0)
|
||||
lastEmitAtMs = nowMs
|
||||
}
|
||||
|
||||
@Synchronized
|
||||
fun processChunk(chunk: String, nowMs: Long = System.currentTimeMillis()): List<String> {
|
||||
if (chunk.isEmpty()) return emptyList()
|
||||
buf.append(chunk)
|
||||
return drain(nowMs, forceByTime = false)
|
||||
}
|
||||
|
||||
@Synchronized
|
||||
fun flush(nowMs: Long = System.currentTimeMillis()): List<String> {
|
||||
val out = mutableListOf<String>()
|
||||
val remaining = buf.toString().trim()
|
||||
buf.setLength(0)
|
||||
if (remaining.isNotEmpty()) out.add(remaining)
|
||||
lastEmitAtMs = nowMs
|
||||
return out
|
||||
}
|
||||
|
||||
@Synchronized
|
||||
fun maybeForceByTime(nowMs: Long = System.currentTimeMillis()): List<String> {
|
||||
return drain(nowMs, forceByTime = true)
|
||||
}
|
||||
|
||||
private fun drain(nowMs: Long, forceByTime: Boolean): List<String> {
|
||||
val out = mutableListOf<String>()
|
||||
|
||||
// 1) 优先按标点切分(尽早产出一小段)
|
||||
while (true) {
|
||||
val idx = firstPunctuationIndex(buf)
|
||||
if (idx < 0) break
|
||||
val seg = buf.substring(0, idx + 1).trim()
|
||||
buf.delete(0, idx + 1)
|
||||
if (seg.isNotEmpty()) {
|
||||
out.add(seg)
|
||||
lastEmitAtMs = nowMs
|
||||
}
|
||||
}
|
||||
|
||||
// 2) 长度兜底
|
||||
while (buf.length >= maxLen) {
|
||||
val seg = buf.substring(0, maxLen).trim()
|
||||
buf.delete(0, maxLen)
|
||||
if (seg.isNotEmpty()) {
|
||||
out.add(seg)
|
||||
lastEmitAtMs = nowMs
|
||||
}
|
||||
}
|
||||
|
||||
// 3) 时间兜底:长时间无标点也要切一段
|
||||
if (forceByTime && buf.isNotEmpty() && nowMs - lastEmitAtMs >= maxWaitMs) {
|
||||
val cut = minOf(buf.length, maxLen)
|
||||
val seg = buf.substring(0, cut).trim()
|
||||
buf.delete(0, cut)
|
||||
if (seg.isNotEmpty()) {
|
||||
out.add(seg)
|
||||
lastEmitAtMs = nowMs
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
private fun firstPunctuationIndex(sb: StringBuilder): Int {
|
||||
// 强/弱标点都允许尽早断句
|
||||
val punct = charArrayOf('。', '!', '?', ';', ',', '.', '!', '?', ';', ',', '\n')
|
||||
var best = -1
|
||||
for (p in punct) {
|
||||
val i = sb.indexOf(p.toString())
|
||||
if (i >= 0 && (best < 0 || i < best)) best = i
|
||||
}
|
||||
return best
|
||||
}
|
||||
}
|
||||
|
||||
207
app/src/main/java/com/digitalperson/cloud/CloudApiManager.java
Normal file
207
app/src/main/java/com/digitalperson/cloud/CloudApiManager.java
Normal file
@@ -0,0 +1,207 @@
|
||||
package com.digitalperson.cloud;
|
||||
|
||||
import android.os.Handler;
|
||||
import android.os.Looper;
|
||||
import android.util.Log;
|
||||
|
||||
import com.digitalperson.BuildConfig;
|
||||
|
||||
import org.json.JSONArray;
|
||||
import org.json.JSONException;
|
||||
import org.json.JSONObject;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.DataOutputStream;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStreamReader;
|
||||
import java.net.HttpURLConnection;
|
||||
import java.net.URL;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
|
||||
public class CloudApiManager {
|
||||
private static final String TAG = "CloudApiManager";
|
||||
|
||||
// 火山引擎OpenAI兼容API配置
|
||||
private static final String LLM_API_URL = BuildConfig.LLM_API_URL;
|
||||
private static final String API_KEY = BuildConfig.LLM_API_KEY;
|
||||
private static final String LLM_MODEL = BuildConfig.LLM_MODEL;
|
||||
|
||||
private CloudApiListener mListener;
|
||||
private Handler mMainHandler; // 用于在主线程执行UI更新
|
||||
private JSONArray mConversationHistory; // 存储对话历史
|
||||
|
||||
public interface CloudApiListener {
|
||||
void onLLMResponseReceived(String response);
|
||||
void onLLMStreamingChunkReceived(String chunk);
|
||||
void onTTSAudioReceived(String audioFilePath);
|
||||
void onError(String errorMessage);
|
||||
}
|
||||
|
||||
public CloudApiManager(CloudApiListener listener) {
|
||||
this.mListener = listener;
|
||||
this.mMainHandler = new Handler(Looper.getMainLooper()); // 初始化主线程Handler
|
||||
this.mConversationHistory = new JSONArray(); // 初始化对话历史
|
||||
}
|
||||
|
||||
public void callLLM(String userInput) {
|
||||
new Thread(() -> {
|
||||
try {
|
||||
// 添加用户输入到对话历史
|
||||
addMessageToHistory("user", userInput);
|
||||
|
||||
// 创建HTTP连接
|
||||
URL url = new URL(LLM_API_URL);
|
||||
HttpURLConnection conn = (HttpURLConnection) url.openConnection();
|
||||
conn.setRequestMethod("POST");
|
||||
conn.setRequestProperty("Content-Type", "application/json");
|
||||
conn.setRequestProperty("Authorization", "Bearer " + API_KEY);
|
||||
conn.setDoOutput(true);
|
||||
conn.setConnectTimeout(10000);
|
||||
conn.setReadTimeout(60000); // 延长读取超时以支持流式响应
|
||||
|
||||
// 构建请求体
|
||||
JSONObject requestBody = new JSONObject();
|
||||
requestBody.put("model", LLM_MODEL);
|
||||
requestBody.put("messages", mConversationHistory);
|
||||
requestBody.put("stream", true); // 启用流式响应
|
||||
|
||||
String jsonBody = requestBody.toString();
|
||||
|
||||
Log.d(TAG, "LLM Request: " + jsonBody);
|
||||
|
||||
// 发送请求
|
||||
try (DataOutputStream dos = new DataOutputStream(conn.getOutputStream())) {
|
||||
dos.write(jsonBody.getBytes(StandardCharsets.UTF_8));
|
||||
dos.flush();
|
||||
}
|
||||
|
||||
// 读取响应
|
||||
int responseCode = conn.getResponseCode();
|
||||
StringBuilder fullResponse = new StringBuilder();
|
||||
StringBuilder accumulatedContent = new StringBuilder();
|
||||
|
||||
Log.d(TAG, "LLM Response Code: " + responseCode);
|
||||
|
||||
if (responseCode == 200) {
|
||||
// 逐行读取流式响应
|
||||
try (BufferedReader br = new BufferedReader(
|
||||
new InputStreamReader(conn.getInputStream(), StandardCharsets.UTF_8))) {
|
||||
String line;
|
||||
while ((line = br.readLine()) != null) {
|
||||
Log.d(TAG, "LLM Streaming Line: " + line);
|
||||
|
||||
// 处理SSE格式的响应
|
||||
if (line.startsWith("data: ")) {
|
||||
String dataPart = line.substring(6);
|
||||
if (dataPart.equals("[DONE]")) {
|
||||
// 流式响应结束
|
||||
break;
|
||||
}
|
||||
|
||||
try {
|
||||
// 解析JSON
|
||||
JSONObject chunkObj = new JSONObject(dataPart);
|
||||
JSONArray choices = chunkObj.getJSONArray("choices");
|
||||
if (choices.length() > 0) {
|
||||
JSONObject choice = choices.getJSONObject(0);
|
||||
JSONObject delta = choice.getJSONObject("delta");
|
||||
|
||||
if (delta.has("content")) {
|
||||
String chunkContent = delta.getString("content");
|
||||
accumulatedContent.append(chunkContent);
|
||||
|
||||
// 发送流式chunk到监听器
|
||||
if (mListener != null) {
|
||||
mMainHandler.post(() -> {
|
||||
mListener.onLLMStreamingChunkReceived(chunkContent);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (JSONException e) {
|
||||
Log.e(TAG, "Failed to parse streaming chunk: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
fullResponse.append(line).append("\n");
|
||||
}
|
||||
}
|
||||
|
||||
String content = accumulatedContent.toString();
|
||||
Log.d(TAG, "Full LLM Response: " + content);
|
||||
|
||||
// 添加AI回复到对话历史
|
||||
addMessageToHistory("assistant", content);
|
||||
|
||||
if (mListener != null) {
|
||||
mMainHandler.post(() -> {
|
||||
mListener.onLLMResponseReceived(content);
|
||||
});
|
||||
}
|
||||
} else {
|
||||
// 读取错误响应
|
||||
StringBuilder errorResponse = new StringBuilder();
|
||||
try (BufferedReader br = new BufferedReader(
|
||||
new InputStreamReader(conn.getErrorStream(), StandardCharsets.UTF_8))) {
|
||||
String line;
|
||||
while ((line = br.readLine()) != null) {
|
||||
errorResponse.append(line);
|
||||
}
|
||||
}
|
||||
throw new IOException("HTTP " + responseCode + ": " + errorResponse.toString());
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
Log.e(TAG, "LLM call failed: " + e.getMessage());
|
||||
if (mListener != null) {
|
||||
mMainHandler.post(() -> {
|
||||
mListener.onError("LLM调用失败: " + e.getMessage());
|
||||
});
|
||||
}
|
||||
}
|
||||
}).start();
|
||||
}
|
||||
|
||||
/**
|
||||
* 添加消息到对话历史
|
||||
*/
|
||||
private void addMessageToHistory(String role, String content) {
|
||||
try {
|
||||
JSONObject message = new JSONObject();
|
||||
message.put("role", role);
|
||||
message.put("content", content);
|
||||
mConversationHistory.put(message);
|
||||
} catch (JSONException e) {
|
||||
Log.e(TAG, "Failed to add message to history: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 清空对话历史
|
||||
*/
|
||||
public void clearConversationHistory() {
|
||||
mConversationHistory = new JSONArray();
|
||||
}
|
||||
|
||||
public void callTTS(String text, File outputFile) {
|
||||
if (mListener != null) {
|
||||
mMainHandler.post(() -> {
|
||||
mListener.onError("TTS功能暂未实现");
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
private String extractContentFromResponse(String response) {
|
||||
try {
|
||||
int contentStart = response.indexOf("\"content\":\"") + 11;
|
||||
int contentEnd = response.indexOf("\"", contentStart);
|
||||
if (contentStart > 10 && contentEnd > contentStart) {
|
||||
return response.substring(contentStart, contentEnd);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
Log.e(TAG, "Failed to parse response: " + e.getMessage());
|
||||
}
|
||||
return "抱歉,无法解析响应";
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,95 @@
|
||||
package com.digitalperson.engine;
|
||||
|
||||
import android.content.Context;
|
||||
import android.util.Log;
|
||||
|
||||
/**
|
||||
* RKNN SenseVoice engine (copied from school_teacher, used as ASR backend on RK3588).
|
||||
*
|
||||
* It depends on native libs:
|
||||
* - libsentencepiece.so
|
||||
* - librknnrt.so
|
||||
* - libsensevoiceEngine.so (built from app/src/main/cpp)
|
||||
*/
|
||||
public class SenseVoiceEngineRKNN implements WhisperEngine {
|
||||
private static final String TAG = "SenseVoiceEngineRKNN";
|
||||
private final long nativePtr;
|
||||
private boolean mIsInitialized = false;
|
||||
|
||||
static {
|
||||
try {
|
||||
System.loadLibrary("sentencepiece");
|
||||
System.loadLibrary("rknnrt");
|
||||
System.loadLibrary("sensevoiceEngine");
|
||||
Log.d(TAG, "Loaded libsentencepiece.so, librknnrt.so, libsensevoiceEngine.so");
|
||||
} catch (UnsatisfiedLinkError e) {
|
||||
Log.e(TAG, "Failed to load native libraries for SenseVoice", e);
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
public SenseVoiceEngineRKNN(Context context) {
|
||||
nativePtr = createSenseVoiceEngine();
|
||||
if (nativePtr == 0) {
|
||||
throw new RuntimeException("Failed to create native SenseVoice engine");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isInitialized() {
|
||||
return mIsInitialized;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean initialize(String modelPath, String vocabPath, boolean multilingual) {
|
||||
// SenseVoice needs three files: model, embedding, bpe
|
||||
String embeddingPath = modelPath.replace("sense-voice-encoder.rknn", "embedding.npy");
|
||||
String bpePath = modelPath.replace("sense-voice-encoder.rknn", "chn_jpn_yue_eng_ko_spectok.bpe.model");
|
||||
int ret = loadModel(nativePtr, modelPath, embeddingPath, bpePath);
|
||||
if (ret == 0) {
|
||||
mIsInitialized = true;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deinitialize() {
|
||||
if (nativePtr != 0) {
|
||||
freeModel(nativePtr);
|
||||
}
|
||||
mIsInitialized = false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String transcribeBuffer(float[] samples) {
|
||||
if (!mIsInitialized) {
|
||||
return "Error: Engine not initialized";
|
||||
}
|
||||
return transcribeBufferNative(nativePtr, samples, 0, false);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String transcribeFile(String waveFile) {
|
||||
if (!mIsInitialized) {
|
||||
return "Error: Engine not initialized";
|
||||
}
|
||||
return transcribeFileNative(nativePtr, waveFile, 0, false);
|
||||
}
|
||||
|
||||
public boolean loadModelDirectly(String modelPath, String embeddingPath, String bpePath) {
|
||||
int ret = loadModel(nativePtr, modelPath, embeddingPath, bpePath);
|
||||
if (ret == 0) {
|
||||
mIsInitialized = true;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private native long createSenseVoiceEngine();
|
||||
private native int loadModel(long nativePtr, String modelPath, String embeddingPath, String bpeModelPath);
|
||||
private native void freeModel(long ptr);
|
||||
private native String transcribeBufferNative(long ptr, float[] samples, int language, boolean use_itn);
|
||||
private native String transcribeFileNative(long ptr, String waveFile, int language, boolean use_itn);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
package com.digitalperson.engine;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public interface WhisperEngine {
|
||||
boolean isInitialized();
|
||||
boolean initialize(String modelPath, String vocabPath, boolean multilingual) throws IOException;
|
||||
void deinitialize();
|
||||
String transcribeFile(String wavePath);
|
||||
String transcribeBuffer(float[] samples);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
package com.digitalperson.metrics;
|
||||
|
||||
import android.util.Log;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class TraceManager {
|
||||
private static final String TAG = "TraceManager";
|
||||
|
||||
private static TraceManager instance;
|
||||
|
||||
private TraceSession currentSession;
|
||||
private final List<TraceSession> sessionHistory = new ArrayList<>();
|
||||
|
||||
private TraceManager() {
|
||||
}
|
||||
|
||||
public static synchronized TraceManager getInstance() {
|
||||
if (instance == null) {
|
||||
instance = new TraceManager();
|
||||
}
|
||||
return instance;
|
||||
}
|
||||
|
||||
public TraceSession startNewTurn() {
|
||||
endTurn(); // 清理之前的会话
|
||||
currentSession = new TraceSession();
|
||||
Log.i(TAG, "Started new trace turn: " + currentSession.getTurnId());
|
||||
return currentSession;
|
||||
}
|
||||
|
||||
public TraceSession getCurrent() {
|
||||
return currentSession;
|
||||
}
|
||||
|
||||
public void setCurrent(TraceSession session) {
|
||||
this.currentSession = session;
|
||||
}
|
||||
|
||||
public void endTurn() {
|
||||
if (currentSession != null) {
|
||||
Log.i(TAG, "Ending trace turn: " + currentSession.getTurnId());
|
||||
currentSession.printSummary();
|
||||
sessionHistory.add(currentSession);
|
||||
// 限制历史会话数量,避免内存占用过大
|
||||
if (sessionHistory.size() > 10) {
|
||||
sessionHistory.remove(0);
|
||||
}
|
||||
currentSession = null;
|
||||
}
|
||||
}
|
||||
|
||||
public boolean hasActiveSession() {
|
||||
return currentSession != null;
|
||||
}
|
||||
|
||||
public List<TraceSession> getSessionHistory() {
|
||||
return sessionHistory;
|
||||
}
|
||||
|
||||
public TraceSession getSessionById(String turnId) {
|
||||
for (TraceSession session : sessionHistory) {
|
||||
if (session.getTurnId().equals(turnId)) {
|
||||
return session;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public void clearSessionHistory() {
|
||||
sessionHistory.clear();
|
||||
Log.i(TAG, "Cleared session history");
|
||||
}
|
||||
}
|
||||
264
app/src/main/java/com/digitalperson/metrics/TraceSession.java
Normal file
264
app/src/main/java/com/digitalperson/metrics/TraceSession.java
Normal file
@@ -0,0 +1,264 @@
|
||||
package com.digitalperson.metrics;
|
||||
|
||||
import android.util.Log;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
public class TraceSession {
|
||||
private static final String TAG = "TraceSession";
|
||||
|
||||
private final String turnId;
|
||||
private final Map<String, Long> marks = new ConcurrentHashMap<>();
|
||||
private final Map<String, Long> durations = new ConcurrentHashMap<>();
|
||||
private final Map<String, SegmentMetrics> segments = new ConcurrentHashMap<>();
|
||||
|
||||
public TraceSession() {
|
||||
this.turnId = String.valueOf(System.currentTimeMillis());
|
||||
}
|
||||
|
||||
public TraceSession(String turnId) {
|
||||
this.turnId = turnId;
|
||||
}
|
||||
|
||||
public String getTurnId() {
|
||||
return turnId;
|
||||
}
|
||||
|
||||
public void mark(String name) {
|
||||
long timestamp = System.currentTimeMillis();
|
||||
marks.put(name, timestamp);
|
||||
Log.d(TAG, "Marked " + name + " at " + timestamp);
|
||||
}
|
||||
|
||||
public void addDuration(String name, long deltaMs) {
|
||||
// 使用循环确保原子性,处理并发更新
|
||||
while (true) {
|
||||
Long currentValue = durations.get(name);
|
||||
long newValue = (currentValue != null) ? currentValue + deltaMs : deltaMs;
|
||||
if (currentValue == null) {
|
||||
// 如果键不存在,尝试添加
|
||||
if (durations.putIfAbsent(name, newValue) == null) {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
// 如果键存在,尝试更新
|
||||
if (durations.replace(name, currentValue, newValue)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void recordSegment(String segId, int textLen, long phonemizeMs, long encoderMs, long decoderMs, int invalidCount, int pcmSamples) {
|
||||
segments.put(segId, new SegmentMetrics(textLen, phonemizeMs, encoderMs, decoderMs, invalidCount, pcmSamples));
|
||||
}
|
||||
|
||||
public void recordSegment(String segId, int textLen, long phonemizeMs, long encoderMs, long decoderMs, int invalidCount) {
|
||||
segments.put(segId, new SegmentMetrics(textLen, phonemizeMs, encoderMs, decoderMs, invalidCount, 0));
|
||||
}
|
||||
|
||||
public Map<String, Long> getMarks() {
|
||||
return marks;
|
||||
}
|
||||
|
||||
public Map<String, Long> getDurations() {
|
||||
return durations;
|
||||
}
|
||||
|
||||
public Map<String, SegmentMetrics> getSegments() {
|
||||
return segments;
|
||||
}
|
||||
|
||||
public String summary() {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
sb.append("=== 会话 " + turnId.substring(8) + " ===\n");
|
||||
|
||||
// ASR metrics
|
||||
long asrTime = getASRTime();
|
||||
if (asrTime > 0) {
|
||||
sb.append("ASR处理: " + asrTime + "ms\n");
|
||||
} else {
|
||||
sb.append("ASR处理: 无数据\n");
|
||||
}
|
||||
|
||||
// Derived metrics
|
||||
if (marks.containsKey("tts_first_audio_play")) {
|
||||
if (marks.containsKey("llm_first_chunk")) {
|
||||
long llmToTts = marks.get("tts_first_audio_play") - marks.get("llm_first_chunk");
|
||||
sb.append("LLM到TTS播放: " + llmToTts + "ms\n");
|
||||
} else {
|
||||
sb.append("LLM到TTS播放: 无数据\n");
|
||||
}
|
||||
if (marks.containsKey("recording_done")) {
|
||||
long recToTts = marks.get("tts_first_audio_play") - marks.get("recording_done");
|
||||
sb.append("录音到TTS播放: " + recToTts + "ms\n");
|
||||
} else {
|
||||
sb.append("录音到TTS播放: 无数据\n");
|
||||
}
|
||||
} else {
|
||||
sb.append("TTS播放未开始\n");
|
||||
}
|
||||
|
||||
// Durations
|
||||
if (!durations.isEmpty()) {
|
||||
sb.append("处理耗时:\n");
|
||||
if (durations.containsKey("phonemize_ms_total")) {
|
||||
sb.append(" 音素化: " + durations.get("phonemize_ms_total") + "ms\n");
|
||||
} else {
|
||||
sb.append(" 音素化: 无数据\n");
|
||||
}
|
||||
if (durations.containsKey("encoder_ms_total")) {
|
||||
sb.append(" 编码器: " + durations.get("encoder_ms_total") + "ms\n");
|
||||
} else {
|
||||
sb.append(" 编码器: 无数据\n");
|
||||
}
|
||||
if (durations.containsKey("decoder_ms_total")) {
|
||||
sb.append(" 解码器: " + durations.get("decoder_ms_total") + "ms\n");
|
||||
} else {
|
||||
sb.append(" 解码器: 无数据\n");
|
||||
}
|
||||
} else {
|
||||
sb.append("处理耗时: 无数据\n");
|
||||
}
|
||||
|
||||
// Segments count
|
||||
sb.append("片段数: " + segments.size() + "\n");
|
||||
|
||||
// Invalid samples total
|
||||
long invalidTotal = 0;
|
||||
for (SegmentMetrics seg : segments.values()) {
|
||||
invalidTotal += seg.invalidCount;
|
||||
}
|
||||
sb.append("无效样本数: " + invalidTotal + "\n");
|
||||
|
||||
sb.append("\n");
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
// 获取ASR处理耗时
|
||||
public long getASRTime() {
|
||||
if (marks.containsKey("asr_processing_start") && marks.containsKey("asr_processing_done")) {
|
||||
return marks.get("asr_processing_done") - marks.get("asr_processing_start");
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
// 标记ASR处理开始
|
||||
public void markASRStart() {
|
||||
mark("asr_processing_start");
|
||||
}
|
||||
|
||||
// 标记ASR处理完成
|
||||
public void markASREnd() {
|
||||
mark("asr_processing_done");
|
||||
}
|
||||
|
||||
// 标记录音完成
|
||||
public void markRecordingDone() {
|
||||
mark("recording_done");
|
||||
}
|
||||
|
||||
// 标记LLM首字回复
|
||||
public void markLlmFirstChunk() {
|
||||
mark("llm_first_chunk");
|
||||
}
|
||||
|
||||
// 标记LLM回复完成
|
||||
public void markLlmDone() {
|
||||
mark("llm_done");
|
||||
}
|
||||
|
||||
// 标记LLM响应收到
|
||||
public void markLlmResponseReceived() {
|
||||
mark("llm_response_received");
|
||||
}
|
||||
|
||||
// 标记TTS请求入队
|
||||
public void markTtsRequestEnqueued() {
|
||||
mark("tts_request_enqueued");
|
||||
}
|
||||
|
||||
// 标记TTS合成开始
|
||||
public void markTtsSynthesisStart() {
|
||||
mark("tts_synthesis_start");
|
||||
}
|
||||
|
||||
// 标记TTS第一段PCM准备就绪
|
||||
public void markTtsFirstPcmReady() {
|
||||
mark("tts_first_pcm_ready");
|
||||
}
|
||||
|
||||
// 标记TTS第一段音频播放
|
||||
public void markTtsFirstAudioPlay() {
|
||||
mark("tts_first_audio_play");
|
||||
}
|
||||
|
||||
// 标记TTS合成完成
|
||||
public void markTtsDone() {
|
||||
mark("tts_done");
|
||||
}
|
||||
|
||||
// 获取录音完成到LLM首字回复的时间
|
||||
public long getRecordingToLlmFirstChunkMs() {
|
||||
return getStageTime("recording_done", "llm_first_chunk");
|
||||
}
|
||||
|
||||
// 获取LLM首字回复到TTS第一段音频播放的时间
|
||||
public long getLlmFirstChunkToTtsFirstAudioMs() {
|
||||
return getStageTime("llm_first_chunk", "tts_first_audio_play");
|
||||
}
|
||||
|
||||
// 获取录音完成到TTS第一段音频播放的时间
|
||||
public long getRecordingToTtsFirstAudioMs() {
|
||||
return getStageTime("recording_done", "tts_first_audio_play");
|
||||
}
|
||||
|
||||
// 获取指定阶段的耗时
|
||||
public long getStageTime(String startMark, String endMark) {
|
||||
if (marks.containsKey(startMark) && marks.containsKey(endMark)) {
|
||||
return marks.get(endMark) - marks.get(startMark);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
public String getShortSummary() {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
|
||||
// Derived metrics
|
||||
if (marks.containsKey("tts_first_audio_play")) {
|
||||
if (marks.containsKey("llm_first_chunk")) {
|
||||
long llmToTts = marks.get("tts_first_audio_play") - marks.get("llm_first_chunk");
|
||||
sb.append("LLM→TTS: " + llmToTts + "ms");
|
||||
}
|
||||
if (marks.containsKey("recording_done")) {
|
||||
long recToTts = marks.get("tts_first_audio_play") - marks.get("recording_done");
|
||||
sb.append(" | 录音→TTS: " + recToTts + "ms");
|
||||
}
|
||||
}
|
||||
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
public void printSummary() {
|
||||
Log.i(TAG, summary());
|
||||
}
|
||||
|
||||
public static class SegmentMetrics {
|
||||
public final int textLen;
|
||||
public final long phonemizeMs;
|
||||
public final long encoderMs;
|
||||
public final long decoderMs;
|
||||
public final int invalidCount;
|
||||
public final int pcmSamples;
|
||||
|
||||
public SegmentMetrics(int textLen, long phonemizeMs, long encoderMs, long decoderMs, int invalidCount, int pcmSamples) {
|
||||
this.textLen = textLen;
|
||||
this.phonemizeMs = phonemizeMs;
|
||||
this.encoderMs = encoderMs;
|
||||
this.decoderMs = decoderMs;
|
||||
this.invalidCount = invalidCount;
|
||||
this.pcmSamples = pcmSamples;
|
||||
}
|
||||
}
|
||||
}
|
||||
11
app/src/main/java/com/k2fsa/sherpa/onnx/FeatureConfig.kt
Normal file
11
app/src/main/java/com/k2fsa/sherpa/onnx/FeatureConfig.kt
Normal file
@@ -0,0 +1,11 @@
|
||||
package com.k2fsa.sherpa.onnx
|
||||
|
||||
data class FeatureConfig(
|
||||
var sampleRate: Int = 16000,
|
||||
var featureDim: Int = 80,
|
||||
var dither: Float = 0.0f
|
||||
)
|
||||
|
||||
fun getFeatureConfig(sampleRate: Int, featureDim: Int): FeatureConfig {
|
||||
return FeatureConfig(sampleRate = sampleRate, featureDim = featureDim)
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
package com.k2fsa.sherpa.onnx
|
||||
|
||||
data class HomophoneReplacerConfig(
|
||||
var dictDir: String = "", // unused
|
||||
var lexicon: String = "",
|
||||
var ruleFsts: String = "",
|
||||
)
|
||||
1241
app/src/main/java/com/k2fsa/sherpa/onnx/OfflineRecognizer.kt
Normal file
1241
app/src/main/java/com/k2fsa/sherpa/onnx/OfflineRecognizer.kt
Normal file
File diff suppressed because it is too large
Load Diff
32
app/src/main/java/com/k2fsa/sherpa/onnx/OfflineStream.kt
Normal file
32
app/src/main/java/com/k2fsa/sherpa/onnx/OfflineStream.kt
Normal file
@@ -0,0 +1,32 @@
|
||||
package com.k2fsa.sherpa.onnx
|
||||
|
||||
class OfflineStream(var ptr: Long) {
|
||||
fun acceptWaveform(samples: FloatArray, sampleRate: Int) =
|
||||
acceptWaveform(ptr, samples, sampleRate)
|
||||
|
||||
protected fun finalize() {
|
||||
if (ptr != 0L) {
|
||||
delete(ptr)
|
||||
ptr = 0
|
||||
}
|
||||
}
|
||||
|
||||
fun release() = finalize()
|
||||
|
||||
fun use(block: (OfflineStream) -> Unit) {
|
||||
try {
|
||||
block(this)
|
||||
} finally {
|
||||
release()
|
||||
}
|
||||
}
|
||||
|
||||
private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int)
|
||||
private external fun delete(ptr: Long)
|
||||
|
||||
companion object {
|
||||
init {
|
||||
System.loadLibrary("sherpa-onnx-jni")
|
||||
}
|
||||
}
|
||||
}
|
||||
7
app/src/main/java/com/k2fsa/sherpa/onnx/QnnConfig.kt
Normal file
7
app/src/main/java/com/k2fsa/sherpa/onnx/QnnConfig.kt
Normal file
@@ -0,0 +1,7 @@
|
||||
package com.k2fsa.sherpa.onnx
|
||||
|
||||
data class QnnConfig(
|
||||
var backendLib: String = "",
|
||||
var contextBinary: String = "",
|
||||
var systemLib: String = "",
|
||||
)
|
||||
373
app/src/main/java/com/k2fsa/sherpa/onnx/Tts.kt
Normal file
373
app/src/main/java/com/k2fsa/sherpa/onnx/Tts.kt
Normal file
@@ -0,0 +1,373 @@
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
package com.k2fsa.sherpa.onnx
|
||||
|
||||
import android.content.res.AssetManager
|
||||
|
||||
data class OfflineTtsVitsModelConfig(
|
||||
var model: String = "",
|
||||
var lexicon: String = "",
|
||||
var tokens: String = "",
|
||||
var dataDir: String = "",
|
||||
var dictDir: String = "", // unused
|
||||
var noiseScale: Float = 0.667f,
|
||||
var noiseScaleW: Float = 0.8f,
|
||||
var lengthScale: Float = 1.0f,
|
||||
)
|
||||
|
||||
data class OfflineTtsMatchaModelConfig(
|
||||
var acousticModel: String = "",
|
||||
var vocoder: String = "",
|
||||
var lexicon: String = "",
|
||||
var tokens: String = "",
|
||||
var dataDir: String = "",
|
||||
var dictDir: String = "", // unused
|
||||
var noiseScale: Float = 1.0f,
|
||||
var lengthScale: Float = 1.0f,
|
||||
)
|
||||
|
||||
data class OfflineTtsKokoroModelConfig(
|
||||
var model: String = "",
|
||||
var voices: String = "",
|
||||
var tokens: String = "",
|
||||
var dataDir: String = "",
|
||||
var lexicon: String = "",
|
||||
var lang: String = "",
|
||||
var dictDir: String = "", // unused
|
||||
var lengthScale: Float = 1.0f,
|
||||
)
|
||||
|
||||
data class OfflineTtsKittenModelConfig(
|
||||
var model: String = "",
|
||||
var voices: String = "",
|
||||
var tokens: String = "",
|
||||
var dataDir: String = "",
|
||||
var lengthScale: Float = 1.0f,
|
||||
)
|
||||
|
||||
/**
|
||||
* Configuration for Pocket TTS models.
|
||||
*
|
||||
* See https://k2-fsa.github.io/sherpa/onnx/tts/pocket/index.html for details.
|
||||
*
|
||||
* @property lmFlow Path to the LM flow model (.onnx)
|
||||
* @property lmMain Path to the LM main model (.onnx)
|
||||
* @property encoder Path to the encoder model (.onnx)
|
||||
* @property decoder Path to the decoder model (.onnx)
|
||||
* @property textConditioner Path to the text conditioner model (.onnx)
|
||||
* @property vocabJson Path to vocabulary JSON file
|
||||
* @property tokenScoresJson Path to token scores JSON file
|
||||
*/
|
||||
data class OfflineTtsPocketModelConfig(
|
||||
var lmFlow: String = "",
|
||||
var lmMain: String = "",
|
||||
var encoder: String = "",
|
||||
var decoder: String = "",
|
||||
var textConditioner: String = "",
|
||||
var vocabJson: String = "",
|
||||
var tokenScoresJson: String = "",
|
||||
)
|
||||
|
||||
data class OfflineTtsModelConfig(
|
||||
var vits: OfflineTtsVitsModelConfig = OfflineTtsVitsModelConfig(),
|
||||
var matcha: OfflineTtsMatchaModelConfig = OfflineTtsMatchaModelConfig(),
|
||||
var kokoro: OfflineTtsKokoroModelConfig = OfflineTtsKokoroModelConfig(),
|
||||
var kitten: OfflineTtsKittenModelConfig = OfflineTtsKittenModelConfig(),
|
||||
val pocket: OfflineTtsPocketModelConfig = OfflineTtsPocketModelConfig(),
|
||||
|
||||
var numThreads: Int = 1,
|
||||
var debug: Boolean = false,
|
||||
var provider: String = "cpu",
|
||||
)
|
||||
|
||||
data class OfflineTtsConfig(
|
||||
var model: OfflineTtsModelConfig = OfflineTtsModelConfig(),
|
||||
var ruleFsts: String = "",
|
||||
var ruleFars: String = "",
|
||||
var maxNumSentences: Int = 1,
|
||||
var silenceScale: Float = 0.2f,
|
||||
)
|
||||
|
||||
class GeneratedAudio(
|
||||
val samples: FloatArray,
|
||||
val sampleRate: Int,
|
||||
) {
|
||||
fun save(filename: String) =
|
||||
saveImpl(filename = filename, samples = samples, sampleRate = sampleRate)
|
||||
|
||||
private external fun saveImpl(
|
||||
filename: String,
|
||||
samples: FloatArray,
|
||||
sampleRate: Int
|
||||
): Boolean
|
||||
}
|
||||
|
||||
data class GenerationConfig(
|
||||
var silenceScale: Float = 0.2f,
|
||||
var speed: Float = 1.0f,
|
||||
var sid: Int = 0,
|
||||
var referenceAudio: FloatArray? = null,
|
||||
var referenceSampleRate: Int = 0,
|
||||
var referenceText: String? = null,
|
||||
var numSteps: Int = 5,
|
||||
var extra: Map<String, String>? = null
|
||||
)
|
||||
|
||||
class OfflineTts(
|
||||
assetManager: AssetManager? = null,
|
||||
var config: OfflineTtsConfig,
|
||||
) {
|
||||
private var ptr: Long
|
||||
|
||||
init {
|
||||
ptr = if (assetManager != null) {
|
||||
newFromAsset(assetManager, config)
|
||||
} else {
|
||||
newFromFile(config)
|
||||
}
|
||||
}
|
||||
|
||||
fun sampleRate() = getSampleRate(ptr)
|
||||
|
||||
fun numSpeakers() = getNumSpeakers(ptr)
|
||||
|
||||
fun generate(
|
||||
text: String,
|
||||
sid: Int = 0,
|
||||
speed: Float = 1.0f
|
||||
): GeneratedAudio {
|
||||
return toGeneratedAudio(generateImpl(ptr, text = text, sid = sid, speed = speed))
|
||||
}
|
||||
|
||||
fun generateWithCallback(
|
||||
text: String,
|
||||
sid: Int = 0,
|
||||
speed: Float = 1.0f,
|
||||
callback: (samples: FloatArray) -> Int
|
||||
): GeneratedAudio {
|
||||
return toGeneratedAudio(generateWithCallbackImpl(
|
||||
ptr,
|
||||
text = text,
|
||||
sid = sid,
|
||||
speed = speed,
|
||||
callback = callback
|
||||
))
|
||||
}
|
||||
|
||||
fun generateWithConfig(
|
||||
text: String,
|
||||
config: GenerationConfig
|
||||
): GeneratedAudio {
|
||||
return toGeneratedAudio(generateWithConfigImpl(ptr, text, config, null))
|
||||
}
|
||||
|
||||
fun generateWithConfigAndCallback(
|
||||
text: String,
|
||||
config: GenerationConfig,
|
||||
callback: (samples: FloatArray) -> Int
|
||||
): GeneratedAudio {
|
||||
return toGeneratedAudio(generateWithConfigImpl(ptr, text, config, callback))
|
||||
}
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
private fun toGeneratedAudio(obj: Any): GeneratedAudio {
|
||||
return when (obj) {
|
||||
is GeneratedAudio -> obj
|
||||
is Array<*> -> {
|
||||
// Native may return Object[]{ float[] samples, int sampleRate }
|
||||
val samples = obj.getOrNull(0) as? FloatArray
|
||||
?: error("Unexpected native TTS return: element[0] is not FloatArray")
|
||||
val sampleRate = (obj.getOrNull(1) as? Number)?.toInt()
|
||||
?: error("Unexpected native TTS return: element[1] is not Int/Number")
|
||||
GeneratedAudio(samples = samples, sampleRate = sampleRate)
|
||||
}
|
||||
else -> error("Unexpected native TTS return type: ${obj::class.java.name}")
|
||||
}
|
||||
}
|
||||
|
||||
fun allocate(assetManager: AssetManager? = null) {
|
||||
if (ptr == 0L) {
|
||||
ptr = if (assetManager != null) {
|
||||
newFromAsset(assetManager, config)
|
||||
} else {
|
||||
newFromFile(config)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun free() {
|
||||
if (ptr != 0L) {
|
||||
delete(ptr)
|
||||
ptr = 0
|
||||
}
|
||||
}
|
||||
|
||||
protected fun finalize() {
|
||||
if (ptr != 0L) {
|
||||
delete(ptr)
|
||||
ptr = 0
|
||||
}
|
||||
}
|
||||
|
||||
fun release() = finalize()
|
||||
|
||||
private external fun newFromAsset(
|
||||
assetManager: AssetManager,
|
||||
config: OfflineTtsConfig,
|
||||
): Long
|
||||
|
||||
private external fun newFromFile(
|
||||
config: OfflineTtsConfig,
|
||||
): Long
|
||||
|
||||
private external fun delete(ptr: Long)
|
||||
private external fun getSampleRate(ptr: Long): Int
|
||||
private external fun getNumSpeakers(ptr: Long): Int
|
||||
|
||||
// The returned array has two entries:
|
||||
// - the first entry is an 1-D float array containing audio samples.
|
||||
// Each sample is normalized to the range [-1, 1]
|
||||
// - the second entry is the sample rate
|
||||
private external fun generateImpl(
|
||||
ptr: Long,
|
||||
text: String,
|
||||
sid: Int = 0,
|
||||
speed: Float = 1.0f
|
||||
): Any
|
||||
|
||||
private external fun generateWithCallbackImpl(
|
||||
ptr: Long,
|
||||
text: String,
|
||||
sid: Int = 0,
|
||||
speed: Float = 1.0f,
|
||||
callback: (samples: FloatArray) -> Int
|
||||
): Any
|
||||
|
||||
|
||||
private external fun generateWithConfigImpl(
|
||||
ptr: Long,
|
||||
text: String,
|
||||
config: GenerationConfig,
|
||||
callback: ((samples: FloatArray) -> Int)?
|
||||
): Any
|
||||
|
||||
companion object {
|
||||
init {
|
||||
System.loadLibrary("sherpa-onnx-jni")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// please refer to
|
||||
// https://k2-fsa.github.io/sherpa/onnx/tts/pretrained_models/index.html
|
||||
// to download models
|
||||
fun getOfflineTtsConfig(
|
||||
modelDir: String,
|
||||
modelName: String, // for VITS
|
||||
acousticModelName: String, // for Matcha
|
||||
vocoder: String, // for Matcha
|
||||
voices: String, // for Kokoro or kitten
|
||||
lexicon: String,
|
||||
dataDir: String,
|
||||
dictDir: String, // unused
|
||||
ruleFsts: String,
|
||||
ruleFars: String,
|
||||
numThreads: Int? = null,
|
||||
isKitten: Boolean = false
|
||||
): OfflineTtsConfig {
|
||||
// For Matcha TTS, please set
|
||||
// acousticModelName, vocoder
|
||||
|
||||
// For Kokoro TTS, please set
|
||||
// modelName, voices
|
||||
|
||||
// For Kitten TTS, please set
|
||||
// modelName, voices, isKitten
|
||||
|
||||
// For VITS, please set
|
||||
// modelName
|
||||
|
||||
val numberOfThreads = if (numThreads != null) {
|
||||
numThreads
|
||||
} else if (voices.isNotEmpty()) {
|
||||
// for Kokoro and Kitten TTS models, we use more threads
|
||||
4
|
||||
} else {
|
||||
2
|
||||
}
|
||||
|
||||
if (modelName.isEmpty() && acousticModelName.isEmpty()) {
|
||||
throw IllegalArgumentException("Please specify a TTS model")
|
||||
}
|
||||
|
||||
if (modelName.isNotEmpty() && acousticModelName.isNotEmpty()) {
|
||||
throw IllegalArgumentException("Please specify either a VITS or a Matcha model, but not both")
|
||||
}
|
||||
|
||||
if (acousticModelName.isNotEmpty() && vocoder.isEmpty()) {
|
||||
throw IllegalArgumentException("Please provide vocoder for Matcha TTS")
|
||||
}
|
||||
|
||||
val vits = if (modelName.isNotEmpty() && voices.isEmpty()) {
|
||||
OfflineTtsVitsModelConfig(
|
||||
model = "$modelDir/$modelName",
|
||||
lexicon = "$modelDir/$lexicon",
|
||||
tokens = "$modelDir/tokens.txt",
|
||||
dataDir = dataDir,
|
||||
)
|
||||
} else {
|
||||
OfflineTtsVitsModelConfig()
|
||||
}
|
||||
|
||||
val matcha = if (acousticModelName.isNotEmpty()) {
|
||||
OfflineTtsMatchaModelConfig(
|
||||
acousticModel = "$modelDir/$acousticModelName",
|
||||
vocoder = vocoder,
|
||||
lexicon = "$modelDir/$lexicon",
|
||||
tokens = "$modelDir/tokens.txt",
|
||||
dataDir = dataDir,
|
||||
)
|
||||
} else {
|
||||
OfflineTtsMatchaModelConfig()
|
||||
}
|
||||
|
||||
val kokoro = if (voices.isNotEmpty() && !isKitten) {
|
||||
OfflineTtsKokoroModelConfig(
|
||||
model = "$modelDir/$modelName",
|
||||
voices = "$modelDir/$voices",
|
||||
tokens = "$modelDir/tokens.txt",
|
||||
dataDir = dataDir,
|
||||
lexicon = when {
|
||||
lexicon == "" -> lexicon
|
||||
"," in lexicon -> lexicon
|
||||
else -> "$modelDir/$lexicon"
|
||||
},
|
||||
)
|
||||
} else {
|
||||
OfflineTtsKokoroModelConfig()
|
||||
}
|
||||
|
||||
val kitten = if (isKitten) {
|
||||
OfflineTtsKittenModelConfig(
|
||||
model = "$modelDir/$modelName",
|
||||
voices = "$modelDir/$voices",
|
||||
tokens = "$modelDir/tokens.txt",
|
||||
dataDir = dataDir,
|
||||
)
|
||||
} else {
|
||||
OfflineTtsKittenModelConfig()
|
||||
}
|
||||
|
||||
return OfflineTtsConfig(
|
||||
model = OfflineTtsModelConfig(
|
||||
vits = vits,
|
||||
matcha = matcha,
|
||||
kokoro = kokoro,
|
||||
kitten = kitten,
|
||||
numThreads = numberOfThreads,
|
||||
debug = true,
|
||||
provider = "cpu",
|
||||
),
|
||||
ruleFsts = ruleFsts,
|
||||
ruleFars = ruleFars,
|
||||
)
|
||||
}
|
||||
149
app/src/main/java/com/k2fsa/sherpa/onnx/Vad.kt
Normal file
149
app/src/main/java/com/k2fsa/sherpa/onnx/Vad.kt
Normal file
@@ -0,0 +1,149 @@
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
package com.k2fsa.sherpa.onnx
|
||||
|
||||
import android.content.res.AssetManager
|
||||
|
||||
data class SileroVadModelConfig(
|
||||
var model: String = "",
|
||||
var threshold: Float = 0.5F,
|
||||
var minSilenceDuration: Float = 0.25F,
|
||||
var minSpeechDuration: Float = 0.25F,
|
||||
var windowSize: Int = 512,
|
||||
var maxSpeechDuration: Float = 5.0F,
|
||||
)
|
||||
|
||||
data class TenVadModelConfig(
|
||||
var model: String = "",
|
||||
var threshold: Float = 0.5F,
|
||||
var minSilenceDuration: Float = 0.25F,
|
||||
var minSpeechDuration: Float = 0.25F,
|
||||
var windowSize: Int = 256,
|
||||
var maxSpeechDuration: Float = 5.0F,
|
||||
)
|
||||
|
||||
data class VadModelConfig(
|
||||
var sileroVadModelConfig: SileroVadModelConfig = SileroVadModelConfig(),
|
||||
var tenVadModelConfig: TenVadModelConfig = TenVadModelConfig(),
|
||||
var sampleRate: Int = 16000,
|
||||
var numThreads: Int = 1,
|
||||
var provider: String = "cpu",
|
||||
var debug: Boolean = false,
|
||||
)
|
||||
|
||||
class SpeechSegment(val start: Int, val samples: FloatArray)
|
||||
|
||||
class Vad(
|
||||
assetManager: AssetManager? = null,
|
||||
var config: VadModelConfig,
|
||||
) {
|
||||
private var ptr: Long
|
||||
|
||||
init {
|
||||
if (assetManager != null) {
|
||||
ptr = newFromAsset(assetManager, config)
|
||||
} else {
|
||||
ptr = newFromFile(config)
|
||||
}
|
||||
}
|
||||
|
||||
protected fun finalize() {
|
||||
if (ptr != 0L) {
|
||||
delete(ptr)
|
||||
ptr = 0
|
||||
}
|
||||
}
|
||||
|
||||
fun release() = finalize()
|
||||
|
||||
fun compute(samples: FloatArray): Float = compute(ptr, samples)
|
||||
|
||||
|
||||
fun acceptWaveform(samples: FloatArray) = acceptWaveform(ptr, samples)
|
||||
|
||||
fun empty(): Boolean = empty(ptr)
|
||||
fun pop() = pop(ptr)
|
||||
|
||||
fun front(): SpeechSegment {
|
||||
return front(ptr)
|
||||
}
|
||||
|
||||
fun clear() = clear(ptr)
|
||||
|
||||
fun isSpeechDetected(): Boolean = isSpeechDetected(ptr)
|
||||
|
||||
fun reset() = reset(ptr)
|
||||
|
||||
fun flush() = flush(ptr)
|
||||
|
||||
private external fun delete(ptr: Long)
|
||||
|
||||
private external fun newFromAsset(
|
||||
assetManager: AssetManager,
|
||||
config: VadModelConfig,
|
||||
): Long
|
||||
|
||||
private external fun newFromFile(
|
||||
config: VadModelConfig,
|
||||
): Long
|
||||
|
||||
private external fun acceptWaveform(ptr: Long, samples: FloatArray)
|
||||
private external fun compute(ptr: Long, samples: FloatArray): Float
|
||||
|
||||
private external fun empty(ptr: Long): Boolean
|
||||
private external fun pop(ptr: Long)
|
||||
private external fun clear(ptr: Long)
|
||||
private external fun front(ptr: Long): SpeechSegment
|
||||
private external fun isSpeechDetected(ptr: Long): Boolean
|
||||
private external fun reset(ptr: Long)
|
||||
private external fun flush(ptr: Long)
|
||||
|
||||
companion object {
|
||||
init {
|
||||
System.loadLibrary("sherpa-onnx-jni")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Please visit
|
||||
// https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx
|
||||
// to download silero_vad.onnx
|
||||
// and put it inside the assets/
|
||||
// directory
|
||||
//
|
||||
// For ten-vad, please use
|
||||
// https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/ten-vad.onnx
|
||||
//
|
||||
fun getVadModelConfig(type: Int): VadModelConfig? {
|
||||
when (type) {
|
||||
0 -> {
|
||||
return VadModelConfig(
|
||||
sileroVadModelConfig = SileroVadModelConfig(
|
||||
model = "silero_vad.onnx",
|
||||
threshold = 0.5F,
|
||||
minSilenceDuration = 0.25F,
|
||||
minSpeechDuration = 0.25F,
|
||||
windowSize = 512,
|
||||
),
|
||||
sampleRate = 16000,
|
||||
numThreads = 1,
|
||||
provider = "cpu",
|
||||
)
|
||||
}
|
||||
|
||||
1 -> {
|
||||
return VadModelConfig(
|
||||
tenVadModelConfig = TenVadModelConfig(
|
||||
model = "ten-vad.onnx",
|
||||
threshold = 0.5F,
|
||||
minSilenceDuration = 0.25F,
|
||||
minSpeechDuration = 0.25F,
|
||||
windowSize = 256,
|
||||
),
|
||||
sampleRate = 16000,
|
||||
numThreads = 1,
|
||||
provider = "cpu",
|
||||
)
|
||||
}
|
||||
}
|
||||
return null
|
||||
}
|
||||
Reference in New Issue
Block a user