Skip to content

Commit

Permalink
Merge pull request #117 from bolna-ai/feat/BOLNA-15/web-calling-with-…
Browse files Browse the repository at this point in the history
…deepgram-enhancements

Feat/bolna 15/web calling with deepgram enhancements
  • Loading branch information
prateeksachan authored Feb 25, 2025
2 parents b508080 + 2f8e644 commit 74b7225
Show file tree
Hide file tree
Showing 22 changed files with 633 additions and 352 deletions.
479 changes: 286 additions & 193 deletions bolna/agent_manager/task_manager.py

Large diffs are not rendered by default.

22 changes: 22 additions & 0 deletions bolna/helpers/mark_event_meta_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from bolna.helpers.logger_config import configure_logger

logger = configure_logger(__name__)

class MarkEventMetaData:
def __init__(self):
self.mark_event_meta_data = {}

def update_data(self, mark_id, value):
logger.info(f"Updating mark_id = {mark_id} with value = {value}")
self.mark_event_meta_data[mark_id] = value

def fetch_data(self, mark_id):
logger.info(f"Fetching meta data details for mark_id = {mark_id}")
return self.mark_event_meta_data.pop(mark_id, {})

def clear_data(self):
logger.info(f"Clearing mark meta data dict")
self.mark_event_meta_data = {}

def __str__(self):
return f"{self.mark_event_meta_data}"
47 changes: 47 additions & 0 deletions bolna/helpers/observable_variable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import asyncio
import inspect
from bolna.helpers.logger_config import configure_logger

logger = configure_logger(__name__)

class ObservableVariable:
def __init__(self, value):
self._value = value
self._observers = []

def add_observer(self, observer):
"""
Register an observer function.
The observer can be a synchronous function or an async function.
"""
self._observers.append(observer)

@property
def value(self):
"""Getter for the observable variable."""
return self._value

@value.setter
def value(self, new_value):
"""Setter that updates the variable and notifies observers if the value changes."""
if self._value != new_value:
self._value = new_value
self._notify_observers(new_value)

def _notify_observers(self, new_value):
"""
Notify each observer about the new value.
Async observers are scheduled appropriately.
"""
for observer in self._observers:
if inspect.iscoroutinefunction(observer):
try:
# If an event loop is already running, schedule the async observer
loop = asyncio.get_running_loop()
loop.create_task(observer(new_value))
except RuntimeError:
# No running loop; run the async function in a temporary event loop
asyncio.run(observer(new_value))
else:
# Synchronous observer: call it directly
observer(new_value)
2 changes: 1 addition & 1 deletion bolna/input_handlers/daily_webcall.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


class DailyInputHandler:
def __init__(self, queues=None, websocket=None, input_types=None, mark_set=None, queue=None,
def __init__(self, queues=None, websocket=None, input_types=None, mark_event_meta_data=None, queue=None,
conversation_recording=None, room_url=None):
self.queues = queues
self.websocket = websocket
Expand Down
67 changes: 63 additions & 4 deletions bolna/input_handlers/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@


class DefaultInputHandler:
def __init__(self, queues=None, websocket=None, input_types=None, mark_set = None, queue = None, turn_based_conversation=False, conversation_recording = None):
def __init__(self, queues=None, websocket=None, input_types=None, mark_event_meta_data=None, queue=None,
turn_based_conversation=False, conversation_recording=None, is_welcome_message_played=False,
observable_variables=None):
self.queues = queues
self.websocket = websocket
self.input_types = input_types
Expand All @@ -20,7 +22,24 @@ def __init__(self, queues=None, websocket=None, input_types=None, mark_set = Non
self.turn_based_conversation = turn_based_conversation
self.queue = queue
self.conversation_recording = conversation_recording
self.is_welcome_message_played = is_welcome_message_played
# This variable stores the response which has been heard by the user
self.response_heard_by_user = ""
self._is_audio_being_played_to_user = False
self.observable_variables = observable_variables
self.mark_event_meta_data = mark_event_meta_data

def update_is_audio_being_played(self, value):
self._is_audio_being_played_to_user = value

def is_audio_being_played_to_user(self):
return self._is_audio_being_played_to_user

def get_response_heard_by_user(self):
response = self.response_heard_by_user
self.response_heard_by_user = ""
return response.strip()

async def stop_handler(self):
self.running = False
try:
Expand All @@ -32,6 +51,40 @@ async def stop_handler(self):
def get_stream_sid(self):
return str(uuid.uuid4())

def welcome_message_played(self):
return self.is_welcome_message_played

def get_mark_event_meta_data_obj(self, packet):
mark_id = packet["name"]
return self.mark_event_meta_data.fetch_data(mark_id)

def process_mark_message(self, packet):
mark_event_meta_data_obj = self.get_mark_event_meta_data_obj(packet)
if not mark_event_meta_data_obj:
logger.info(f"No object retrieved from global dict of mark_event_meta_data for received mark event - {packet}")
return

logger.info(f"Mark event meta data object retrieved = {mark_event_meta_data_obj}")
message_type = mark_event_meta_data_obj.get("type")
self.response_heard_by_user += mark_event_meta_data_obj.get("text_synthesized")

if mark_event_meta_data_obj.get("is_final_chunk"):
self._is_audio_being_played_to_user = False

if message_type != "is_user_online_message":
self.observable_variables["final_chunk_played_observable"].value = not self.observable_variables["final_chunk_played_observable"].value

if message_type == "agent_welcome_message":
logger.info("Received mark event for agent_welcome_message")
self.is_welcome_message_played = True

elif message_type == "agent_hangup":
logger.info(f"Agent hangup has been triggered")
self.observable_variables["agent_hangup_observable"].value = True

def __process_mark_event(self, packet):
self.process_mark_message(packet)

def __process_audio(self, audio):
data = base64.b64decode(audio)
ws_data_packet = create_ws_data_packet(
Expand Down Expand Up @@ -86,16 +139,22 @@ async def _listen(self):
return

async def process_message(self, message):
if message['type'] not in self.input_types.keys() and not self.turn_based_conversation:
logger.info(f"straight away returning")
return {"message": "invalid input type"}
# TODO check what condition needs to be added over here
# if message['type'] not in self.input_types.keys() and not self.turn_based_conversation:
# logger.info(f"straight away returning")
# return {"message": "invalid input type"}

if message['type'] == 'audio':
self.__process_audio(message['data'])

elif message["type"] == "text":
logger.info(f"Received text: {message['data']}")
self.__process_text(message['data'])

elif message["type"] == "mark":
logger.info(f"Received mark event")
self.__process_mark_event(message)

else:
return {"message": "Other modalities not implemented yet"}

Expand Down
17 changes: 9 additions & 8 deletions bolna/input_handlers/telephony.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@


class TelephonyInputHandler(DefaultInputHandler):
def __init__(self, queues, websocket=None, input_types=None, mark_set=None, turn_based_conversation=False):
super().__init__(queues, websocket, input_types, turn_based_conversation)
def __init__(self, queues, websocket=None, input_types=None, mark_event_meta_data=None, turn_based_conversation=False,
is_welcome_message_played=False, observable_variables=None):
super().__init__(queues, websocket, input_types, mark_event_meta_data, turn_based_conversation,
is_welcome_message_played=is_welcome_message_played, observable_variables=observable_variables)
self.stream_sid = None
self.call_sid = None
self.buffer = []
self.message_count = 0
self.mark_set = mark_set
# self.mark_event_meta_data = mark_event_meta_data
self.last_media_received = 0
self.io_provider = None

Expand All @@ -34,9 +36,8 @@ async def call_start(self, packet):
async def disconnect_stream(self):
pass

async def process_mark_message(self, packet):
if packet["mark"]["name"] in self.mark_set:
self.mark_set.remove(packet["mark"]["name"])
# def get_mark_event_meta_data_obj(self, packet):
# pass

async def stop_handler(self):
asyncio.create_task(self.disconnect_stream())
Expand Down Expand Up @@ -94,8 +95,8 @@ async def _listen(self):
else:
logger.info("Getting media elements but not inbound media")

elif packet['event'] == 'mark':
await self.process_mark_message(packet)
elif packet['event'] == 'mark' or packet['event'] == 'playedStream':
self.process_mark_message(packet)

elif packet['event'] == 'stop':
logger.info('call stopping')
Expand Down
9 changes: 7 additions & 2 deletions bolna/input_handlers/telephony_providers/exotel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,16 @@


class ExotelInputHandler(TelephonyInputHandler):
def __init__(self, queues, websocket=None, input_types=None, mark_set=None, turn_based_conversation=False):
super().__init__(queues, websocket, input_types, mark_set, turn_based_conversation)
def __init__(self, queues, websocket=None, input_types=None, mark_event_meta_data=None, turn_based_conversation=False,
is_welcome_message_played=False, observable_variables=None):
super().__init__(queues, websocket, input_types, mark_event_meta_data, turn_based_conversation,
is_welcome_message_played=is_welcome_message_played, observable_variables=observable_variables)
self.io_provider = 'exotel'

async def call_start(self, packet):
start = packet['start']
self.call_sid = start['call_sid']
self.stream_sid = start['stream_sid']

def get_mark_event_meta_data_obj(self, packet):
pass
10 changes: 8 additions & 2 deletions bolna/input_handlers/telephony_providers/plivo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@


class PlivoInputHandler(TelephonyInputHandler):
def __init__(self, queues, websocket=None, input_types=None, mark_set=None, turn_based_conversation=False):
super().__init__(queues, websocket, input_types, mark_set, turn_based_conversation)
def __init__(self, queues, websocket=None, input_types=None, mark_event_meta_data=None, turn_based_conversation=False,
is_welcome_message_played=False, observable_variables=None):
super().__init__(queues, websocket, input_types, mark_event_meta_data, turn_based_conversation,
is_welcome_message_played=is_welcome_message_played, observable_variables=observable_variables)
self.io_provider = 'plivo'
self.client = plivosdk.RestClient(os.getenv('PLIVO_AUTH_ID'), os.getenv('PLIVO_AUTH_TOKEN'))

Expand All @@ -24,3 +26,7 @@ async def disconnect_stream(self):
self.client.calls.delete_all_streams(self.call_sid)
except Exception as e:
logger.info('Error deleting plivo stream: {}'.format(str(e)))

def get_mark_event_meta_data_obj(self, packet):
mark_id = packet["name"]
return self.mark_event_meta_data.fetch_data(mark_id)
10 changes: 8 additions & 2 deletions bolna/input_handlers/telephony_providers/twilio.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,17 @@


class TwilioInputHandler(TelephonyInputHandler):
def __init__(self, queues, websocket=None, input_types=None, mark_set=None, turn_based_conversation=False):
super().__init__(queues, websocket, input_types, mark_set, turn_based_conversation)
def __init__(self, queues, websocket=None, input_types=None, mark_event_meta_data=None, turn_based_conversation=False,
is_welcome_message_played=False, observable_variables=None):
super().__init__(queues, websocket, input_types, mark_event_meta_data, turn_based_conversation,
is_welcome_message_played=is_welcome_message_played, observable_variables=observable_variables)
self.io_provider = 'twilio'

async def call_start(self, packet):
start = packet['start']
self.call_sid = start['callSid']
self.stream_sid = start['streamSid']

def get_mark_event_meta_data_obj(self, packet):
mark_id = packet["mark"]["name"]
return self.mark_event_meta_data.fetch_data(mark_id)
2 changes: 1 addition & 1 deletion bolna/output_handlers/daily_webcall.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


class DailyOutputHandler:
def __init__(self, io_provider='daily', websocket=None, queue=None, room_url=None, mark_set=None):
def __init__(self, io_provider='daily', websocket=None, queue=None, room_url=None, mark_event_meta_data=None):
self.websocket = websocket
self.is_interruption_task_on = False
self.queue = queue
Expand Down
42 changes: 38 additions & 4 deletions bolna/output_handlers/default.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json
import uuid
import base64
from dotenv import load_dotenv
from bolna.helpers.logger_config import configure_logger
Expand All @@ -7,14 +9,16 @@


class DefaultOutputHandler:
def __init__(self, io_provider='default', websocket=None, queue=None):
def __init__(self, io_provider='default', websocket=None, queue=None, is_web_based_call=False, mark_event_meta_data=None):
self.websocket = websocket
self.is_interruption_task_on = False
self.queue = queue
self.io_provider = io_provider
self.is_chunking_supported = True
self.is_last_hangup_chunk_sent = False
self.is_welcome_message_sent = False
# self.is_welcome_message_sent = False
self.is_web_based_call = is_web_based_call
self.mark_event_meta_data = mark_event_meta_data

# @TODO Figure out the best way to handle this
async def handle_interruption(self):
Expand All @@ -28,15 +32,22 @@ def process_in_chunks(self, yield_chunks=False):
def get_provider(self):
return self.io_provider

def set_hangup_sent(self):
self.is_last_hangup_chunk_sent = True

def hangup_sent(self):
return self.is_last_hangup_chunk_sent

def welcome_message_sent(self):
return self.is_welcome_message_sent
# def welcome_message_sent(self):
# return self.is_welcome_message_sent

async def handle(self, packet):
try:
logger.info(f"Packet received:")
# if (self.is_web_based_call and packet["meta_info"].get("message_category", "") == "agent_welcome_message" and
# packet["meta_info"].get("is_final_chunk_of_entire_response", True)):
# self.is_welcome_message_sent = True

data = None
if packet["meta_info"]['type'] in ('audio', 'text'):
if packet["meta_info"]['type'] == 'audio':
Expand All @@ -50,6 +61,29 @@ async def handle(self, packet):
response = {"data": data, "type": packet["meta_info"]['type']}
await self.websocket.send_json(response)

# sending mark message for type of audio
if packet["meta_info"]['type'] == 'audio':
meta_info = packet["meta_info"]
mark_event_meta_data = {
"text_synthesized": "" if meta_info["sequence_id"] == -1 else meta_info.get("text_synthesized",
""),
"type": meta_info.get('message_category', ''),
"is_first_chunk": meta_info.get("is_first_chunk", False),
"is_final_chunk": True if (meta_info["sequence_id"] == -1 or (
meta_info.get("end_of_llm_stream", False) and meta_info.get(
"end_of_synthesizer_stream", False))) else False,
"sequence_id": meta_info["sequence_id"]
}
mark_id = meta_info.get("mark_id") if (
meta_info.get("mark_id") and meta_info.get("mark_id") != "") else str(uuid.uuid4())
logger.info(f"Mark meta data being saved for mark id - {mark_id} is - {mark_event_meta_data}")
self.mark_event_meta_data.update_data(mark_id, mark_event_meta_data)
mark_message = {
"type": "mark",
"name": mark_id
}
logger.info(f"Sending mark event - {mark_message}")
await self.websocket.send_text(json.dumps(mark_message))
else:
logger.error("Other modalities are not implemented yet")
except Exception as e:
Expand Down
Loading

0 comments on commit 74b7225

Please sign in to comment.