Skip to content

Commit

Permalink
refactor: removes Alamofire from embeddings (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinhermawan authored Jun 8, 2024
1 parent e43a9f6 commit 13eda0e
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 21 deletions.
21 changes: 11 additions & 10 deletions Sources/OllamaKit/OllamaKit+Embeddings.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
// Created by Paul Thrasher on 02/09/24.
//

import Alamofire
import Combine
import Foundation

Expand All @@ -24,11 +23,9 @@ extension OllamaKit {
/// - Returns: An ``OKEmbeddingsResponse`` containing the embeddings from the model.
/// - Throws: An error if the request fails or the response can't be decoded.
public func embeddings(data: OKEmbeddingsRequestData) async throws -> OKEmbeddingsResponse {
let request = AF.request(router.embeddings(data: data)).validate()
let response = request.serializingDecodable(OKEmbeddingsResponse.self, decoder: decoder)
let value = try await response.value
let request = try OKRouter.embeddings(data: data).asURLRequest()

return value
return try await OKHTTPClient.shared.sendRequest(for: request, with: OKEmbeddingsResponse.self)
}

/// Retrieves embeddings from a specific model from the Ollama API as a Combine publisher.
Expand All @@ -50,11 +47,15 @@ extension OllamaKit {
///
/// - Parameter data: The ``OKEmbeddingsRequestData`` used to query the API for embeddings from a specific model.
/// - Returns: A `AnyPublisher<OKEmbeddingsResponse, AFError>` that emits embeddings.
public func embeddings(data: OKEmbeddingsRequestData) -> AnyPublisher<OKEmbeddingsResponse, AFError> {
let request = AF.request(router.embeddings(data: data)).validate()
public func embeddings(data: OKEmbeddingsRequestData) -> AnyPublisher<OKEmbeddingsResponse, Error> {
let request: URLRequest

return request
.publishDecodable(type: OKEmbeddingsResponse.self, decoder: decoder).value()
.eraseToAnyPublisher()
do {
request = try OKRouter.embeddings(data: data).asURLRequest()
} catch {
return Fail(error: error).eraseToAnyPublisher()
}

return OKHTTPClient.shared.sendRequest(for: request, with: OKEmbeddingsResponse.self)
}
}
35 changes: 24 additions & 11 deletions Sources/OllamaKit/Utils/OKHTTPClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,28 @@ import Combine
import Foundation

internal struct OKHTTPClient {
private let decoder: JSONDecoder = .default

static let shared = OKHTTPClient()

func sendRequest(for request: URLRequest) async throws -> Void {
let (_, response) = try await URLSession.shared.data(for: request)

guard let httpResponse = response as? HTTPURLResponse, (200...299).contains(httpResponse.statusCode) else {
throw URLError(.badServerResponse)
}
}

func sendRequest<T: Decodable>(for request: URLRequest, with responseType: T.Type) async throws -> T {
let (data, response) = try await URLSession.shared.data(for: request)

guard let httpResponse = response as? HTTPURLResponse, (200...299).contains(httpResponse.statusCode) else {
throw URLError(.badServerResponse)
}

return try decoder.decode(T.self, from: data)
}

func sendRequest<T: Decodable>(for request: URLRequest, with responseType: T.Type) -> AnyPublisher<T, Error> {
return URLSession.shared.dataTaskPublisher(for: request)
.tryMap { data, response in
Expand All @@ -20,7 +40,7 @@ internal struct OKHTTPClient {

return data
}
.decode(type: T.self, decoder: JSONDecoder())
.decode(type: T.self, decoder: decoder)
.eraseToAnyPublisher()
}

Expand All @@ -36,14 +56,6 @@ internal struct OKHTTPClient {
.eraseToAnyPublisher()
}

func sendRequest(for request: URLRequest) async throws -> Void {
let (_, response) = try await URLSession.shared.data(for: request)

guard let httpResponse = response as? HTTPURLResponse, (200...299).contains(httpResponse.statusCode) else {
throw URLError(.badServerResponse)
}
}

func streamRequest<T: Decodable>(for request: URLRequest, with responseType: T.Type) -> AsyncThrowingStream<T, Error> {
return AsyncThrowingStream { continuation in
let task = URLSession.shared.dataTask(with: request) { data, response, error in
Expand All @@ -62,7 +74,7 @@ internal struct OKHTTPClient {

while let chunk = extractNextJSON(from: &buffer) {
do {
let response = try JSONDecoder().decode(T.self, from: chunk)
let response = try decoder.decode(T.self, from: chunk)
continuation.yield(response)
} catch {
continuation.finish(throwing: error)
Expand All @@ -72,6 +84,7 @@ internal struct OKHTTPClient {

continuation.finish()
}

task.resume()
}
}
Expand All @@ -95,7 +108,7 @@ internal struct OKHTTPClient {

while let chunk = extractNextJSON(from: &buffer) {
do {
let response = try JSONDecoder().decode(T.self, from: chunk)
let response = try decoder.decode(T.self, from: chunk)
subject.send(response)
} catch {
subject.send(completion: .failure(error))
Expand Down

0 comments on commit 13eda0e

Please sign in to comment.