Skip to content

Commit

Permalink
test(tabby): add mocks and test prompt building for CompletionService (
Browse files Browse the repository at this point in the history
…#2293)

* Begin writing tests for CompletionService

* Assert output and built prompt

* Add test for ChatService

* Fix conflict
  • Loading branch information
boxbeam committed May 31, 2024
1 parent 830343d commit 13f03a1
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 0 deletions.
70 changes: 70 additions & 0 deletions crates/tabby/src/services/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,73 @@ pub async fn create_chat_service(logger: Arc<dyn EventLogger>, chat: &ModelConfi

ChatService::new(engine, logger)
}

#[cfg(test)]
mod tests {
use std::sync::Mutex;

use anyhow::Result;
use async_trait::async_trait;
use futures::StreamExt;
use tabby_inference::ChatCompletionOptions;

use super::*;

struct MockChatCompletionStream;

#[async_trait]
impl ChatCompletionStream for MockChatCompletionStream {
async fn chat_completion(
&self,
_messages: &[Message],
_options: ChatCompletionOptions,
) -> Result<BoxStream<String>> {
let s = stream! {
yield "Hello, world!".into();
};
Ok(Box::pin(s))
}
}

struct MockEventLogger(Mutex<Vec<Event>>);

impl EventLogger for MockEventLogger {
fn write(&self, x: tabby_common::api::event::LogEntry) {
self.0.lock().unwrap().push(x.event);
}
}

#[tokio::test]
async fn test_chat_service() {
let engine = Arc::new(MockChatCompletionStream);
let logger = Arc::new(MockEventLogger(Default::default()));
let service = Arc::new(ChatService::new(engine, logger.clone()));

let request = ChatCompletionRequest {
messages: vec![Message {
role: "user".into(),
content: "Hello, computer!".into(),
}],
temperature: None,
seed: None,
presence_penalty: None,
user: None,
};
let mut output = service.generate(request).await;
let response = output.next().await.unwrap();
assert_eq!(response.choices[0].delta.content, "Hello, world!");

let finish = output.next().await.unwrap();
assert_eq!(finish.choices[0].delta.content, "");
assert_eq!(finish.choices[0].finish_reason.as_ref().unwrap(), "stop");

assert!(output.next().await.is_none());

let event = &logger.0.lock().unwrap()[0];
let Event::ChatCompletion { output, .. } = event else {
panic!("Expected ChatCompletion event");
};
assert_eq!(output.role, "assistant");
assert_eq!(output.content, "Hello, world!");
}
}
87 changes: 87 additions & 0 deletions crates/tabby/src/services/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,3 +354,90 @@ pub async fn create_completion_service(

CompletionService::new(engine.clone(), code, logger, prompt_template)
}

#[cfg(test)]
mod tests {
use async_stream::stream;
use async_trait::async_trait;
use futures::stream::BoxStream;
use tabby_common::api::code::{CodeSearchError, CodeSearchQuery, CodeSearchResponse};
use tabby_inference::{CompletionOptions, CompletionStream};

use super::*;

struct MockEventLogger;

impl EventLogger for MockEventLogger {
fn write(&self, _x: api::event::LogEntry) {}
}

struct MockCompletionStream;

#[async_trait]
impl CompletionStream for MockCompletionStream {
async fn generate(&self, _prompt: &str, _options: CompletionOptions) -> BoxStream<String> {
let s = stream! {
yield r#""Hello, world!""#.into();
};

Box::pin(s)
}
}

struct MockCodeSearch;

#[async_trait]
impl CodeSearch for MockCodeSearch {
async fn search_in_language(
&self,
_query: CodeSearchQuery,
_limit: usize,
_offset: usize,
) -> Result<CodeSearchResponse, CodeSearchError> {
Ok(CodeSearchResponse {
num_hits: 0,
hits: vec![],
})
}
}

fn mock_completion_service() -> CompletionService {
let generation = CodeGeneration::new(Arc::new(MockCompletionStream));
CompletionService::new(
Arc::new(generation),
Arc::new(MockCodeSearch),
Arc::new(MockEventLogger),
Some("<pre>{prefix}<mid>{suffix}<end>".into()),
)
}

#[tokio::test]
async fn test_completion_service() {
let completion_service = mock_completion_service();
let segment = Segments {
prefix: "fn hello_world() -> &'static str {".into(),
suffix: Some("}".into()),
filepath: None,
git_url: None,
declarations: None,
relevant_snippets_from_changed_files: None,
clipboard: None,
};
let request = CompletionRequest {
language: Some("rust".into()),
segments: Some(segment.clone()),
user: None,
debug_options: None,
temperature: None,
seed: None,
};

let response = completion_service.generate(&request).await.unwrap();
assert_eq!(response.choices[0].text, r#""Hello, world!""#);

let prompt = completion_service
.prompt_builder
.build("rust", segment.clone(), &[]);
assert_eq!(prompt, "<pre>fn hello_world() -> &'static str {<mid>}<end>");
}
}

0 comments on commit 13f03a1

Please sign in to comment.