Skip to content

Commit

Permalink
working audio input
Browse files Browse the repository at this point in the history
  • Loading branch information
MadcowD committed Oct 2, 2024
1 parent d02a2f9 commit 11f6245
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 9 deletions.
17 changes: 12 additions & 5 deletions examples/openai_audio.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
import ell

from pydub import AudioSegment
import numpy as np

# Helper function to load and convert audio files
def load_audio_sample(file_path):
audio = AudioSegment.from_file(file_path)
samplearray = np.array(audio.get_array_of_samples())
return samplearray


ell.init(verbose=True)

@ell.complex("gpt-4o-audio-preview")
def test():
return [ell.user("Hey! Could you talk to me in spanish? I'd like to hear how you say 'ell'.")]
return [ell.user(["Hey! what do you think about this?", load_audio_sample("toronto.mp3")])]

response = test()
print(response.audios[0])

if __name__ == "__main__":
test()

response = test()
print(response.audios[0])
Binary file added examples/toronto.mp3
Binary file not shown.
6 changes: 5 additions & 1 deletion src/ell/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import json
from ell.configurator import _Model, config, register_provider
from ell.types.message import LMP
from ell.util.serialization import serialize_image
from ell.util.serialization import array_buffer_to_base64, serialize_image

try:
# XXX: Could genericize.
Expand Down Expand Up @@ -122,6 +122,7 @@ def translate_from_provider(
logger(delta.content, is_refusal=hasattr(delta, "refusal") and delta.refusal)
for _, message_stream in sorted(message_streams.items(), key=lambda x: x[0]):
text = "".join((choice.delta.content or "") for choice in message_stream)
# XXX: API might be close to something else.
messages.append(
Message(role=role,
content=_lstr(content=text,origin_trace=origin_id)))
Expand All @@ -144,6 +145,7 @@ def translate_from_provider(
ContentBlock(
text=_lstr(content=content,origin_trace=origin_id)))
if logger: logger(content)
#XXX: Streaming tool calls are coming.
if (tool_calls := message.tool_calls):
for tool_call in tool_calls:
matching_tool = ell_call.get_tool_by_name(tool_call.function.name)
Expand Down Expand Up @@ -178,6 +180,8 @@ def _content_block_to_openai_format(content_block: ContentBlock) -> Dict[str, An
"type": "image_url",
"image_url": image_url
}
elif (audio := content_block.audio) is not None:
return dict(type="input_audio", audio=array_buffer_to_base64(audio))
elif ((text := content_block.text) is not None): return dict(type="text", text=text)
elif (parsed := content_block.parsed): return dict(type="text", text=parsed.model_dump_json())
else:
Expand Down
22 changes: 19 additions & 3 deletions src/ell/types/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,10 @@ def coerce(cls, content: AnyContent) -> "ContentBlock":
return cls(tool_call=content)
if isinstance(content, ToolResult):
return cls(tool_result=content)
if isinstance(content, (ImageContent, np.ndarray, PILImage.Image)):
if isinstance(content, (ImageContent, PILImage.Image)) or (isinstance(content, np.ndarray) and content.ndim >= 3):
return cls(image=ImageContent.coerce(content))
if isinstance(content, np.ndarray) and content.ndim == 1:
return cls(audio=content)
if isinstance(content, BaseModel):
return cls(parsed=content)

Expand Down Expand Up @@ -352,8 +354,8 @@ def images(self) -> List[ImageContent]:
return [c.image for c in self.content if c.image]

@property
def audios(self) -> List[Union[np.ndarray, List[float]]]:
"""Returns a list of all audio content.
def audios(self) -> List[Union[np.ndarray, List[float], List[int]]]:
"""Returns a list of all audio content in each content block..
Example:
>>> audio1 = np.array([0.1, 0.2, 0.3])
Expand All @@ -363,6 +365,20 @@ def audios(self) -> List[Union[np.ndarray, List[float]]]:
2
"""
return [c.audio for c in self.content if c.audio]


@property
def audio(self) -> np.ndarray:
"""Returns the first audio content.
Example:
>>> audio1 = np.array([0.1, 0.2, 0.3])
>>> message = Message(role="user", content=["Text", audio1, "More text"])
>>> message.audio
array([0.1, 0.2, 0.3])
"""
return np.concatenate(self.audios)


@property
def text_only(self) -> str:
Expand Down
25 changes: 25 additions & 0 deletions src/ell/util/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,31 @@ def compute_state_cache_key(ipstr, fn_closure):
_free_vars_str = f"{json.dumps(get_immutable_vars(fn_closure[3]), sort_keys=True, default=repr)}"
state_cache_key = hashlib.sha256(f"{ipstr}{_global_free_vars_str}{_free_vars_str}".encode('utf-8')).hexdigest()
return state_cache_key


def float_to_16bit_pcm(float32_array):
int16_array = (np.clip(float32_array, -1, 1) * 32767).astype(np.int16)
return int16_array.tobytes()

def base64_to_array_buffer(base64_string):
return base64.b64decode(base64_string)

def array_buffer_to_base64(array_buffer):
if isinstance(array_buffer, np.ndarray):
if array_buffer.dtype == np.float32:
array_buffer = float_to_16bit_pcm(array_buffer)
elif array_buffer.dtype == np.int16:
array_buffer = array_buffer.tobytes()
return base64.b64encode(array_buffer).decode('utf-8')

def merge_int16_arrays(left, right):
if isinstance(left, bytes):
left = np.frombuffer(left, dtype=np.int16)
if isinstance(right, bytes):
right = np.frombuffer(right, dtype=np.int16)
if not isinstance(left, np.ndarray) or not isinstance(right, np.ndarray):
raise ValueError("Both items must be numpy arrays or bytes objects")
return np.concatenate((left, right))


def prepare_invocation_params(params):
Expand Down
47 changes: 47 additions & 0 deletions src/ell/util/verbosity.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,51 @@ def get_terminal_width() -> int:
logger.warning("Unable to determine terminal size. Defaulting to 80 columns.")
return 80

import numpy as np
def plot_ascii_waveform(audio_data: np.ndarray, width: int = 80) -> List[str]:
"""
Plot an improved ASCII waveform of the given audio data with a height of 1.
Args:
audio_data (np.ndarray): The audio data to plot.
width (int): The width of the ASCII plot.
Returns:
List[str]: A list of strings representing the ASCII waveform.
"""
if audio_data.ndim != 1:
raise ValueError("Audio data must be a 1D numpy array")

# Normalize audio data to fit within the range [0, 1]
normalized_data = (audio_data - np.min(audio_data)) / (np.max(audio_data) - np.min(audio_data))

# Create the ASCII waveform
step = max(1, len(audio_data) // width)

# Characters for different amplitudes
chars = ' ▁▂▃▄▅▆▇█'

waveform = ''
for i in range(0, len(audio_data), step):
char_index = int(normalized_data[i] * (len(chars) - 1))
waveform += chars[char_index]

# Add top and bottom borders
border = '─' * width
waveform = [f'╭{border}╮', f'│{waveform}│', f'╰{border}╯']

# Add audio label
label = "Audio ContentBlock"
label_position = (width - len(label)) // 2
waveform[0] = (
waveform[0][:label_position] +
label +
waveform[0][label_position + len(label):]
)

return waveform


def wrap_text_with_prefix(message, width: int, prefix: str, subsequent_prefix: str, text_color: str) -> List[str]:
"""Wrap text while preserving the prefix and color for each line."""
result = []
Expand All @@ -102,6 +147,8 @@ def wrap_text_with_prefix(message, width: int, prefix: str, subsequent_prefix: s
for c in contnets_to_wrap:
if c.image and c.image.image:
block_wrapped_lines = plot_ascii(c.image.image, min(80, width - len(prefix)))
elif c.audio is not None:
block_wrapped_lines = plot_ascii_waveform(c.audio)
else:
text = _content_to_text([c])
paragraphs = text.split('\n')
Expand Down

0 comments on commit 11f6245

Please sign in to comment.