diff --git a/Demo/DemoChat/Sources/UI/DetailView.swift b/Demo/DemoChat/Sources/UI/DetailView.swift index 9e2a07e9..3cd119ea 100644 --- a/Demo/DemoChat/Sources/UI/DetailView.swift +++ b/Demo/DemoChat/Sources/UI/DetailView.swift @@ -19,7 +19,7 @@ struct DetailView: View { @State private var showsModelSelectionSheet = false @State private var selectedChatModel: Model = .gpt4_0613 - private let availableChatModels: [Model] = [.gpt3_5Turbo0613, .gpt4_0613] + private let availableChatModels: [Model] = [.gpt3_5Turbo, .gpt4_0613] let conversation: Conversation let error: Error? diff --git a/Demo/DemoChat/Sources/UI/TextToSpeechView.swift b/Demo/DemoChat/Sources/UI/TextToSpeechView.swift index 459a4423..023548ce 100644 --- a/Demo/DemoChat/Sources/UI/TextToSpeechView.swift +++ b/Demo/DemoChat/Sources/UI/TextToSpeechView.swift @@ -15,7 +15,7 @@ public struct TextToSpeechView: View { @State private var prompt: String = "" @State private var voice: AudioSpeechQuery.AudioSpeechVoice = .alloy - @State private var speed: Double = 1 + @State private var speed: Double = AudioSpeechQuery.Speed.normal.rawValue @State private var responseFormat: AudioSpeechQuery.AudioSpeechResponseFormat = .mp3 public init(store: SpeechStore) { @@ -56,7 +56,7 @@ public struct TextToSpeechView: View { HStack { Text("Speed: ") Spacer() - Stepper(value: $speed, in: 0.25...4, step: 0.25) { + Stepper(value: $speed, in: AudioSpeechQuery.Speed.min.rawValue...AudioSpeechQuery.Speed.max.rawValue, step: 0.25) { HStack { Spacer() Text("**\(String(format: "%.2f", speed))**") diff --git a/Sources/OpenAI/Public/Models/AudioSpeechQuery.swift b/Sources/OpenAI/Public/Models/AudioSpeechQuery.swift index 36db44a5..f41a91f8 100644 --- a/Sources/OpenAI/Public/Models/AudioSpeechQuery.swift +++ b/Sources/OpenAI/Public/Models/AudioSpeechQuery.swift @@ -54,13 +54,7 @@ public struct AudioSpeechQuery: Codable, Equatable { case responseFormat = "response_format" case speed } - - private enum Constants { - static let normalSpeed = 1.0 - static let maxSpeed = 4.0 - static let minSpeed = 0.25 - } - + public init(model: Model, input: String, voice: AudioSpeechVoice, responseFormat: AudioSpeechResponseFormat = .mp3, speed: Double?) { self.model = AudioSpeechQuery.validateSpeechModel(model) self.speed = AudioSpeechQuery.normalizeSpeechSpeed(speed) @@ -80,13 +74,22 @@ private extension AudioSpeechQuery { } return inputModel } - +} + +public extension AudioSpeechQuery { + + enum Speed: Double { + case normal = 1.0 + case max = 4.0 + case min = 0.25 + } + static func normalizeSpeechSpeed(_ inputSpeed: Double?) -> String { - guard let inputSpeed else { return "\(Constants.normalSpeed)" } - let isSpeedOutOfBounds = inputSpeed >= Constants.maxSpeed && inputSpeed <= Constants.minSpeed + guard let inputSpeed else { return "\(Self.Speed.normal.rawValue)" } + let isSpeedOutOfBounds = inputSpeed <= Self.Speed.min.rawValue || Self.Speed.max.rawValue <= inputSpeed guard !isSpeedOutOfBounds else { print("[AudioSpeech] Speed value must be between 0.25 and 4.0. Setting value to closest valid.") - return inputSpeed < Constants.minSpeed ? "\(Constants.minSpeed)" : "\(Constants.maxSpeed)" + return inputSpeed < Self.Speed.min.rawValue ? "\(Self.Speed.min.rawValue)" : "\(Self.Speed.max.rawValue)" } return "\(inputSpeed)" } diff --git a/Tests/OpenAITests/OpenAITests.swift b/Tests/OpenAITests/OpenAITests.swift index 8a552b78..b0f2c932 100644 --- a/Tests/OpenAITests/OpenAITests.swift +++ b/Tests/OpenAITests/OpenAITests.swift @@ -258,6 +258,30 @@ class OpenAITests: XCTestCase { XCTAssertEqual(inError, apiError) } + func testAudioSpeechDoesNotNormalize() async throws { + let query = AudioSpeechQuery(model: .tts_1, input: "Hello, world!", voice: .alloy, responseFormat: .mp3, speed: 2.0) + + XCTAssertEqual(query.speed, "\(2.0)") + } + + func testAudioSpeechNormalizeNil() async throws { + let query = AudioSpeechQuery(model: .tts_1, input: "Hello, world!", voice: .alloy, responseFormat: .mp3, speed: nil) + + XCTAssertEqual(query.speed, "\(1.0)") + } + + func testAudioSpeechNormalizeLow() async throws { + let query = AudioSpeechQuery(model: .tts_1, input: "Hello, world!", voice: .alloy, responseFormat: .mp3, speed: 0.0) + + XCTAssertEqual(query.speed, "\(0.25)") + } + + func testAudioSpeechNormalizeHigh() async throws { + let query = AudioSpeechQuery(model: .tts_1, input: "Hello, world!", voice: .alloy, responseFormat: .mp3, speed: 10.0) + + XCTAssertEqual(query.speed, "\(4.0)") + } + func testAudioSpeechError() async throws { let query = AudioSpeechQuery(model: .tts_1, input: "Hello, world!", voice: .alloy, responseFormat: .mp3, speed: 1.0) let inError = APIError(message: "foo", type: "bar", param: "baz", code: "100")