Skip to content

Commit

Permalink
allow inner fn to be sync
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-dixon committed Oct 22, 2024
1 parent c54b95a commit 71ea23f
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 6 deletions.
4 changes: 2 additions & 2 deletions typescript/src/lmp/_track.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ export type Invocation = {



type F = (...args: any[]) => Promise<string | Array<Message>>
type F = (...args: any[]) => Promise<string | Array<Message>> | string | Array<Message>

/**
* Used for tracing of invocations.
Expand Down Expand Up @@ -303,7 +303,7 @@ export const invokeWithTracking = async (lmp: LMPDefinition & { lmpId: string },
}

const start = performance.now()
const lmpfnoutput = await f(...args)
const lmpfnoutput = await Promise.resolve(f(...args))
// const event = await getNextPausedEvent()
// console.log('event', event)
// await handleBreakpointHit(event)
Expand Down
4 changes: 2 additions & 2 deletions typescript/src/lmp/simple.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import { EllCallParams } from '../provider'

const logger = logging.getLogger('ell')

type SimpleLMPInner = (...args: any[]) => Promise<string | Array<Message>>
type SimpleLMPInner = (...args: any[]) => string | Array<Message> | Promise<string | Array<Message>>
type SimpleLMP<A extends SimpleLMPInner> = ((...args: Parameters<A>) => Promise<string>) & {
__ell_type__?: 'simple'
__ell_lmp_name__?: string
Expand Down Expand Up @@ -49,7 +49,7 @@ export const simple = <F extends SimpleLMPInner>(a: Kwargs, f: F): SimpleLMP<F>
if (lmpId && !a.exempt_from_tracking) {
return await invokeWithTracking({ ...lmpDefinition!, lmpId }, args, f, a)
}
const promptFnOutput = await f(...args)
const promptFnOutput = await Promise.resolve(f(...args))
const modelClient = await getModelClient(a)
const provider = config.getProviderFor(modelClient)
if (!provider) {
Expand Down
3 changes: 2 additions & 1 deletion typescript/src/serialize/sql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ export type InvocationContents = {
free_vars: Record<string, any>
is_external: boolean
invocation: Invocation
// todo. created_at?
}
export const InvocationContents = (props: InvocationContents) => ({
...props,
Expand Down Expand Up @@ -138,7 +139,7 @@ class Mutex {
}

export class SQLiteStore extends Store {
private db: Database | null = null
public db: Database | null = null
private dbPath: string
private txMutex = new Mutex()

Expand Down
2 changes: 1 addition & 1 deletion typescript/test/fixtures/hello_world.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ function getRandomPunctuation(): string {
return randomChoice(['!', '!!', '!!!'])
}

export const hello = ell.simple({ model: 'gpt-4o-mini' }, async (name: string) => {
export const hello = ell.simple({ model: 'gpt-4o-mini' }, (name: string) => {
const adjective = getRandomAdjective()
const punctuation = getRandomPunctuation()

Expand Down
20 changes: 20 additions & 0 deletions typescript/test/runtime.mocha.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,24 @@ describe('lmp', () => {
const result = await child2('world')
assert.deepStrictEqual(result, new Message('assistant', 'child'))
})

test('sync prompt functions', async () => {
const child = simple({ model: 'gpt-4o-mini' }, (a: string) => {
return 'child'
})
const hello = simple({ model: 'gpt-4o' }, async (a: { a: string }) => {
const ok = await child(a.a)
return a.a + ok
})

const result = await hello({ a: 'world' })

assert.equal(result, 'worldchild')

assert.ok(hello.__ell_lmp_id__?.startsWith('lmp-'))
assert.equal(hello.__ell_lmp_name__, 'hello')

assert.ok(child.__ell_lmp_id__?.startsWith('lmp-'))
assert.equal(child.__ell_lmp_name__, 'child')
})
})
115 changes: 115 additions & 0 deletions typescript/test/tracing.mocha.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import assert from 'assert'
import OpenAI from 'openai'
import { config } from '../src/configurator'
import { chatCompletionsToStream } from './util'
import { SQLiteStore } from '../src/serialize/sql'
import * as ell from 'ell-ai'

describe('tracing', () => {
let store: SQLiteStore
beforeEach(async () => {
store = new SQLiteStore(':memory:')
await store.initialize()
ell.init({ store })

config.defaultClient = config.defaultClient || new OpenAI({ apiKey: 'test' })
// @ts-expect-error
config.defaultClient.chat.completions.create = async (...args) => {
return chatCompletionsToStream([
<OpenAI.Chat.Completions.ChatCompletion>{
usage: {
prompt_tokens: 10,
completion_tokens: 10,
latency_ms: 10,
total_tokens: 20,
},
id: 'chatcmpl-123',
created: 1677652288,
model: 'gpt-3.5-turbo-0125',
object: 'chat.completion',
choices: [
<OpenAI.Chat.Completions.ChatCompletion.Choice>{
index: 0,
finish_reason: 'stop',
logprobs: null,
message: {
// @ts-expect-error
content: args[0].messages[0].content[0].text,
role: 'assistant',
refusal: null,
},
},
],
},
])
}
})

it('simple', async () => {
const hello = require('./fixtures/hello_world')
const result = await hello.hello('world')

assert.equal(result, 'You are a helpful and expressive assistant.')

const lmp = (await store.db?.all('SELECT * FROM serializedlmp'))?.[0]
assert.ok(typeof lmp.created_at === 'string')
delete lmp.created_at
assert.deepEqual(lmp, {
lmp_id: 'lmp-a79d4140040f36d6c8074901fd00d769',
name: 'test.fixtures.hello_world.hello',
source:
'export const hello = ell.simple({ model: \'gpt-4o-mini\' }, (name: string) => {\n const adjective = getRandomAdjective()\n const punctuation = getRandomPunctuation()\n\n return [\n ell.system(\'You are a helpful and expressive assistant.\'),\n ell.user(`Say a ${adjective} hello to ${name}${punctuation}`),\n ] \n})',
language: 'typescript',
dependencies: '',
lmp_type: 'LM',
api_params: '{"model":"gpt-4o-mini"}',
initial_free_vars: '{}',
initial_global_vars: '{}',
num_invocations: 1,
commit_message: 'Initial version',
version_number: 1,
})

const invocation = (await store.db?.all('SELECT * FROM invocation'))?.[0]
const invocationId = invocation.id

assert.ok(invocationId.startsWith('invocation-'))
delete invocation.id
assert.ok(typeof invocation.created_at === 'string')
delete invocation.created_at
assert.ok(typeof invocation.latency_ms === 'number')
delete invocation.latency_ms

assert.deepStrictEqual(invocation, {
lmp_id: lmp.lmp_id,
prompt_tokens: null,
completion_tokens: null,
state_cache_key: '',
used_by_id: null,
})

const invocationContents = (await store.db?.all('SELECT * FROM invocationcontents'))?.[0]

assert.equal(invocationContents?.invocation_id, invocationId)
delete invocationContents.invocation_id

// Free vars
const freeVars = JSON.parse(invocationContents.free_vars)
assert.deepEqual(freeVars.name, 'world')
assert.ok(['enthusiastic', 'cheerful', 'warm', 'friendly', 'heartfelt', 'sincere'].includes(freeVars.adjective))
assert.ok(['!', '!!', '!!!'].includes(freeVars.punctuation))
delete invocationContents.free_vars

// Global vars
const globalVars = JSON.parse(invocationContents.global_vars)
assert.deepEqual(globalVars, {})
delete invocationContents.global_vars

assert.deepStrictEqual(invocationContents, {
'invocation_api_params': '{"model":"gpt-4o-mini"}',
'is_external': 0,
'params': '["world"]',
'results': '"You are a helpful and expressive assistant."',
})
})
})

0 comments on commit 71ea23f

Please sign in to comment.