fix(ios): prefetch talk tts segments

This commit is contained in:
Nimrod Gutman
2026-02-21 20:42:26 +02:00
committed by Nimrod Gutman
parent 9632b9bcf0
commit 8a661e30c9
2 changed files with 262 additions and 70 deletions

View File

@@ -704,7 +704,7 @@ final class GatewayConnectionController {
var addr = in_addr()
let parsed = host.withCString { inet_pton(AF_INET, $0, &addr) == 1 }
guard parsed else { return false }
let value = ntohl(addr.s_addr)
let value = UInt32(bigEndian: addr.s_addr)
let firstOctet = UInt8((value >> 24) & 0xFF)
return firstOctet == 127
}

View File

@@ -91,6 +91,8 @@ final class TalkModeManager: NSObject {
private var incrementalSpeechBuffer = IncrementalSpeechBuffer()
private var incrementalSpeechContext: IncrementalSpeechContext?
private var incrementalSpeechDirective: TalkDirective?
private var incrementalSpeechPrefetch: IncrementalSpeechPrefetchState?
private var incrementalSpeechPrefetchMonitorTask: Task<Void, Never>?
private let logger = Logger(subsystem: "bot.molt", category: "TalkMode")
@@ -1177,6 +1179,7 @@ final class TalkModeManager: NSObject {
self.incrementalSpeechQueue.removeAll()
self.incrementalSpeechTask?.cancel()
self.incrementalSpeechTask = nil
self.cancelIncrementalPrefetch()
self.incrementalSpeechActive = true
self.incrementalSpeechUsed = false
self.incrementalSpeechLanguage = nil
@@ -1189,6 +1192,7 @@ final class TalkModeManager: NSObject {
self.incrementalSpeechQueue.removeAll()
self.incrementalSpeechTask?.cancel()
self.incrementalSpeechTask = nil
self.cancelIncrementalPrefetch()
self.incrementalSpeechActive = false
self.incrementalSpeechContext = nil
self.incrementalSpeechDirective = nil
@@ -1216,20 +1220,168 @@ final class TalkModeManager: NSObject {
self.incrementalSpeechTask = Task { @MainActor [weak self] in
guard let self else { return }
defer {
self.cancelIncrementalPrefetch()
self.isSpeaking = false
self.stopRecognition()
self.incrementalSpeechTask = nil
}
while !Task.isCancelled {
guard !self.incrementalSpeechQueue.isEmpty else { break }
let segment = self.incrementalSpeechQueue.removeFirst()
self.statusText = "Speaking…"
self.isSpeaking = true
self.lastSpokenText = segment
await self.speakIncrementalSegment(segment)
await self.updateIncrementalContextIfNeeded()
let context = self.incrementalSpeechContext
let prefetchedAudio = await self.consumeIncrementalPrefetchedAudioIfAvailable(
for: segment,
context: context)
if let context {
self.startIncrementalPrefetchMonitor(context: context)
}
await self.speakIncrementalSegment(
segment,
context: context,
prefetchedAudio: prefetchedAudio)
self.cancelIncrementalPrefetchMonitor()
}
self.isSpeaking = false
self.stopRecognition()
self.incrementalSpeechTask = nil
}
}
private func cancelIncrementalPrefetch() {
self.cancelIncrementalPrefetchMonitor()
self.incrementalSpeechPrefetch?.task.cancel()
self.incrementalSpeechPrefetch = nil
}
private func cancelIncrementalPrefetchMonitor() {
self.incrementalSpeechPrefetchMonitorTask?.cancel()
self.incrementalSpeechPrefetchMonitorTask = nil
}
private func startIncrementalPrefetchMonitor(context: IncrementalSpeechContext) {
self.cancelIncrementalPrefetchMonitor()
self.incrementalSpeechPrefetchMonitorTask = Task { @MainActor [weak self] in
guard let self else { return }
while !Task.isCancelled {
if self.ensureIncrementalPrefetchForUpcomingSegment(context: context) {
return
}
try? await Task.sleep(nanoseconds: 40_000_000)
}
}
}
private func ensureIncrementalPrefetchForUpcomingSegment(context: IncrementalSpeechContext) -> Bool {
guard context.canUseElevenLabs else {
self.cancelIncrementalPrefetch()
return false
}
guard let nextSegment = self.incrementalSpeechQueue.first else { return false }
if let existing = self.incrementalSpeechPrefetch {
if existing.segment == nextSegment, existing.context == context {
return true
}
existing.task.cancel()
self.incrementalSpeechPrefetch = nil
}
self.startIncrementalPrefetch(segment: nextSegment, context: context)
return self.incrementalSpeechPrefetch != nil
}
private func startIncrementalPrefetch(segment: String, context: IncrementalSpeechContext) {
guard context.canUseElevenLabs, let apiKey = context.apiKey, let voiceId = context.voiceId else { return }
let prefetchOutputFormat = self.resolveIncrementalPrefetchOutputFormat(context: context)
let request = self.makeIncrementalTTSRequest(
text: segment,
context: context,
outputFormat: prefetchOutputFormat)
let id = UUID()
let task = Task { [weak self] in
let stream = ElevenLabsTTSClient(apiKey: apiKey).streamSynthesize(voiceId: voiceId, request: request)
var chunks: [Data] = []
do {
for try await chunk in stream {
try Task.checkCancellation()
chunks.append(chunk)
}
await self?.completeIncrementalPrefetch(id: id, chunks: chunks)
} catch is CancellationError {
await self?.clearIncrementalPrefetch(id: id)
} catch {
await self?.failIncrementalPrefetch(id: id, error: error)
}
}
self.incrementalSpeechPrefetch = IncrementalSpeechPrefetchState(
id: id,
segment: segment,
context: context,
outputFormat: prefetchOutputFormat,
chunks: nil,
task: task)
}
private func completeIncrementalPrefetch(id: UUID, chunks: [Data]) {
guard var prefetch = self.incrementalSpeechPrefetch, prefetch.id == id else { return }
prefetch.chunks = chunks
self.incrementalSpeechPrefetch = prefetch
}
private func clearIncrementalPrefetch(id: UUID) {
guard let prefetch = self.incrementalSpeechPrefetch, prefetch.id == id else { return }
prefetch.task.cancel()
self.incrementalSpeechPrefetch = nil
}
private func failIncrementalPrefetch(id: UUID, error: any Error) {
guard let prefetch = self.incrementalSpeechPrefetch, prefetch.id == id else { return }
self.logger.debug("incremental prefetch failed: \(error.localizedDescription, privacy: .public)")
prefetch.task.cancel()
self.incrementalSpeechPrefetch = nil
}
private func consumeIncrementalPrefetchedAudioIfAvailable(
for segment: String,
context: IncrementalSpeechContext?
) async -> IncrementalPrefetchedAudio?
{
guard let context else {
self.cancelIncrementalPrefetch()
return nil
}
guard let prefetch = self.incrementalSpeechPrefetch else {
return nil
}
guard prefetch.context == context else {
prefetch.task.cancel()
self.incrementalSpeechPrefetch = nil
return nil
}
guard prefetch.segment == segment else {
return nil
}
if let chunks = prefetch.chunks, !chunks.isEmpty {
let prefetched = IncrementalPrefetchedAudio(chunks: chunks, outputFormat: prefetch.outputFormat)
self.incrementalSpeechPrefetch = nil
return prefetched
}
await prefetch.task.value
guard let completed = self.incrementalSpeechPrefetch else { return nil }
guard completed.context == context, completed.segment == segment else { return nil }
guard let chunks = completed.chunks, !chunks.isEmpty else { return nil }
let prefetched = IncrementalPrefetchedAudio(chunks: chunks, outputFormat: completed.outputFormat)
self.incrementalSpeechPrefetch = nil
return prefetched
}
private func resolveIncrementalPrefetchOutputFormat(context: IncrementalSpeechContext) -> String? {
if TalkTTSValidation.pcmSampleRate(from: context.outputFormat) != nil {
return ElevenLabsTTSClient.validatedOutputFormat("mp3_44100")
}
return context.outputFormat
}
private func finishIncrementalSpeech() async {
guard self.incrementalSpeechActive else { return }
let leftover = self.incrementalSpeechBuffer.flush()
@@ -1337,77 +1489,103 @@ final class TalkModeManager: NSObject {
canUseElevenLabs: canUseElevenLabs)
}
private func speakIncrementalSegment(_ text: String) async {
await self.updateIncrementalContextIfNeeded()
guard let context = self.incrementalSpeechContext else {
private func makeIncrementalTTSRequest(
text: String,
context: IncrementalSpeechContext,
outputFormat: String?
) -> ElevenLabsTTSRequest
{
ElevenLabsTTSRequest(
text: text,
modelId: context.modelId,
outputFormat: outputFormat,
speed: TalkTTSValidation.resolveSpeed(
speed: context.directive?.speed,
rateWPM: context.directive?.rateWPM),
stability: TalkTTSValidation.validatedStability(
context.directive?.stability,
modelId: context.modelId),
similarity: TalkTTSValidation.validatedUnit(context.directive?.similarity),
style: TalkTTSValidation.validatedUnit(context.directive?.style),
speakerBoost: context.directive?.speakerBoost,
seed: TalkTTSValidation.validatedSeed(context.directive?.seed),
normalize: ElevenLabsTTSClient.validatedNormalize(context.directive?.normalize),
language: context.language,
latencyTier: TalkTTSValidation.validatedLatencyTier(context.directive?.latencyTier))
}
private static func makeBufferedAudioStream(chunks: [Data]) -> AsyncThrowingStream<Data, Error> {
AsyncThrowingStream { continuation in
for chunk in chunks {
continuation.yield(chunk)
}
continuation.finish()
}
}
private func speakIncrementalSegment(
_ text: String,
context preferredContext: IncrementalSpeechContext? = nil,
prefetchedAudio: IncrementalPrefetchedAudio? = nil
) async
{
let context: IncrementalSpeechContext
if let preferredContext {
context = preferredContext
} else {
await self.updateIncrementalContextIfNeeded()
guard let resolvedContext = self.incrementalSpeechContext else {
try? await TalkSystemSpeechSynthesizer.shared.speak(
text: text,
language: self.incrementalSpeechLanguage)
return
}
context = resolvedContext
}
guard context.canUseElevenLabs, let apiKey = context.apiKey, let voiceId = context.voiceId else {
try? await TalkSystemSpeechSynthesizer.shared.speak(
text: text,
language: self.incrementalSpeechLanguage)
return
}
if context.canUseElevenLabs, let apiKey = context.apiKey, let voiceId = context.voiceId {
let request = ElevenLabsTTSRequest(
text: text,
modelId: context.modelId,
outputFormat: context.outputFormat,
speed: TalkTTSValidation.resolveSpeed(
speed: context.directive?.speed,
rateWPM: context.directive?.rateWPM),
stability: TalkTTSValidation.validatedStability(
context.directive?.stability,
modelId: context.modelId),
similarity: TalkTTSValidation.validatedUnit(context.directive?.similarity),
style: TalkTTSValidation.validatedUnit(context.directive?.style),
speakerBoost: context.directive?.speakerBoost,
seed: TalkTTSValidation.validatedSeed(context.directive?.seed),
normalize: ElevenLabsTTSClient.validatedNormalize(context.directive?.normalize),
language: context.language,
latencyTier: TalkTTSValidation.validatedLatencyTier(context.directive?.latencyTier))
let client = ElevenLabsTTSClient(apiKey: apiKey)
let stream = client.streamSynthesize(voiceId: voiceId, request: request)
let sampleRate = TalkTTSValidation.pcmSampleRate(from: context.outputFormat)
let result: StreamingPlaybackResult
if let sampleRate {
self.lastPlaybackWasPCM = true
var playback = await self.pcmPlayer.play(stream: stream, sampleRate: sampleRate)
if !playback.finished, playback.interruptedAt == nil {
self.logger.warning("pcm playback failed; retrying mp3")
self.lastPlaybackWasPCM = false
let mp3Format = ElevenLabsTTSClient.validatedOutputFormat("mp3_44100")
let mp3Stream = client.streamSynthesize(
voiceId: voiceId,
request: ElevenLabsTTSRequest(
text: text,
modelId: context.modelId,
outputFormat: mp3Format,
speed: TalkTTSValidation.resolveSpeed(
speed: context.directive?.speed,
rateWPM: context.directive?.rateWPM),
stability: TalkTTSValidation.validatedStability(
context.directive?.stability,
modelId: context.modelId),
similarity: TalkTTSValidation.validatedUnit(context.directive?.similarity),
style: TalkTTSValidation.validatedUnit(context.directive?.style),
speakerBoost: context.directive?.speakerBoost,
seed: TalkTTSValidation.validatedSeed(context.directive?.seed),
normalize: ElevenLabsTTSClient.validatedNormalize(context.directive?.normalize),
language: context.language,
latencyTier: TalkTTSValidation.validatedLatencyTier(context.directive?.latencyTier)))
playback = await self.mp3Player.play(stream: mp3Stream)
}
result = playback
} else {
self.lastPlaybackWasPCM = false
result = await self.mp3Player.play(stream: stream)
}
if !result.finished, let interruptedAt = result.interruptedAt {
self.lastInterruptedAtSeconds = interruptedAt
}
let client = ElevenLabsTTSClient(apiKey: apiKey)
let request = self.makeIncrementalTTSRequest(
text: text,
context: context,
outputFormat: context.outputFormat)
let stream: AsyncThrowingStream<Data, Error>
if let prefetchedAudio, !prefetchedAudio.chunks.isEmpty {
stream = Self.makeBufferedAudioStream(chunks: prefetchedAudio.chunks)
} else {
try? await TalkSystemSpeechSynthesizer.shared.speak(
text: text,
language: self.incrementalSpeechLanguage)
stream = client.streamSynthesize(voiceId: voiceId, request: request)
}
let playbackFormat = prefetchedAudio?.outputFormat ?? context.outputFormat
let sampleRate = TalkTTSValidation.pcmSampleRate(from: playbackFormat)
let result: StreamingPlaybackResult
if let sampleRate {
self.lastPlaybackWasPCM = true
var playback = await self.pcmPlayer.play(stream: stream, sampleRate: sampleRate)
if !playback.finished, playback.interruptedAt == nil {
self.logger.warning("pcm playback failed; retrying mp3")
self.lastPlaybackWasPCM = false
let mp3Format = ElevenLabsTTSClient.validatedOutputFormat("mp3_44100")
let mp3Stream = client.streamSynthesize(
voiceId: voiceId,
request: self.makeIncrementalTTSRequest(
text: text,
context: context,
outputFormat: mp3Format))
playback = await self.mp3Player.play(stream: mp3Stream)
}
result = playback
} else {
self.lastPlaybackWasPCM = false
result = await self.mp3Player.play(stream: stream)
}
if !result.finished, let interruptedAt = result.interruptedAt {
self.lastInterruptedAtSeconds = interruptedAt
}
}
@@ -1874,7 +2052,7 @@ extension TalkModeManager {
}
#endif
private struct IncrementalSpeechContext {
private struct IncrementalSpeechContext: Equatable {
let apiKey: String?
let voiceId: String?
let modelId: String?
@@ -1884,4 +2062,18 @@ private struct IncrementalSpeechContext {
let canUseElevenLabs: Bool
}
private struct IncrementalSpeechPrefetchState {
let id: UUID
let segment: String
let context: IncrementalSpeechContext
let outputFormat: String?
var chunks: [Data]?
let task: Task<Void, Never>
}
private struct IncrementalPrefetchedAudio {
let chunks: [Data]
let outputFormat: String?
}
// swiftlint:enable type_body_length