Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] working audio input #276

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions examples/openai_audio.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
import ell

from pydub import AudioSegment
import numpy as np

# Helper function to load and convert audio files
def load_audio_sample(file_path):
audio = AudioSegment.from_file(file_path)
samplearray = np.array(audio.get_array_of_samples())
return samplearray


ell.init(verbose=True)

@ell.complex("gpt-4o-audio-preview")
def test():
return [ell.user("Hey! Could you talk to me in spanish? I'd like to hear how you say 'ell'.")]
return [ell.user(["Hey! what do you think about this?", load_audio_sample("toronto.mp3")])]

response = test()
print(response.audios[0])

if __name__ == "__main__":
test()

response = test()
print(response.audios[0])
Binary file added examples/toronto.mp3
Binary file not shown.
6 changes: 5 additions & 1 deletion src/ell/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import json
from ell.configurator import _Model, config, register_provider
from ell.types.message import LMP
from ell.util.serialization import serialize_image
from ell.util.serialization import array_buffer_to_base64, serialize_image

try:
# XXX: Could genericize.
Expand Down Expand Up @@ -122,6 +122,7 @@ def translate_from_provider(
logger(delta.content, is_refusal=hasattr(delta, "refusal") and delta.refusal)
for _, message_stream in sorted(message_streams.items(), key=lambda x: x[0]):
text = "".join((choice.delta.content or "") for choice in message_stream)
# XXX: API might be close to something else.
messages.append(
Message(role=role,
content=_lstr(content=text,origin_trace=origin_id)))
Expand All @@ -144,6 +145,7 @@ def translate_from_provider(
ContentBlock(
text=_lstr(content=content,origin_trace=origin_id)))
if logger: logger(content)
#XXX: Streaming tool calls are coming.
if (tool_calls := message.tool_calls):
for tool_call in tool_calls:
matching_tool = ell_call.get_tool_by_name(tool_call.function.name)
Expand Down Expand Up @@ -178,6 +180,8 @@ def _content_block_to_openai_format(content_block: ContentBlock) -> Dict[str, An
"type": "image_url",
"image_url": image_url
}
elif (audio := content_block.audio) is not None:
return dict(type="input_audio", audio=array_buffer_to_base64(audio))
elif ((text := content_block.text) is not None): return dict(type="text", text=text)
elif (parsed := content_block.parsed): return dict(type="text", text=parsed.model_dump_json())
else:
Expand Down
22 changes: 19 additions & 3 deletions src/ell/types/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,10 @@ def coerce(cls, content: AnyContent) -> "ContentBlock":
return cls(tool_call=content)
if isinstance(content, ToolResult):
return cls(tool_result=content)
if isinstance(content, (ImageContent, np.ndarray, PILImage.Image)):
if isinstance(content, (ImageContent, PILImage.Image)) or (isinstance(content, np.ndarray) and content.ndim >= 3):
return cls(image=ImageContent.coerce(content))
if isinstance(content, np.ndarray) and content.ndim == 1:
return cls(audio=content)
if isinstance(content, BaseModel):
return cls(parsed=content)

Expand Down Expand Up @@ -352,8 +354,8 @@ def images(self) -> List[ImageContent]:
return [c.image for c in self.content if c.image]

@property
def audios(self) -> List[Union[np.ndarray, List[float]]]:
"""Returns a list of all audio content.
def audios(self) -> List[Union[np.ndarray, List[float], List[int]]]:
"""Returns a list of all audio content in each content block..

Example:
>>> audio1 = np.array([0.1, 0.2, 0.3])
Expand All @@ -363,6 +365,20 @@ def audios(self) -> List[Union[np.ndarray, List[float]]]:
2
"""
return [c.audio for c in self.content if c.audio]


@property
def audio(self) -> np.ndarray:
"""Returns the first audio content.

Example:
>>> audio1 = np.array([0.1, 0.2, 0.3])
>>> message = Message(role="user", content=["Text", audio1, "More text"])
>>> message.audio
array([0.1, 0.2, 0.3])
"""
return np.concatenate(self.audios)


@property
def text_only(self) -> str:
Expand Down
25 changes: 25 additions & 0 deletions src/ell/util/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,31 @@ def compute_state_cache_key(ipstr, fn_closure):
_free_vars_str = f"{json.dumps(get_immutable_vars(fn_closure[3]), sort_keys=True, default=repr)}"
state_cache_key = hashlib.sha256(f"{ipstr}{_global_free_vars_str}{_free_vars_str}".encode('utf-8')).hexdigest()
return state_cache_key


def float_to_16bit_pcm(float32_array):
int16_array = (np.clip(float32_array, -1, 1) * 32767).astype(np.int16)
return int16_array.tobytes()

def base64_to_array_buffer(base64_string):
return base64.b64decode(base64_string)

def array_buffer_to_base64(array_buffer):
if isinstance(array_buffer, np.ndarray):
if array_buffer.dtype == np.float32:
array_buffer = float_to_16bit_pcm(array_buffer)
elif array_buffer.dtype == np.int16:
array_buffer = array_buffer.tobytes()
return base64.b64encode(array_buffer).decode('utf-8')

def merge_int16_arrays(left, right):
if isinstance(left, bytes):
left = np.frombuffer(left, dtype=np.int16)
if isinstance(right, bytes):
right = np.frombuffer(right, dtype=np.int16)
if not isinstance(left, np.ndarray) or not isinstance(right, np.ndarray):
raise ValueError("Both items must be numpy arrays or bytes objects")
return np.concatenate((left, right))


def prepare_invocation_params(params):
Expand Down
47 changes: 47 additions & 0 deletions src/ell/util/verbosity.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,51 @@ def get_terminal_width() -> int:
logger.warning("Unable to determine terminal size. Defaulting to 80 columns.")
return 80

import numpy as np
def plot_ascii_waveform(audio_data: np.ndarray, width: int = 80) -> List[str]:
"""
Plot an improved ASCII waveform of the given audio data with a height of 1.

Args:
audio_data (np.ndarray): The audio data to plot.
width (int): The width of the ASCII plot.

Returns:
List[str]: A list of strings representing the ASCII waveform.
"""
if audio_data.ndim != 1:
raise ValueError("Audio data must be a 1D numpy array")

# Normalize audio data to fit within the range [0, 1]
normalized_data = (audio_data - np.min(audio_data)) / (np.max(audio_data) - np.min(audio_data))

# Create the ASCII waveform
step = max(1, len(audio_data) // width)

# Characters for different amplitudes
chars = ' ▁▂▃▄▅▆▇█'

waveform = ''
for i in range(0, len(audio_data), step):
char_index = int(normalized_data[i] * (len(chars) - 1))
waveform += chars[char_index]

# Add top and bottom borders
border = '─' * width
waveform = [f'╭{border}╮', f'│{waveform}│', f'╰{border}╯']

# Add audio label
label = "Audio ContentBlock"
label_position = (width - len(label)) // 2
waveform[0] = (
waveform[0][:label_position] +
label +
waveform[0][label_position + len(label):]
)

return waveform


def wrap_text_with_prefix(message, width: int, prefix: str, subsequent_prefix: str, text_color: str) -> List[str]:
"""Wrap text while preserving the prefix and color for each line."""
result = []
Expand All @@ -102,6 +147,8 @@ def wrap_text_with_prefix(message, width: int, prefix: str, subsequent_prefix: s
for c in contnets_to_wrap:
if c.image and c.image.image:
block_wrapped_lines = plot_ascii(c.image.image, min(80, width - len(prefix)))
elif c.audio is not None:
block_wrapped_lines = plot_ascii_waveform(c.audio)
else:
text = _content_to_text([c])
paragraphs = text.split('\n')
Expand Down
Loading