Skip to content

Commit

Permalink
add observation support
Browse files Browse the repository at this point in the history
  • Loading branch information
alexcheng1982 committed Apr 5, 2024
1 parent 97e6257 commit 3e9993b
Show file tree
Hide file tree
Showing 18 changed files with 393 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ data class NextStep(
}

data class AgentExecutor(
val agent: Planner,
val planner: Planner,
val nameToToolMap: Map<String, FunctionCallback>,
val returnIntermediateSteps: Boolean = false,
val maxIterations: Int? = 10,
Expand Down Expand Up @@ -71,7 +71,10 @@ data class AgentExecutor(
Duration.ofMillis(timeElapsed)
)
val output =
agent.returnStoppedResponse(earlyStoppingMethod, intermediateSteps)
planner.returnStoppedResponse(
earlyStoppingMethod,
intermediateSteps
)
return returnResult(output, intermediateSteps)
}

Expand Down Expand Up @@ -119,7 +122,7 @@ data class AgentExecutor(
): MutableList<Plannable> {
val result = mutableListOf<Plannable>()
try {
val action = { agent.plan(inputs, intermediateSteps) }
val action = { planner.plan(inputs, intermediateSteps) }
val output = observationRegistry?.let { registry ->
Observation.createNotStarted("agent.execution.plan", registry)
.observe(action)
Expand Down Expand Up @@ -149,7 +152,10 @@ data class AgentExecutor(
agentAction: AgentAction
): AgentStep {
val agentTool =
nameToToolMap[agentAction.tool] ?: return AgentStep(agentAction, "Invalid tool")
nameToToolMap[agentAction.tool] ?: return AgentStep(
agentAction,
"Invalid tool"
)
val (tool, toolInput) = agentAction
logger.info(
"Start executing tool [{}] with request [{}]",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package io.github.alexcheng1982.agentappbuilder.core.observation

import io.micrometer.common.KeyValue
import io.micrometer.common.KeyValues
import io.micrometer.common.docs.KeyName
import io.micrometer.observation.Observation
import io.micrometer.observation.ObservationConvention
import io.micrometer.observation.ObservationRegistry
import io.micrometer.observation.docs.ObservationDocumentation
import io.micrometer.observation.transport.RequestReplySenderContext
import org.springframework.ai.chat.ChatClient
import org.springframework.ai.chat.ChatResponse
import org.springframework.ai.chat.prompt.Prompt

enum class ChatClientObservationDocumentation : ObservationDocumentation {
CHAT_CLIENT_CALL {
override fun getDefaultConvention(): Class<out ObservationConvention<out Observation.Context>>? {
return DefaultChatClientObservationConvention::class.java
}

override fun getLowCardinalityKeyNames(): Array<KeyName> {
return LowCardinalityKeyNames.values().toList().toTypedArray()
}

override fun getHighCardinalityKeyNames(): Array<KeyName> {
return HighCardinalityKeyNames.values().toList().toTypedArray()
}
};

enum class LowCardinalityKeyNames : KeyName

enum class HighCardinalityKeyNames : KeyName {
PROMPT_CONTENT {
override fun asString(): String {
return "prompt.content"
}
},
RESPONSE_CONTENT {
override fun asString(): String {
return "response.content"
}

}
}
}

class DefaultChatClientObservationConvention(private val name: String? = null) :
ChatClientObservationConvention {
private val defaultName = "chat.client.call"

private val promptContentNone: KeyValue = KeyValue.of(
ChatClientObservationDocumentation.HighCardinalityKeyNames.PROMPT_CONTENT,
KeyValue.NONE_VALUE
)

private val responseContentNone: KeyValue = KeyValue.of(
ChatClientObservationDocumentation.HighCardinalityKeyNames.RESPONSE_CONTENT,
KeyValue.NONE_VALUE
)

override fun getName(): String {
return name ?: defaultName
}

override fun getLowCardinalityKeyValues(context: ChatClientRequestObservationContext): KeyValues {
return KeyValues.empty()
}

override fun getHighCardinalityKeyValues(context: ChatClientRequestObservationContext): KeyValues {
return KeyValues.of(promptContent(context), responseContent(context))
}

private fun promptContent(context: ChatClientRequestObservationContext): KeyValue {
return context.carrier?.contents?.let { content ->
KeyValue.of(
ChatClientObservationDocumentation.HighCardinalityKeyNames.PROMPT_CONTENT,
content
)
} ?: promptContentNone
}

private fun responseContent(context: ChatClientRequestObservationContext): KeyValue {
return context.response?.let { response ->
response.results.joinToString("\n") { it.output?.content ?: "" }
.let { content ->
KeyValue.of(
ChatClientObservationDocumentation.HighCardinalityKeyNames.RESPONSE_CONTENT,
content
)
}
} ?: responseContentNone
}
}

interface ChatClientObservationConvention :
ObservationConvention<ChatClientRequestObservationContext> {
override fun supportsContext(context: Observation.Context): Boolean {
return context is ChatClientRequestObservationContext
}
}


class ChatClientRequestObservationContext(val prompt: Prompt) :
RequestReplySenderContext<Prompt, ChatResponse>({ _, _, _ ->
run {}
}) {
init {
setCarrier(prompt)
}
}

class InstrumentedChatClient(
private val chatClient: ChatClient,
private val observationRegistry: ObservationRegistry? = null,
) : ChatClient {
override fun call(prompt: Prompt): ChatResponse {
val action = { chatClient.call(prompt) }
return observationRegistry?.let { registry ->
instrumentedCall(prompt, action, registry)
} ?: action.invoke()
}

private fun instrumentedCall(
prompt: Prompt,
action: () -> ChatResponse,
registry: ObservationRegistry
): ChatResponse {
val observationContext =
ChatClientRequestObservationContext(prompt)
val observation =
ChatClientObservationDocumentation.CHAT_CLIENT_CALL.observation(
null,
DefaultChatClientObservationConvention(),
{ observationContext },
registry
).start()
return try {
observation.openScope().use {
val response = action.invoke()
observationContext.setResponse(response)
response
}
} catch (e: Exception) {
observation.error(e)
throw e
} finally {
observation.stop()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,18 @@ import io.github.alexcheng1982.agentappbuilder.core.chatmemory.ChatMemory
import io.github.alexcheng1982.agentappbuilder.core.chatmemory.ChatMemoryStore
import io.github.alexcheng1982.agentappbuilder.core.chatmemory.MessageWindowChatMemory
import io.github.alexcheng1982.agentappbuilder.core.executor.ActionPlanningResult
import io.github.alexcheng1982.agentappbuilder.core.observation.InstrumentedChatClient
import io.github.alexcheng1982.agentappbuilder.core.tool.AgentTool
import io.github.alexcheng1982.agentappbuilder.core.tool.AgentToolsProvider
import io.micrometer.observation.ObservationRegistry
import org.springframework.ai.chat.ChatClient
import org.springframework.ai.chat.messages.SystemMessage
import org.springframework.ai.chat.prompt.ChatOptions
import org.springframework.ai.chat.prompt.Prompt
import org.springframework.ai.chat.prompt.PromptTemplate

open class LLMPlanner(
private val chatClient: ChatClient,
private var chatClient: ChatClient,
private val toolsProvider: AgentToolsProvider,
private val outputParser: OutputParser,
private val userPromptTemplate: PromptTemplate,
Expand All @@ -30,7 +32,15 @@ open class LLMPlanner(
MessageWindowChatMemory(store, memoryId.toString(), 10)
}
},
observationRegistry: ObservationRegistry? = null,
) : Planner {
init {
chatClient =
if (chatClient is InstrumentedChatClient) chatClient else InstrumentedChatClient(
chatClient, observationRegistry
)
}

override fun plan(
inputs: Map<String, Any>,
intermediateSteps: List<IntermediateAgentStep>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import io.github.alexcheng1982.agentappbuilder.core.chatmemory.ChatMemoryStore
import io.github.alexcheng1982.agentappbuilder.core.planner.LLMPlanner
import io.github.alexcheng1982.agentappbuilder.core.tool.AgentToolsProvider
import io.github.alexcheng1982.agentappbuilder.core.tool.AutoDiscoveredAgentToolsProvider
import io.micrometer.observation.ObservationRegistry
import org.springframework.ai.chat.ChatClient
import org.springframework.ai.chat.prompt.PromptTemplate
import org.springframework.core.io.ClassPathResource
Expand All @@ -19,6 +20,7 @@ class NoFeedbackPlanner(
systemPromptResource: Resource,
systemInstruction: String? = null,
chatMemoryStore: ChatMemoryStore? = null,
observationRegistry: ObservationRegistry? = null,
) : LLMPlanner(
chatClient,
agentToolsProvider,
Expand All @@ -27,13 +29,15 @@ class NoFeedbackPlanner(
PromptTemplate(systemPromptResource),
systemInstruction,
chatMemoryStore,
observationRegistry = observationRegistry,
) {
companion object {
fun createDefault(
chatClient: ChatClient,
agentToolsProvider: AgentToolsProvider = AutoDiscoveredAgentToolsProvider,
systemInstruction: String? = null,
chatMemoryStore: ChatMemoryStore? = null,
observationRegistry: ObservationRegistry? = null,
): NoFeedbackPlanner {
return NoFeedbackPlanner(
chatClient,
Expand All @@ -42,6 +46,7 @@ class NoFeedbackPlanner(
ClassPathResource("prompts/no-feedback/system.st"),
systemInstruction,
chatMemoryStore,
observationRegistry,
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import io.github.alexcheng1982.agentappbuilder.core.chatmemory.ChatMemoryStore
import io.github.alexcheng1982.agentappbuilder.core.planner.LLMPlanner
import io.github.alexcheng1982.agentappbuilder.core.tool.AgentToolsProvider
import io.github.alexcheng1982.agentappbuilder.core.tool.AutoDiscoveredAgentToolsProvider
import io.micrometer.observation.ObservationRegistry
import org.springframework.ai.chat.ChatClient
import org.springframework.ai.chat.prompt.PromptTemplate
import org.springframework.core.io.ClassPathResource
Expand All @@ -16,6 +17,7 @@ class ReActPlanner(
systemPromptResource: Resource,
systemInstruction: String? = null,
chatMemoryStore: ChatMemoryStore? = null,
observationRegistry: ObservationRegistry? = null,
) :
LLMPlanner(
chatClient,
Expand All @@ -25,13 +27,15 @@ class ReActPlanner(
PromptTemplate(systemPromptResource),
systemInstruction,
chatMemoryStore,
observationRegistry = observationRegistry,
) {
companion object {
fun createDefault(
chatClient: ChatClient,
agentToolsProvider: AgentToolsProvider = AutoDiscoveredAgentToolsProvider,
systemInstruction: String? = null,
chatMemoryStore: ChatMemoryStore? = null,
observationRegistry: ObservationRegistry? = null,
): ReActPlanner {
return ReActPlanner(
chatClient,
Expand All @@ -40,6 +44,7 @@ class ReActPlanner(
ClassPathResource("prompts/react/system.st"),
systemInstruction,
chatMemoryStore,
observationRegistry,
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import io.github.alexcheng1982.agentappbuilder.core.chatmemory.ChatMemoryStore
import io.github.alexcheng1982.agentappbuilder.core.planner.LLMPlanner
import io.github.alexcheng1982.agentappbuilder.core.tool.AgentToolsProvider
import io.github.alexcheng1982.agentappbuilder.core.tool.AutoDiscoveredAgentToolsProvider
import io.micrometer.observation.ObservationRegistry
import org.springframework.ai.chat.ChatClient
import org.springframework.ai.chat.prompt.PromptTemplate
import org.springframework.core.io.ClassPathResource
Expand All @@ -16,6 +17,7 @@ class ReActJsonPlanner(
systemPromptResource: Resource,
systemInstruction: String? = null,
chatMemoryStore: ChatMemoryStore? = null,
observationRegistry: ObservationRegistry? = null,
) : LLMPlanner(
chatClient,
agentToolsProvider,
Expand All @@ -24,13 +26,15 @@ class ReActJsonPlanner(
PromptTemplate(systemPromptResource),
systemInstruction,
chatMemoryStore,
observationRegistry = observationRegistry,
) {
companion object {
fun createDefault(
chatClient: ChatClient,
agentToolsProvider: AgentToolsProvider = AutoDiscoveredAgentToolsProvider,
systemInstruction: String? = null,
chatMemoryStore: ChatMemoryStore? = null,
observationRegistry: ObservationRegistry? = null,
): ReActJsonPlanner {
return ReActJsonPlanner(
chatClient,
Expand All @@ -39,6 +43,7 @@ class ReActJsonPlanner(
ClassPathResource("prompts/react-json/system.st"),
systemInstruction,
chatMemoryStore,
observationRegistry,
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import io.github.alexcheng1982.agentappbuilder.core.chatmemory.ChatMemoryStore
import io.github.alexcheng1982.agentappbuilder.core.planner.LLMPlanner
import io.github.alexcheng1982.agentappbuilder.core.tool.AgentToolsProvider
import io.github.alexcheng1982.agentappbuilder.core.tool.AutoDiscoveredAgentToolsProvider
import io.micrometer.observation.ObservationRegistry
import org.springframework.ai.chat.ChatClient
import org.springframework.ai.chat.prompt.PromptTemplate
import org.springframework.core.io.ClassPathResource
Expand All @@ -16,6 +17,7 @@ class StructuredChatPlanner(
systemPromptResource: Resource,
systemInstruction: String? = null,
chatMemoryStore: ChatMemoryStore? = null,
observationRegistry: ObservationRegistry? = null,
) : LLMPlanner(
chatClient,
agentToolsProvider,
Expand All @@ -24,13 +26,15 @@ class StructuredChatPlanner(
PromptTemplate(systemPromptResource),
systemInstruction,
chatMemoryStore,
observationRegistry = observationRegistry,
) {
companion object {
fun createDefault(
chatClient: ChatClient,
agentToolsProvider: AgentToolsProvider = AutoDiscoveredAgentToolsProvider,
systemInstruction: String? = null,
chatMemoryStore: ChatMemoryStore? = null,
observationRegistry: ObservationRegistry? = null,
): StructuredChatPlanner {
return StructuredChatPlanner(
chatClient,
Expand All @@ -39,6 +43,7 @@ class StructuredChatPlanner(
ClassPathResource("prompts/structured-chat/system.st"),
systemInstruction,
chatMemoryStore,
observationRegistry,
)
}
}
Expand Down
Loading

0 comments on commit 3e9993b

Please sign in to comment.