Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
clemlesne committed Dec 14, 2024
2 parents 7137e08 + 42286af commit 16b9118
Show file tree
Hide file tree
Showing 10 changed files with 418 additions and 436 deletions.
8 changes: 3 additions & 5 deletions app/helpers/call_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ async def on_call_connected(
_handle_recording(
call=call,
client=client,
scheduler=scheduler,
server_call_id=server_call_id,
), # Second, start recording the call
)
Expand Down Expand Up @@ -235,7 +234,7 @@ async def on_automation_recognize_error(
logger.info(
"Timeout, retrying language selection (%s/%s)",
call.recognition_retry,
await recognition_retry_max(scheduler),
await recognition_retry_max(),
)
await _handle_ivr_language(
call=call,
Expand Down Expand Up @@ -321,7 +320,7 @@ async def _pre_recognize_error(
Returns True if the call should continue, False if it should end.
"""
# Voice retries are exhausted, end call
if call.recognition_retry >= await recognition_retry_max(scheduler):
if call.recognition_retry >= await recognition_retry_max():
logger.info("Timeout, ending call")
return False

Expand Down Expand Up @@ -793,15 +792,14 @@ async def _handle_ivr_language(
async def _handle_recording(
call: CallStateModel,
client: CallAutomationClient,
scheduler: Scheduler,
server_call_id: str,
) -> None:
"""
Start recording the call.
Feature activation is checked before starting the recording.
"""
if not await recording_enabled(scheduler):
if not await recording_enabled():
return

assert CONFIG.communication_services.recording_container_url
Expand Down
184 changes: 38 additions & 146 deletions app/helpers/call_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,22 @@
from azure.cognitiveservices.speech import (
SpeechSynthesizer,
)
from azure.cognitiveservices.speech.audio import PushAudioInputStream
from azure.communication.callautomation.aio import CallAutomationClient
from openai import APIError

from app.helpers.call_utils import (
AECStream,
SttClient,
handle_media,
handle_realtime_tts,
tts_sentence_split,
use_stt_client,
use_tts_client,
)
from app.helpers.config import CONFIG
from app.helpers.features import (
answer_hard_timeout_sec,
answer_soft_timeout_sec,
phone_silence_timeout_sec,
recognition_stt_complete_timeout_ms,
vad_cutoff_timeout_ms,
vad_silence_timeout_ms,
)
Expand All @@ -38,9 +36,7 @@
from app.helpers.logging import logger
from app.helpers.monitoring import (
SpanAttributeEnum,
call_answer_latency,
call_cutoff_latency,
call_stt_complete_latency,
gauge_set,
tracer,
)
Expand All @@ -59,7 +55,7 @@

# TODO: Refacto, this function is too long
@tracer.start_as_current_span("call_load_llm_chat")
async def load_llm_chat( # noqa: PLR0913, PLR0915
async def load_llm_chat( # noqa: PLR0913
audio_in: asyncio.Queue[bytes],
audio_out: asyncio.Queue[bytes | bool],
audio_sample_rate: int,
Expand All @@ -70,89 +66,25 @@ async def load_llm_chat( # noqa: PLR0913, PLR0915
training_callback: Callable[[CallStateModel], Awaitable[None]],
) -> None:
# Init language recognition
stt_buffer: list[str] = [] # Temporary buffer for recognition
stt_complete_gate = asyncio.Event() # Gate to wait for the recognition
aec = AECStream(
sample_rate=audio_sample_rate,
scheduler=scheduler,
)
audio_reference: asyncio.Queue[bytes] = asyncio.Queue()
answer_start: float | None = None

async def _send_in_to_aec() -> None:
"""
Send input audio to the echo cancellation.
"""
while True:
in_chunck = await audio_in.get()
audio_in.task_done()
await aec.push_input(in_chunck)

async def _send_out_to_aec() -> None:
"""
Forward the TTS to the echo cancellation and output.
"""
while True:
# Consume the audio
out_chunck = await audio_reference.get()
audio_reference.task_done()

# Report the answer latency and reset the timer
nonlocal answer_start
if answer_start:
# Enrich span
gauge_set(
metric=call_answer_latency,
value=time.monotonic() - answer_start,
)
answer_start = None

# Forward the audio
await asyncio.gather(
# First, send the audio to the output
audio_out.put(out_chunck),
# Then, send the audio to the echo cancellation
aec.push_reference(out_chunck),
)

def _partial_stt_callback(text: str) -> None:
"""
Store the partial recognition in the buffer.
"""
# Init buffer if empty
if not stt_buffer:
stt_buffer.append("")
# Replace the partial recognition
stt_buffer[-1] = text
logger.debug("Partial recognition: %s", stt_buffer)

def _complete_stt_callback(text: str) -> None:
"""
Store the recognition in the buffer.
"""
# Init buffer if empty
if not stt_buffer:
stt_buffer.append("")
# Store the recognition
stt_buffer[-1] = text
logger.debug("Complete recognition: %s", stt_buffer)
# Add a new buffer for the next partial recognition
stt_buffer.append("")

# Open the recognition gate
stt_complete_gate.set()
audio_tts: asyncio.Queue[bytes] = asyncio.Queue()

async with (
use_stt_client(
audio_sample_rate=audio_sample_rate,
SttClient(
call=call,
complete_callback=_complete_stt_callback,
partial_callback=_partial_stt_callback,
) as stt_stream,
sample_rate=audio_sample_rate,
scheduler=scheduler,
) as stt_client,
use_tts_client(
call=call,
out=audio_reference,
out=audio_tts,
) as tts_client,
AECStream(
in_raw_queue=audio_in,
in_reference_queue=audio_tts,
out_queue=audio_out,
sample_rate=audio_sample_rate,
scheduler=scheduler,
) as aec,
):
# Build scheduler
last_chat: asyncio.Task | None = None
Expand Down Expand Up @@ -198,10 +130,6 @@ async def _stop_callback() -> None:
# Send a stop signal
await audio_out.put(False)

# Reset TTS buffer
stt_buffer.clear()
stt_complete_gate.clear()

# Report the cutoff latency
gauge_set(
metric=call_cutoff_latency,
Expand Down Expand Up @@ -235,40 +163,17 @@ async def _commit_answer(
if wait:
await last_chat

async def _compute_stt_metrics() -> None:
"""
Report the recognition latency.
"""
start = time.monotonic()
await stt_complete_gate.wait()
gauge_set(
metric=call_stt_complete_latency,
value=time.monotonic() - start,
)

async def _response_callback(_retry: bool = False) -> None:
"""
Triggered when the audio buffer needs to be processed.
If the recognition is empty, retry the recognition once. Otherwise, process the response.
"""
# Report the answer latency
nonlocal answer_start
answer_start = time.monotonic()
aec.answer_start()

# Report the STT metrics
stt_metrics_task = asyncio.create_task(_compute_stt_metrics())

# Wait the complete recognition for 50ms maximum
try:
await asyncio.wait_for(
stt_complete_gate.wait(),
timeout=await recognition_stt_complete_timeout_ms(scheduler) / 1000,
)
except TimeoutError:
logger.debug("Complete recognition timeout, using partial recognition")

stt_text = " ".join(stt_buffer).strip()
# Pull the recognition
stt_text = await stt_client.pull_recognition()

# Ignore empty recognition
if not stt_text:
Expand Down Expand Up @@ -297,10 +202,7 @@ async def _response_callback(_retry: bool = False) -> None:
)

# Process the response and wait for latency metrics
await asyncio.gather(
_commit_answer(wait=False),
stt_metrics_task,
)
await _commit_answer(wait=False)

# First call
if len(call.messages) <= 1:
Expand All @@ -319,22 +221,14 @@ async def _response_callback(_retry: bool = False) -> None:
wait=False,
)

await asyncio.gather(
# Start the echo cancellation
aec.process_stream(),
# Apply the echo cancellation
_send_in_to_aec(),
_send_out_to_aec(),
# Detect VAD
_process_audio_for_vad(
call=call,
echo_cancellation=aec,
out_stream=stt_stream,
response_callback=_response_callback,
scheduler=scheduler,
stop_callback=_stop_callback,
timeout_callback=_timeout_callback,
),
# Detect VAD
await _process_audio_for_vad(
call=call,
in_callback=aec.pull_audio,
out_callback=stt_client.push_audio,
response_callback=_response_callback,
stop_callback=_stop_callback,
timeout_callback=_timeout_callback,
)


Expand Down Expand Up @@ -412,10 +306,10 @@ def _loading_task() -> asyncio.Task:
# Timeouts
soft_timeout_triggered = False
soft_timeout_task = asyncio.create_task(
asyncio.sleep(await answer_soft_timeout_sec(scheduler))
asyncio.sleep(await answer_soft_timeout_sec())
)
hard_timeout_task = asyncio.create_task(
asyncio.sleep(await answer_hard_timeout_sec(scheduler))
asyncio.sleep(await answer_hard_timeout_sec())
)

def _clear_tasks() -> None:
Expand Down Expand Up @@ -445,7 +339,7 @@ def _clear_tasks() -> None:
if hard_timeout_task.done():
logger.warning(
"Hard timeout of %ss reached",
await answer_hard_timeout_sec(scheduler),
await answer_hard_timeout_sec(),
)
# Clean up
_clear_tasks()
Expand All @@ -457,7 +351,7 @@ def _clear_tasks() -> None:
if soft_timeout_task.done() and not soft_timeout_triggered:
logger.warning(
"Soft timeout of %ss reached",
await answer_soft_timeout_sec(scheduler),
await answer_soft_timeout_sec(),
)
soft_timeout_triggered = True
# Never store the error message in the call history, it has caused hallucinations in the LLM
Expand Down Expand Up @@ -606,7 +500,6 @@ async def _content_callback(buffer: str) -> None:
async for delta in completion_stream(
max_tokens=160, # Lowest possible value for 90% of the cases, if not sufficient, retry will be triggered, 100 tokens ~= 75 words, 20 words ~= 1 sentence, 6 sentences ~= 160 tokens
messages=call.messages,
scheduler=scheduler,
system=system,
tools=tools,
):
Expand Down Expand Up @@ -721,10 +614,9 @@ async def _content_callback(buffer: str) -> None:
# TODO: Refacto and simplify
async def _process_audio_for_vad( # noqa: PLR0913
call: CallStateModel,
echo_cancellation: AECStream,
out_stream: PushAudioInputStream,
in_callback: Callable[[], Awaitable[tuple[bytes, bool]]],
out_callback: Callable[[bytes], None],
response_callback: Callable[[], Awaitable[None]],
scheduler: Scheduler,
stop_callback: Callable[[], Awaitable[None]],
timeout_callback: Callable[[], Awaitable[None]],
) -> None:
Expand All @@ -748,7 +640,7 @@ async def _wait_for_silence() -> None:
"""
# Wait before flushing
nonlocal stop_task
timeout_ms = await vad_silence_timeout_ms(scheduler)
timeout_ms = await vad_silence_timeout_ms()
await asyncio.sleep(timeout_ms / 1000)

# Cancel the clear TTS task
Expand All @@ -761,7 +653,7 @@ async def _wait_for_silence() -> None:
await response_callback()

# Wait for silence and trigger timeout
timeout_sec = await phone_silence_timeout_sec(scheduler)
timeout_sec = await phone_silence_timeout_sec()
while True:
# Stop this time if the call played a message
timeout_start = datetime.now(UTC)
Expand Down Expand Up @@ -790,7 +682,7 @@ async def _wait_for_stop() -> None:
"""
Stop the TTS if user speaks for too long.
"""
timeout_ms = await vad_cutoff_timeout_ms(scheduler)
timeout_ms = await vad_cutoff_timeout_ms()

# Wait before clearing the TTS queue
await asyncio.sleep(timeout_ms / 1000)
Expand All @@ -801,10 +693,10 @@ async def _wait_for_stop() -> None:

while True:
# Wait for the next audio packet
out_chunck, is_speech = await echo_cancellation.pull_audio()
out_chunck, is_speech = await in_callback()

# Add audio to the buffer
out_stream.write(out_chunck)
out_callback(out_chunck)

# If no speech, init the silence task
if not is_speech:
Expand Down
Loading

0 comments on commit 16b9118

Please sign in to comment.