-
Notifications
You must be signed in to change notification settings - Fork 90
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
82 changed files
with
6,663 additions
and
713 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import * as Eff from 'effect' | ||
import invariant from 'invariant' | ||
|
||
import * as React from 'react' | ||
|
||
import * as AWS from 'utils/AWS' | ||
import * as Actor from 'utils/Actor' | ||
|
||
import * as Bedrock from './Bedrock' | ||
import * as Context from './Context' | ||
import * as Conversation from './Conversation' | ||
import * as GlobalContext from './GlobalContext' | ||
import useIsEnabled from './enabled' | ||
|
||
export const DISABLED = Symbol('DISABLED') | ||
|
||
function usePassThru<T>(val: T) { | ||
const ref = React.useRef(val) | ||
ref.current = val | ||
return ref | ||
} | ||
|
||
function useConstructAssistantAPI() { | ||
const passThru = usePassThru({ | ||
bedrock: AWS.Bedrock.useClient(), | ||
context: Context.useLayer(), | ||
}) | ||
const layerEff = Eff.Effect.sync(() => | ||
Eff.Layer.merge( | ||
Bedrock.LLMBedrock(passThru.current.bedrock), | ||
passThru.current.context, | ||
), | ||
) | ||
const [state, dispatch] = Actor.useActorLayer( | ||
Conversation.ConversationActor, | ||
Conversation.init, | ||
layerEff, | ||
) | ||
|
||
GlobalContext.use() | ||
|
||
// XXX: move this to actor state? | ||
const [visible, setVisible] = React.useState(false) | ||
const show = React.useCallback(() => setVisible(true), []) | ||
const hide = React.useCallback(() => setVisible(false), []) | ||
|
||
const assist = React.useCallback( | ||
(msg?: string) => { | ||
if (msg) dispatch(Conversation.Action.Ask({ content: msg })) | ||
show() | ||
}, | ||
[show, dispatch], | ||
) | ||
|
||
return { | ||
visible, | ||
show, | ||
hide, | ||
assist, | ||
state, | ||
dispatch, | ||
} | ||
} | ||
|
||
type AssistantAPI = ReturnType<typeof useConstructAssistantAPI> | ||
|
||
const Ctx = React.createContext<AssistantAPI | typeof DISABLED | null>(null) | ||
|
||
function AssistantAPIProvider({ children }: React.PropsWithChildren<{}>) { | ||
return <Ctx.Provider value={useConstructAssistantAPI()}>{children}</Ctx.Provider> | ||
} | ||
|
||
function DisabledAPIProvider({ children }: React.PropsWithChildren<{}>) { | ||
return <Ctx.Provider value={DISABLED}>{children}</Ctx.Provider> | ||
} | ||
|
||
export function AssistantProvider({ children }: React.PropsWithChildren<{}>) { | ||
return useIsEnabled() ? ( | ||
<Context.ContextAggregatorProvider> | ||
<AssistantAPIProvider>{children}</AssistantAPIProvider> | ||
</Context.ContextAggregatorProvider> | ||
) : ( | ||
<DisabledAPIProvider>{children}</DisabledAPIProvider> | ||
) | ||
} | ||
|
||
export function useAssistantAPI() { | ||
const api = React.useContext(Ctx) | ||
invariant(api, 'AssistantAPI must be used within an AssistantProvider') | ||
return api === DISABLED ? null : api | ||
} | ||
|
||
export function useAssistant() { | ||
return useAssistantAPI()?.assist | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
import BedrockRuntime from 'aws-sdk/clients/bedrockruntime' | ||
import * as Eff from 'effect' | ||
|
||
import * as Log from 'utils/Logging' | ||
|
||
import * as Content from './Content' | ||
import * as LLM from './LLM' | ||
|
||
const MODULE = 'Bedrock' | ||
|
||
const MODEL_ID = 'us.anthropic.claude-3-5-sonnet-20240620-v1:0' | ||
|
||
const mapContent = (contentBlocks: BedrockRuntime.ContentBlocks | undefined) => | ||
Eff.pipe( | ||
contentBlocks, | ||
Eff.Option.fromNullable, | ||
Eff.Option.map( | ||
Eff.Array.flatMapNullable((c) => { | ||
if (c.document) { | ||
return Content.ResponseMessageContentBlock.Document({ | ||
format: c.document.format as $TSFixMe, | ||
source: c.document.source.bytes as $TSFixMe, | ||
name: c.document.name, | ||
}) | ||
} | ||
if (c.image) { | ||
return Content.ResponseMessageContentBlock.Image({ | ||
format: c.image.format as $TSFixMe, | ||
source: c.image.source.bytes as $TSFixMe, | ||
}) | ||
} | ||
if (c.text) { | ||
return Content.ResponseMessageContentBlock.Text({ text: c.text }) | ||
} | ||
if (c.toolUse) { | ||
return Content.ResponseMessageContentBlock.ToolUse(c.toolUse as $TSFixMe) | ||
} | ||
// if (c.guardContent) { | ||
// // TODO | ||
// return acc | ||
// } | ||
// if (c.toolResult) { | ||
// // XXX: is it really supposed to occur here in LLM response? | ||
// return acc | ||
// } | ||
return null | ||
}), | ||
), | ||
) | ||
|
||
// TODO: use Schema | ||
const contentToBedrock = Content.PromptMessageContentBlock.$match({ | ||
GuardContent: ({ text }) => ({ guardContent: { text: { text } } }), | ||
ToolResult: ({ toolUseId, status, content }) => ({ | ||
toolResult: { | ||
toolUseId, | ||
status, | ||
content: content.map( | ||
Content.ToolResultContentBlock.$match({ | ||
Json: ({ _tag, ...rest }) => rest, | ||
Text: ({ _tag, ...rest }) => rest, | ||
// XXX: be careful with base64/non-base64 encoding | ||
Image: ({ format, source }) => ({ | ||
image: { format, source: { bytes: source } }, | ||
}), | ||
Document: ({ format, source, name }) => ({ | ||
document: { format, source: { bytes: source }, name }, | ||
}), | ||
}), | ||
), | ||
}, | ||
}), | ||
ToolUse: ({ _tag, ...toolUse }) => ({ toolUse }), | ||
Text: ({ _tag, ...rest }) => rest, | ||
Image: ({ format, source }) => ({ image: { format, source: { bytes: source } } }), | ||
Document: ({ format, source, name }) => ({ | ||
document: { format, source: { bytes: source }, name }, | ||
}), | ||
}) | ||
|
||
const messagesToBedrock = ( | ||
messages: Eff.Array.NonEmptyArray<LLM.PromptMessage>, | ||
): BedrockRuntime.Message[] => | ||
// create an array of alternating assistant and user messages | ||
Eff.pipe( | ||
messages, | ||
Eff.Array.groupWith((m1, m2) => m1.role === m2.role), | ||
Eff.Array.map((group) => ({ | ||
role: group[0].role, | ||
content: group.map((m) => contentToBedrock(m.content)), | ||
})), | ||
) | ||
|
||
const toolConfigToBedrock = ( | ||
toolConfig: LLM.ToolConfig, | ||
): BedrockRuntime.ToolConfiguration => ({ | ||
tools: Object.entries(toolConfig.tools).map(([name, { description, schema }]) => ({ | ||
toolSpec: { | ||
name, | ||
description, | ||
inputSchema: { json: schema }, | ||
}, | ||
})), | ||
toolChoice: | ||
toolConfig.choice && | ||
LLM.ToolChoice.$match(toolConfig.choice, { | ||
Auto: () => ({ auto: {} }), | ||
Any: () => ({ any: {} }), | ||
Specific: ({ name }) => ({ tool: { name } }), | ||
}), | ||
}) | ||
|
||
// a layer providing the service over aws.bedrock | ||
export function LLMBedrock(bedrock: BedrockRuntime) { | ||
const converse = (prompt: LLM.Prompt, opts?: LLM.Options) => | ||
Log.scoped({ | ||
name: `${MODULE}.converse`, | ||
enter: [ | ||
Log.br, | ||
'model id:', | ||
MODEL_ID, | ||
Log.br, | ||
'prompt:', | ||
prompt, | ||
Log.br, | ||
'opts:', | ||
opts, | ||
], | ||
})( | ||
Eff.Effect.tryPromise(() => | ||
bedrock | ||
.converse({ | ||
modelId: MODEL_ID, | ||
system: [{ text: prompt.system }], | ||
messages: messagesToBedrock(prompt.messages), | ||
toolConfig: prompt.toolConfig && toolConfigToBedrock(prompt.toolConfig), | ||
...opts, | ||
}) | ||
.promise() | ||
.then((backendResponse) => ({ | ||
backendResponse, | ||
content: mapContent(backendResponse.output.message?.content), | ||
})), | ||
), | ||
) | ||
|
||
return Eff.Layer.succeed(LLM.LLM, { converse }) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import * as Eff from 'effect' | ||
|
||
import { JsonRecord } from 'utils/types' | ||
|
||
// XXX: schema for en/decoding to/from aws bedrock types? | ||
|
||
export const DOCUMENT_FORMATS = [ | ||
'pdf', | ||
'csv', | ||
'doc', | ||
'docx', | ||
'xls', | ||
'xlsx', | ||
'html', | ||
'txt', | ||
'md', | ||
] as const | ||
export type DocumentFormat = (typeof DOCUMENT_FORMATS)[number] | ||
|
||
export interface DocumentBlock { | ||
format: DocumentFormat | ||
name: string | ||
// A base64-encoded string of a UTF-8 encoded file, that is the document to include in the message. | ||
source: Buffer | Uint8Array | Blob | string | ||
} | ||
|
||
export const IMAGE_FORMATS = ['png', 'jpeg', 'gif', 'webp'] as const | ||
export type ImageFormat = (typeof IMAGE_FORMATS)[number] | ||
|
||
export interface ImageBlock { | ||
format: ImageFormat | ||
// The raw image bytes for the image. If you use an AWS SDK, you don't need to base64 encode the image bytes. | ||
source: Buffer | Uint8Array | Blob | string | ||
} | ||
|
||
export interface JsonBlock { | ||
json: JsonRecord | ||
} | ||
|
||
export interface TextBlock { | ||
text: string | ||
} | ||
|
||
export interface GuardBlock { | ||
text: string | ||
} | ||
|
||
export interface ToolUseBlock { | ||
toolUseId: string | ||
name: string | ||
input: JsonRecord | ||
} | ||
|
||
export type ToolResultContentBlock = Eff.Data.TaggedEnum<{ | ||
Json: JsonBlock | ||
Text: TextBlock | ||
Image: ImageBlock | ||
Document: DocumentBlock | ||
}> | ||
|
||
// eslint-disable-next-line @typescript-eslint/no-redeclare | ||
export const ToolResultContentBlock = Eff.Data.taggedEnum<ToolResultContentBlock>() | ||
|
||
export type ToolResultStatus = 'success' | 'error' | ||
|
||
export interface ToolResultBlock { | ||
toolUseId: string | ||
content: ToolResultContentBlock[] | ||
status: ToolResultStatus | ||
} | ||
|
||
export type ResponseMessageContentBlock = Eff.Data.TaggedEnum<{ | ||
// GuardContent: {} | ||
// ToolResult: {} | ||
ToolUse: ToolUseBlock | ||
Text: TextBlock | ||
Image: ImageBlock | ||
Document: DocumentBlock | ||
}> | ||
|
||
// eslint-disable-next-line @typescript-eslint/no-redeclare | ||
export const ResponseMessageContentBlock = | ||
Eff.Data.taggedEnum<ResponseMessageContentBlock>() | ||
|
||
export type MessageContentBlock = Eff.Data.TaggedEnum<{ | ||
Text: TextBlock | ||
Image: ImageBlock | ||
Document: DocumentBlock | ||
}> | ||
|
||
// eslint-disable-next-line @typescript-eslint/no-redeclare | ||
export const MessageContentBlock = Eff.Data.taggedEnum<MessageContentBlock>() | ||
|
||
export type PromptMessageContentBlock = Eff.Data.TaggedEnum<{ | ||
GuardContent: GuardBlock | ||
ToolResult: ToolResultBlock | ||
ToolUse: ToolUseBlock | ||
Text: TextBlock | ||
Image: ImageBlock | ||
Document: DocumentBlock | ||
}> | ||
|
||
// eslint-disable-next-line @typescript-eslint/no-redeclare | ||
export const PromptMessageContentBlock = Eff.Data.taggedEnum<PromptMessageContentBlock>() | ||
|
||
export const text = (first: string, ...rest: string[]) => | ||
PromptMessageContentBlock.Text({ text: [first, ...rest].join('') }) |
Oops, something went wrong.