From b2820b4faacd5553e759b98b54fb77697c6627c9 Mon Sep 17 00:00:00 2001 From: ActuallyTaylor Date: Sun, 19 Nov 2023 22:19:29 -0500 Subject: [PATCH] Adds support for the embeddings endpoint Signed-off-by: ActuallyTaylor --- .../Extensions/URLRequest+Request.swift | 2 + Sources/SLlama/LlamaRequests.swift | 4 ++ Sources/SLlama/Models/EmbeddingResults.swift | 12 +++++ Sources/SLlama/Models/EmbeddingSettings.swift | 13 +++++ .../Return Data/CompletionResponse1.json | 51 ------------------- Tests/SLlamaTests/SLlamaTests.swift | 19 +++++-- 6 files changed, 45 insertions(+), 56 deletions(-) create mode 100644 Sources/SLlama/Models/EmbeddingResults.swift create mode 100644 Sources/SLlama/Models/EmbeddingSettings.swift delete mode 100644 Tests/SLlamaTests/Return Data/CompletionResponse1.json diff --git a/Sources/SLlama/Extensions/URLRequest+Request.swift b/Sources/SLlama/Extensions/URLRequest+Request.swift index 2b84e37..306e5c0 100644 --- a/Sources/SLlama/Extensions/URLRequest+Request.swift +++ b/Sources/SLlama/Extensions/URLRequest+Request.swift @@ -13,6 +13,8 @@ extension URLRequest { httpMethod = request.method.name httpBody = try request.method.httpBody + + addValue("application/json", forHTTPHeaderField: "Content-Type") } } diff --git a/Sources/SLlama/LlamaRequests.swift b/Sources/SLlama/LlamaRequests.swift index 90c627a..aa5d4ce 100644 --- a/Sources/SLlama/LlamaRequests.swift +++ b/Sources/SLlama/LlamaRequests.swift @@ -11,4 +11,8 @@ public enum LlamaRequests { public static func completion(_ parameters: CompletionSettings) -> Request { return Request(path: "/completion", method: .post(.body(model: parameters))) } + + public static func embedding(_ parameters: EmbeddingSettings) -> Request { + return Request(path: "/embedding", method: .post(.body(model: parameters))) + } } diff --git a/Sources/SLlama/Models/EmbeddingResults.swift b/Sources/SLlama/Models/EmbeddingResults.swift new file mode 100644 index 0000000..ccb285d --- /dev/null +++ b/Sources/SLlama/Models/EmbeddingResults.swift @@ -0,0 +1,12 @@ +// +// File.swift +// +// +// Created by Taylor Lineman on 11/19/23. +// + +import Foundation + +public struct EmbeddingResults: Codable { + public let embedding: [Double] +} diff --git a/Sources/SLlama/Models/EmbeddingSettings.swift b/Sources/SLlama/Models/EmbeddingSettings.swift new file mode 100644 index 0000000..6364599 --- /dev/null +++ b/Sources/SLlama/Models/EmbeddingSettings.swift @@ -0,0 +1,13 @@ +// +// File.swift +// +// +// Created by Taylor Lineman on 11/19/23. +// + +import Foundation + +public struct EmbeddingSettings: Codable { + /// Text to process + let content: String +} diff --git a/Tests/SLlamaTests/Return Data/CompletionResponse1.json b/Tests/SLlamaTests/Return Data/CompletionResponse1.json deleted file mode 100644 index 902f358..0000000 --- a/Tests/SLlamaTests/Return Data/CompletionResponse1.json +++ /dev/null @@ -1,51 +0,0 @@ -{ - "content": "\n\n1. Choose a domain name: A domain name is the address people type into their browser to access your website. Make it simple, easy to remember, and related to your brand or business.\n\n2. Select a hosting service: Your website needs a server space to store its files and make them accessible to visitors. Look for reliable hosting services with good uptime records.\n\n3. Choose a content management system (CMS): A CMS is a software application that allows you to create, manage, and publish digital content without coding skills. Popular options include WordPress, Drupal, and Joomla.\n\n", - "generation_settings": { - "frequency_penalty": 0.0, - "grammar": "", - "ignore_eos": false, - "logit_bias": [], - "mirostat": 0, - "mirostat_eta": 0.10000000149011612, - "mirostat_tau": 5.0, - "model": "/Users/taylorlineman/Developer/Impel/Models/zephyr-7b-beta.Q5_K_S.gguf", - "n_ctx": 512, - "n_keep": 0, - "n_predict": 128, - "n_probs": 0, - "penalize_nl": true, - "presence_penalty": 0.0, - "repeat_last_n": 64, - "repeat_penalty": 1.100000023841858, - "seed": 4294967295, - "stop": [], - "stream": false, - "temp": 0.800000011920929, - "tfs_z": 1.0, - "top_k": 40, - "top_p": 0.949999988079071, - "typical_p": 1.0 - }, - "model": "/Users/taylorlineman/Developer/Impel/Models/zephyr-7b-beta.Q5_K_S.gguf", - "prompt": "Building a website can be done in 10 simple steps:", - "slot_id": 0, - "stop": true, - "stopped_eos": false, - "stopped_limit": true, - "stopped_word": false, - "stopping_word": "", - "timings": { - "predicted_ms": 7516.141, - "predicted_n": 128, - "predicted_per_second": 17.030015801992008, - "predicted_per_token_ms": 58.7198515625, - "prompt_ms": 394.926, - "prompt_n": 14, - "prompt_per_second": 35.449679180403415, - "prompt_per_token_ms": 28.209 - }, - "tokens_cached": 142, - "tokens_evaluated": 14, - "tokens_predicted": 128, - "truncated": false -} diff --git a/Tests/SLlamaTests/SLlamaTests.swift b/Tests/SLlamaTests/SLlamaTests.swift index 706a698..757dcf6 100644 --- a/Tests/SLlamaTests/SLlamaTests.swift +++ b/Tests/SLlamaTests/SLlamaTests.swift @@ -14,6 +14,16 @@ final class SLlamaTests: XCTestCase { """ + func testEmbeddings() async throws { + let client = Client(baseURLString: "http://127.0.0.1:8080") + + let settings: EmbeddingSettings = .init(content: "This is some contents to embed") + + let request = LlamaRequests.embedding(settings) + + _ = try await client.run(request) + } + func testCompletionEndpoint() async throws { let client = Client(baseURLString: "http://127.0.0.1:8080") @@ -22,15 +32,14 @@ final class SLlamaTests: XCTestCase { let request = LlamaRequests.completion(settings) - let response = try await client.run(request) - - print(response.content) + _ = try await client.run(request) } func testCompletionStreaming() async throws { - let client = Client(baseURLString: "http://127.0.0.1:8080") + let client = Client(baseURLString: "http://127.0.0.1:24445") let preparedPrompt = PromptProcessor.prepareTemplate(template: template, systemPrompt: "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.", userPrompt: "Hello Assistant") + let settings: CompletionSettings = .init(prompt: preparedPrompt, temperature: 0.7, n_predict: 256, stream: true) let request = LlamaRequests.completion(settings) @@ -50,7 +59,7 @@ final class SLlamaTests: XCTestCase { extension SLlamaTests: ClientStreamDelegate { func didReceiveModel(model: Model) where Model : Codable { guard let model = model as? CompletionResult else { return } - print(model.content) + print(model.content, terminator: "") } func didFinish(error: Error?) {