From a46ba259564bd24494df0ed4f3d0c6c00a8a7a5f Mon Sep 17 00:00:00 2001 From: Daniel Rochetti Date: Mon, 27 Nov 2023 22:36:51 -0800 Subject: [PATCH] feat: realtime + pencilkit sample --- Package.resolved | 4 +- Package.swift | 20 +- Sources/FalClient/Client+Codable.swift | 14 +- Sources/FalClient/Client+Request.swift | 5 +- Sources/FalClient/Client.swift | 27 +- Sources/FalClient/FalClient.swift | 22 +- Sources/FalClient/Realtime+Codable.swift | 30 ++ Sources/FalClient/Realtime.swift | 322 ++++++++++++------ .../AccentColor.colorset/Contents.json | 11 + .../AppIcon.appiconset/Contents.json | 63 ++++ .../Assets.xcassets/Contents.json | 6 + .../FalRealtimeSampleApp/ContentView.swift | 60 ++++ .../DrawingCanvasView.swift | 103 ++++++ .../FalRealtimeSampleApp.entitlements | 10 + .../FalRealtimeSampleAppApp.swift | 10 + .../Preview Assets.xcassets/Contents.json | 6 + .../FalRealtimeSampleApp/ViewModel.swift | 66 ++++ .../FalRealtimeSampleApp/fal.swift | 4 + .../FalSampleApp/ContentView.swift | 2 +- 19 files changed, 642 insertions(+), 143 deletions(-) create mode 100644 Sources/FalClient/Realtime+Codable.swift create mode 100644 Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/Assets.xcassets/AccentColor.colorset/Contents.json create mode 100644 Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/Assets.xcassets/AppIcon.appiconset/Contents.json create mode 100644 Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/Assets.xcassets/Contents.json create mode 100644 Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/ContentView.swift create mode 100644 Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/DrawingCanvasView.swift create mode 100644 Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/FalRealtimeSampleApp.entitlements create mode 100644 Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/FalRealtimeSampleAppApp.swift create mode 100644 Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/Preview Content/Preview Assets.xcassets/Contents.json create mode 100644 Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/ViewModel.swift create mode 100644 Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/fal.swift diff --git a/Package.resolved b/Package.resolved index 9133dde..dce572f 100644 --- a/Package.resolved +++ b/Package.resolved @@ -5,8 +5,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/nicklockwood/SwiftFormat", "state" : { - "revision" : "d37a477177d5d4ff2a3ae6328eaaab5bf793e702", - "version" : "0.52.9" + "revision" : "cac06079ce883170ab44cb021faad298daeec2a5", + "version" : "0.52.10" } } ], diff --git a/Package.swift b/Package.swift index cf76a52..6b4de7f 100644 --- a/Package.swift +++ b/Package.swift @@ -6,8 +6,8 @@ import PackageDescription let package = Package( name: "FalClient", platforms: [ - .iOS(.v13), - .macOS(.v11), + .iOS(.v15), + .macOS(.v12), .macCatalyst(.v13), .tvOS(.v13), .watchOS(.v8), @@ -20,17 +20,25 @@ let package = Package( ), ], dependencies: [ - .package(url: "https://github.com/nicklockwood/SwiftFormat", from: "0.50.4"), + .package(url: "https://github.com/nicklockwood/SwiftFormat", from: "0.52.10"), ], targets: [ // Targets are the basic building blocks of a package, defining a module or a test suite. // Targets can depend on other targets in this package and products from dependencies. - .target( - name: "FalClient"), + .target(name: "FalClient"), .testTarget( name: "FalClientTests", dependencies: ["FalClient"] ), - .target(name: "FalSampleApp", dependencies: ["FalClient"]), + .target( + name: "FalSampleApp", + dependencies: ["FalClient"], + path: "Sources/Samples/FalSampleApp" + ), + .target( + name: "FalRealtimeSampleApp", + dependencies: ["FalClient"], + path: "Sources/Samples/FalRealtimeSampleApp" + ), ] ) diff --git a/Sources/FalClient/Client+Codable.swift b/Sources/FalClient/Client+Codable.swift index c5c0112..b774aee 100644 --- a/Sources/FalClient/Client+Codable.swift +++ b/Sources/FalClient/Client+Codable.swift @@ -19,7 +19,7 @@ public extension Client { } func run( - _ id: String, + _ app: String, input: (some Encodable) = EmptyInput.empty, options: RunOptions = DefaultRunOptions ) async throws -> Output { @@ -28,25 +28,25 @@ public extension Client { ? try JSONSerialization.jsonObject(with: inputData!) as? [String: Any] : nil - let data = try await sendRequest(id, input: inputData, queryParams: queryParams, options: options) + let url = buildUrl(fromId: app, path: options.path) + let data = try await sendRequest(url, input: inputData, queryParams: queryParams, options: options) return try decoder.decode(Output.self, from: data) } func subscribe( - _ id: String, + to app: String, input: (some Encodable) = EmptyInput.empty, pollInterval: DispatchTimeInterval = .seconds(1), timeout: DispatchTimeInterval = .minutes(3), includeLogs: Bool = false, - options _: RunOptions = DefaultRunOptions, onQueueUpdate: OnQueueUpdate? = nil ) async throws -> Output { - let requestId = try await queue.submit(id, input: input) + let requestId = try await queue.submit(app, input: input) let start = Int(Date().timeIntervalSince1970 * 1000) var elapsed = 0 var isCompleted = false while elapsed < timeout.milliseconds { - let update = try await queue.status(id, of: requestId, includeLogs: includeLogs) + let update = try await queue.status(app, of: requestId, includeLogs: includeLogs) if let onQueueUpdateCallback = onQueueUpdate { onQueueUpdateCallback(update) } @@ -60,6 +60,6 @@ public extension Client { if !isCompleted { throw FalError.queueTimeout } - return try await queue.response(id, of: requestId) + return try await queue.response(app, of: requestId) } } diff --git a/Sources/FalClient/Client+Request.swift b/Sources/FalClient/Client+Request.swift index 4fad8bc..f028232 100644 --- a/Sources/FalClient/Client+Request.swift +++ b/Sources/FalClient/Client+Request.swift @@ -1,8 +1,7 @@ import Foundation extension Client { - func sendRequest(_ id: String, input: Data?, queryParams: [String: Any]? = nil, options: RequestOptions) async throws -> Data { - let urlString = buildUrl(fromId: id, path: options.path) + func sendRequest(_ urlString: String, input: Data?, queryParams: [String: Any]? = nil, options: RunOptions) async throws -> Data { guard var url = URL(string: urlString) else { throw FalError.invalidUrl(url: urlString) } @@ -49,6 +48,6 @@ extension Client { var userAgent: String { let osVersion = ProcessInfo.processInfo.operatingSystemVersionString - return "fal.ai/swift-client 0.0.1 - \(osVersion)" + return "fal.ai/swift-client 0.1.0 - \(osVersion)" } } diff --git a/Sources/FalClient/Client.swift b/Sources/FalClient/Client.swift index 64ad75e..92040c7 100644 --- a/Sources/FalClient/Client.swift +++ b/Sources/FalClient/Client.swift @@ -1,27 +1,27 @@ import Dispatch import Foundation -enum HttpMethod: String { +public enum HttpMethod: String { case get case post case put case delete } -protocol RequestOptions { +public protocol RequestOptions { var httpMethod: HttpMethod { get } var path: String { get } } public struct RunOptions: RequestOptions { - let path: String - let httpMethod: HttpMethod + public let path: String + public let httpMethod: HttpMethod - static func withMethod(_ method: HttpMethod) -> Self { + static func withMethod(_ method: HttpMethod) -> RunOptions { RunOptions(path: "", httpMethod: method) } - static func route(_ path: String, withMethod method: HttpMethod = .post) -> Self { + static func route(_ path: String, withMethod method: HttpMethod = .post) -> RunOptions { RunOptions(path: path, httpMethod: method) } } @@ -35,39 +35,38 @@ public protocol Client { var queue: Queue { get } + var realtime: Realtime { get } + func run(_ id: String, input: [String: Any]?, options: RunOptions) async throws -> [String: Any] func subscribe( - _ id: String, + to app: String, input: [String: Any]?, pollInterval: DispatchTimeInterval, timeout: DispatchTimeInterval, includeLogs: Bool, - options: RunOptions, onQueueUpdate: OnQueueUpdate? ) async throws -> [String: Any] } public extension Client { - func run(_ id: String, input: [String: Any]? = nil, options: RunOptions = DefaultRunOptions) async throws -> [String: Any] { - return try await run(id, input: input, options: options) + func run(_ app: String, input: [String: Any]? = nil, options: RunOptions = DefaultRunOptions) async throws -> [String: Any] { + return try await run(app, input: input, options: options) } func subscribe( - _ id: String, + to app: String, input: [String: Any]? = nil, pollInterval: DispatchTimeInterval = .seconds(1), timeout: DispatchTimeInterval = .minutes(3), includeLogs: Bool = false, - options: RunOptions = DefaultRunOptions, onQueueUpdate: OnQueueUpdate? = nil ) async throws -> [String: Any] { - return try await subscribe(id, + return try await subscribe(to: app, input: input, pollInterval: pollInterval, timeout: timeout, includeLogs: includeLogs, - options: options, onQueueUpdate: onQueueUpdate) } } diff --git a/Sources/FalClient/FalClient.swift b/Sources/FalClient/FalClient.swift index c0edb60..1cd7a7a 100644 --- a/Sources/FalClient/FalClient.swift +++ b/Sources/FalClient/FalClient.swift @@ -29,10 +29,13 @@ public struct FalClient: Client { public var queue: Queue { QueueClient(client: self) } - public func run(_ id: String, input: [String: Any]?, options: RunOptions) async throws -> [String: Any] { + public var realtime: Realtime { RealtimeClient(client: self) } + + public func run(_ app: String, input: [String: Any]?, options: RunOptions) async throws -> [String: Any] { let inputData = input != nil ? try JSONSerialization.data(withJSONObject: input as Any) : nil let queryParams = options.httpMethod == .get ? input : nil - let data = try await sendRequest(id, input: inputData, queryParams: queryParams, options: options) + let url = buildUrl(fromId: app, path: options.path) + let data = try await sendRequest(url, input: inputData, queryParams: queryParams, options: options) guard let result = try JSONSerialization.jsonObject(with: data) as? [String: Any] else { throw FalError.invalidResultFormat } @@ -40,20 +43,19 @@ public struct FalClient: Client { } public func subscribe( - _ id: String, + to app: String, input: [String: Any]?, pollInterval: DispatchTimeInterval, timeout: DispatchTimeInterval, includeLogs: Bool, - options _: RunOptions, onQueueUpdate: OnQueueUpdate? ) async throws -> [String: Any] { - let requestId = try await queue.submit(id, input: input) + let requestId = try await queue.submit(app, input: input) let start = Int(Date().timeIntervalSince1970 * 1000) var elapsed = 0 var isCompleted = false while elapsed < timeout.milliseconds { - let update = try await queue.status(id, of: requestId, includeLogs: includeLogs) + let update = try await queue.status(app, of: requestId, includeLogs: includeLogs) if let onQueueUpdateCallback = onQueueUpdate { onQueueUpdateCallback(update) } @@ -67,12 +69,16 @@ public struct FalClient: Client { if !isCompleted { throw FalError.queueTimeout } - return try await queue.response(id, of: requestId) + return try await queue.response(app, of: requestId) } } public extension FalClient { - static func withProxy(_ url: String) -> FalClient { + static func withProxy(_ url: String) -> Client { return FalClient(config: ClientConfig(requestProxy: url)) } + + static func withCredentials(_ credentials: ClientCredentials) -> Client { + return FalClient(config: ClientConfig(credentials: credentials)) + } } diff --git a/Sources/FalClient/Realtime+Codable.swift b/Sources/FalClient/Realtime+Codable.swift new file mode 100644 index 0000000..513815a --- /dev/null +++ b/Sources/FalClient/Realtime+Codable.swift @@ -0,0 +1,30 @@ +import Dispatch +import Foundation + +class CodableRealtimeConnection: RealtimeConnection { + override public func send(_ data: Input) throws { + let json = try JSONEncoder().encode(data) + try sendReference(json) + } +} + +public extension Realtime { + func connect( + to app: String, + connectionKey: String, + throttleInterval: DispatchTimeInterval, + onResult completion: @escaping (Result) -> Void + ) throws -> RealtimeConnection { + return handleConnection( + to: app, connectionKey: connectionKey, throttleInterval: throttleInterval, + resultConverter: { data in + let result = try JSONDecoder().decode(Output.self, from: data) + return result + }, + connectionFactory: { send, close in + CodableRealtimeConnection(send, close) + }, + onResult: completion + ) + } +} diff --git a/Sources/FalClient/Realtime.swift b/Sources/FalClient/Realtime.swift index d7f8d89..eb75598 100644 --- a/Sources/FalClient/Realtime.swift +++ b/Sources/FalClient/Realtime.swift @@ -15,6 +15,12 @@ func throttle(_ function: @escaping (T) -> Void, throttleInterval: DispatchTi return throttledFunction } +public enum FalRealtimeError: Error { + case connectionError + case unauthorized + case invalidResult +} + public class RealtimeConnection { var sendReference: SendFunction var closeReference: CloseFunction @@ -36,17 +42,13 @@ public class RealtimeConnection { typealias SendFunction = (Data) throws -> Void typealias CloseFunction = () -> Void -public class RealtimeConnectionUntyped: RealtimeConnection<[String: Any]> { +class UntypedRealtimeConnection: RealtimeConnection<[String: Any]> { override public func send(_ data: [String: Any]) throws { let json = try JSONSerialization.data(withJSONObject: data) try sendReference(json) } } -// public protocol RealtimeConnectionTyped: RealtimeConnection where Input: Encodable { -// func send(_ data: Input) throws -// } - func buildRealtimeUrl(forApp app: String, host: String, token: String? = nil) -> URL { var components = URLComponents() components.scheme = "wss" @@ -54,66 +56,192 @@ func buildRealtimeUrl(forApp app: String, host: String, token: String? = nil) -> components.path = "/ws" if let token = token { - let queryItem = URLQueryItem(name: "fal_jwt_token", value: token) - components.queryItems = [queryItem] + components.queryItems = [URLQueryItem(name: "fal_jwt_token", value: token)] } - // swiftlint:disable:next force_unwrapping return components.url! } -class ConnectionManager { - private let session = URLSession(configuration: .default) - private var connections: [String: URLSessionWebSocketTask] = [:] - private var currentToken: String? +typealias RefreshTokenFunction = (String, (Result) -> Void) -> Void - // Singleton pattern for global access - static let shared = ConnectionManager() +private let TokenExpirationInterval: DispatchTimeInterval = .minutes(1) - init() {} +class WebSocketConnection: NSObject, URLSessionWebSocketDelegate { + let app: String + let client: Client + let onMessage: (Data) -> Void + let onError: (Error) -> Void - func token() -> String? { - return currentToken + private let queue = DispatchQueue(label: "ai.fal.WebSocketConnection.\(UUID().uuidString)") + private let session = URLSession(configuration: .default) + private var enqueuedMessages: [Data] = [] + private var task: URLSessionWebSocketTask? + private var token: String? + + private var isConnecting = false + private var isRefreshingToken = false + + init( + app: String, + client: Client, + onMessage: @escaping (Data) -> Void, + onError: @escaping (Error) -> Void + ) { + self.app = app + self.client = client + self.onMessage = onMessage + self.onError = onError } - func refreshToken(for app: String, completion: @escaping (String?) -> Void) { - // Assuming getToken is a function that fetches the token for the app - getToken(for: app) { [weak self] newToken in - self?.currentToken = newToken - print("Refreshed token: \(String(describing: newToken))") - completion(newToken) + func connect() { + if task == nil && !isConnecting && !isRefreshingToken { + isConnecting = true + if token == nil && !isRefreshingToken { + isRefreshingToken = true + refreshToken(app) { result in + switch result { + case let .success(token): + self.token = token + self.isRefreshingToken = false + self.isConnecting = false + + // Very simple token expiration handling for now + // Create the deadline 90% of the way through the token's lifetime + let tokenExpirationDeadline: DispatchTime = .now() + TokenExpirationInterval - .seconds(20) + DispatchQueue.main.asyncAfter(deadline: tokenExpirationDeadline) { + self.token = nil + } + + self.connect() + case let .failure(error): + self.isConnecting = false + self.isRefreshingToken = false + self.onError(error) + } + } + return + } + + // TODO: get host from config + let url = buildRealtimeUrl(forApp: app, host: "gateway.alpha.fal.ai", token: token) + let webSocketTask = session.webSocketTask(with: url) + webSocketTask.delegate = self + task = webSocketTask + // connect and keep the task reference + task?.resume() + isConnecting = false + receiveMessage() } } - func hasConnection(for app: String) -> Bool { - return connections[app] != nil + func refreshToken(_ app: String, completion: @escaping (Result) -> Void) { + Task { + // TODO: improve app alias resolution + let appAlias = app.split(separator: "-").dropFirst().joined(separator: "-") + let url = "https://rest.alpha.fal.ai/tokens/" + let body = try? JSONSerialization.data(withJSONObject: [ + "allowed_apps": [appAlias], + "token_expiration": 300, + ]) + do { + let response = try await self.client.sendRequest( + url, + input: body, + options: .withMethod(.post) + ) + if let token = String(data: response, encoding: .utf8) { + completion(.success(token.replacingOccurrences(of: "\"", with: ""))) + } else { + completion(.failure(FalRealtimeError.unauthorized)) + } + } catch { + completion(.failure(error)) + } + } } - func getConnection(for app: String) -> URLSessionWebSocketTask { - if let connection = connections[app] { - return connection + func receiveMessage() { + task?.receive { [weak self] incomingMessage in + switch incomingMessage { + case let .success(message): + do { + let data = try message.data() + guard let parsedMessage = try JSONSerialization.jsonObject(with: data) as? [String: Any] else { + self?.onError(FalRealtimeError.invalidResult) + return + } + if isSuccessResult(parsedMessage) { + self?.onMessage(data) + } +// if (parsedMessage["status"] as? String != "error") { +// self?.task?.cancel() +// } + + } catch { + self?.onError(error) + } + case let .failure(error): + self?.onError(error) + } + self?.receiveMessage() } + } - // TODO: get host from config - return session.webSocketTask(with: buildRealtimeUrl(forApp: app, host: "gateway.alpha.fal.ai", token: currentToken)) + func send(_ data: Data) throws { + if let task = task { + guard let message = String(data: data, encoding: .utf8) else { + return + } + task.send(.string(message)) { [weak self] error in + if let error = error { + self?.onError(error) + } + } + } else { + enqueuedMessages.append(data) + queue.sync { + if !isConnecting { + connect() + } + } + } } - func setConnection(for app: String, connection: URLSessionWebSocketTask) { - connections[app] = connection + func close() { + task?.cancel(with: .normalClosure, reason: "Programmatically closed".data(using: .utf8)) } - func removeConnection(for app: String) { - connections[app]?.cancel(with: .normalClosure, reason: nil) - connections.removeValue(forKey: app) + func urlSession( + _: URLSession, + webSocketTask _: URLSessionWebSocketTask, + didOpenWithProtocol _: String? + ) { + if let lastMessage = enqueuedMessages.last { + do { + try send(lastMessage) + } catch { + onError(error) + } + } + enqueuedMessages.removeAll() } - // Implement the getToken function or integrate your existing token fetching logic - private func getToken(for app: String, onComplete completion: @escaping (String?) -> Void) { - completion("token" + app) + func urlSession( + _: URLSession, + webSocketTask _: URLSessionWebSocketTask, + didCloseWith _: URLSessionWebSocketTask.CloseCode, + reason _: Data? + ) { + task = nil } } +var connectionPool: [String: WebSocketConnection] = [:] + public protocol Realtime { + + var client: Client { get } + func connect( to app: String, connectionKey: String, @@ -122,22 +250,9 @@ public protocol Realtime { ) throws -> RealtimeConnection<[String: Any]> } -// public extension Realtime { -// func connect( -// to app: String, -// throttleInterval: DispatchTimeInterval = .milliseconds(64), -// onResult: @escaping (Result<[String: Any], Error>) -> Void -// ) throws -> any RealtimeConnectionUntyped { -// return try connect(to: app, throttleInterval: throttleInterval, onResult: onResult) -// } -// } - -// func establishConnection( -// to _: String, -// onConnect _: @escaping (URLSessionWebSocketTask) -> Void -// ) { -// preconditionFailure() -// } +func isSuccessResult(_ message: [String: Any]) -> Bool { + return message["status"] as? String != "error" && message["type"] as? String != "x-fal-message" +} extension URLSessionWebSocketTask.Message { func data() throws -> Data { @@ -146,8 +261,7 @@ extension URLSessionWebSocketTask.Message { return data case let .string(string): guard let data = string.data(using: .utf8) else { - // TODO: improve exception type - preconditionFailure() + throw FalRealtimeError.invalidResult } return data @unknown default: @@ -156,8 +270,11 @@ extension URLSessionWebSocketTask.Message { } } -struct RealtimeClient: Realtime { - private let client: Client +public struct RealtimeClient: Realtime { + + // TODO in the future make this non-public + // External APIs should not use it + public let client: Client init(client: Client) { self.client = client @@ -170,78 +287,79 @@ struct RealtimeClient: Realtime { onResult completion: @escaping (Result<[String: Any], Error>) -> Void ) throws -> RealtimeConnection<[String: Any]> { return handleConnection( - to: app, connectionKey: connectionKey, throttleInterval: throttleInterval, + to: app, + connectionKey: connectionKey, + throttleInterval: throttleInterval, resultConverter: { data in guard let result = try JSONSerialization.jsonObject(with: data) as? [String: Any] else { - // TODO: throw exception - preconditionFailure() + throw FalRealtimeError.invalidResult } return result }, connectionFactory: { send, close in - RealtimeConnectionUntyped(send, close) + UntypedRealtimeConnection(send, close) }, onResult: completion ) } +} - func handleConnection( +extension Realtime { + internal func handleConnection( to app: String, - connectionKey _: String, + connectionKey: String, throttleInterval: DispatchTimeInterval, resultConverter convertToResultType: @escaping (Data) throws -> ResultType, connectionFactory createRealtimeConnection: @escaping (@escaping SendFunction, @escaping CloseFunction) -> RealtimeConnection, onResult completion: @escaping (Result) -> Void ) -> RealtimeConnection { - var enqueuedMessages: [Data] = [] - var ws: URLSessionWebSocketTask? = nil - - let reconnect = { - let connection = ConnectionManager.shared.getConnection(for: app) - connection.receive { incomingMessage in - switch incomingMessage { - case let .success(message): - do { - let data = try message.data() - let result = try convertToResultType(data) - completion(.success(result)) - } catch { - completion(.failure(error)) - } - case let .failure(error): - // TODO: only drop the connection if the error is fatal - ws = nil - ConnectionManager.shared.removeConnection(for: app) - // TODO: only send certain errors to the completion callback + let key = "\(app):\(connectionKey)" + let ws = connectionPool[key] ?? WebSocketConnection( + app: app, + client: self.client, + onMessage: { data in + do { + let result = try convertToResultType(data) + completion(.success(result)) + } catch { completion(.failure(error)) } + }, + onError: { error in + completion(.failure(error)) } - connection.resume() - ws = connection + ) + if connectionPool[key] == nil { + connectionPool[key] = ws } let sendData = { (data: Data) in - if let task = ws, task.state == .running { - guard let message = String(data: data, encoding: .utf8) else { - // TODO: throw exception - return - } - task.send(.string(message)) { error in - if let error = error { - completion(.failure(error)) - } - } - } else { - enqueuedMessages.append(data) - reconnect() + do { + try ws.send(data) + } catch { + completion(.failure(error)) } } - let send: SendFunction = throttleInterval.milliseconds > 0 ? throttle(sendData, throttleInterval: throttleInterval) : sendData let close: CloseFunction = { - ws?.cancel(with: .normalClosure, reason: nil) + ws.close() } - return createRealtimeConnection(send, close) } } + +public extension Realtime { + func connect( + to app: String, + connectionKey: String = UUID().uuidString, + throttleInterval: DispatchTimeInterval = .milliseconds(64), + onResult completion: @escaping (Result<[String: Any], Error>) -> Void + ) throws -> RealtimeConnection<[String: Any]> { + return try connect( + to: app, + connectionKey: connectionKey, + throttleInterval: throttleInterval, + onResult: completion + ) + } +} diff --git a/Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/Assets.xcassets/AccentColor.colorset/Contents.json b/Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/Assets.xcassets/AccentColor.colorset/Contents.json new file mode 100644 index 0000000..eb87897 --- /dev/null +++ b/Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/Assets.xcassets/AccentColor.colorset/Contents.json @@ -0,0 +1,11 @@ +{ + "colors" : [ + { + "idiom" : "universal" + } + ], + "info" : { + "author" : "xcode", + "version" : 1 + } +} diff --git a/Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/Assets.xcassets/AppIcon.appiconset/Contents.json b/Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/Assets.xcassets/AppIcon.appiconset/Contents.json new file mode 100644 index 0000000..532cd72 --- /dev/null +++ b/Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/Assets.xcassets/AppIcon.appiconset/Contents.json @@ -0,0 +1,63 @@ +{ + "images" : [ + { + "idiom" : "universal", + "platform" : "ios", + "size" : "1024x1024" + }, + { + "idiom" : "mac", + "scale" : "1x", + "size" : "16x16" + }, + { + "idiom" : "mac", + "scale" : "2x", + "size" : "16x16" + }, + { + "idiom" : "mac", + "scale" : "1x", + "size" : "32x32" + }, + { + "idiom" : "mac", + "scale" : "2x", + "size" : "32x32" + }, + { + "idiom" : "mac", + "scale" : "1x", + "size" : "128x128" + }, + { + "idiom" : "mac", + "scale" : "2x", + "size" : "128x128" + }, + { + "idiom" : "mac", + "scale" : "1x", + "size" : "256x256" + }, + { + "idiom" : "mac", + "scale" : "2x", + "size" : "256x256" + }, + { + "idiom" : "mac", + "scale" : "1x", + "size" : "512x512" + }, + { + "idiom" : "mac", + "scale" : "2x", + "size" : "512x512" + } + ], + "info" : { + "author" : "xcode", + "version" : 1 + } +} diff --git a/Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/Assets.xcassets/Contents.json b/Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/Assets.xcassets/Contents.json new file mode 100644 index 0000000..73c0059 --- /dev/null +++ b/Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/Assets.xcassets/Contents.json @@ -0,0 +1,6 @@ +{ + "info" : { + "author" : "xcode", + "version" : 1 + } +} diff --git a/Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/ContentView.swift b/Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/ContentView.swift new file mode 100644 index 0000000..91a9510 --- /dev/null +++ b/Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/ContentView.swift @@ -0,0 +1,60 @@ +import Kingfisher +import SwiftUI + +struct ContentView: View { + @State var canvasView = CanvasView() + @State var drawingData: Data? + @ObservedObject var liveImage = LiveImage() + + var body: some View { + GeometryReader { geometry in + if geometry.size.width > geometry.size.height { + // Landscape + HStack { + DrawingCanvasView(canvasView: $canvasView, drawingData: $drawingData) + .onChange(of: drawingData) { onDrawingChange() } + .frame(maxWidth: .infinity, maxHeight: .infinity) + ImageViewContainer(imageData: liveImage.currentImage) + .frame(maxWidth: .infinity, maxHeight: .infinity) + } + } else { + // Portrait + VStack { + ImageViewContainer(imageData: liveImage.currentImage) + .frame(maxWidth: .infinity, maxHeight: .infinity) + DrawingCanvasView(canvasView: $canvasView, drawingData: $drawingData) + .onChange(of: drawingData) { onDrawingChange() } + .frame(maxWidth: .infinity, maxHeight: .infinity) + } + } + } + .padding() + } + + func onDrawingChange() { + guard let data = drawingData else { + return + } + do { + try liveImage.generate(prompt: "a moon in a starry night sky", drawing: data) + } catch { + print(error) + } + } +} + +struct ImageViewContainer: View { + var imageData: Data? + + var body: some View { + VStack { + if let image = imageData { + KFImage.data(image, cacheKey: UUID().uuidString) + .transition(.opacity) + } else { + Rectangle() + .fill(Color.gray.opacity(0.4)) + } + } + } +} diff --git a/Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/DrawingCanvasView.swift b/Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/DrawingCanvasView.swift new file mode 100644 index 0000000..2bedbc9 --- /dev/null +++ b/Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/DrawingCanvasView.swift @@ -0,0 +1,103 @@ +import PencilKit +import SwiftUI +import UIKit + +class CanvasView: PKCanvasView { + // keep a list of functions that will be called when the touchMoves event is fired + var touchMoveListeners: [(Set) -> Void] = [] + + func addTouchMoveListener(_ listener: @escaping (Set) -> Void) { + touchMoveListeners.append(listener) + } + + override func touchesMoved(_ touches: Set, with _: UIEvent?) { + // call all the touchMove listeners + touchMoveListeners.forEach { listener in + listener(touches) + } + } +} + +struct DrawingCanvasView: UIViewRepresentable { + @Binding var canvasView: CanvasView + @Binding var drawingData: Data? + @State var toolPicker = PKToolPicker() + @State var isDrawing = false + + func makeUIView(context: Context) -> CanvasView { + canvasView.tool = PKInkingTool(.pen, color: .black, width: 10) + canvasView.delegate = context.coordinator + canvasView.addTouchMoveListener { _ in + if self.isDrawing { +// self.triggerDrawingChange() + } + } + return canvasView + } + +// @MainActor + func triggerDrawingChange() { + if let image = drawingToImage(canvasView: canvasView), + let imageData = image.jpegData(compressionQuality: 0.6) + { + drawingData = imageData + } + } + + func updateUIView(_: CanvasView, context _: Context) { + showToolPicker() + if let data = drawingData, + let drawing = try? PKDrawing(data: data) + { + canvasView.drawing = drawing + } + } + + func makeCoordinator() -> Coordinator { + return Coordinator(self) + } + + class Coordinator: NSObject, PKCanvasViewDelegate { + var parent: DrawingCanvasView + + init(_ parent: DrawingCanvasView) { + self.parent = parent + } + + @MainActor + func canvasViewDrawingDidChange(_: PKCanvasView) { + parent.triggerDrawingChange() + } + + @MainActor + func canvasViewDidBeginUsingTool(_: PKCanvasView) { + parent.isDrawing = true + } + + @MainActor + func canvasViewDidEndUsingTool(_: PKCanvasView) { + parent.isDrawing = false + } + } +} + +extension DrawingCanvasView { + func showToolPicker() { + toolPicker.setVisible(true, forFirstResponder: canvasView) + toolPicker.addObserver(canvasView) + canvasView.becomeFirstResponder() + } + +// @MainActor + func drawingToImage(canvasView: PKCanvasView) -> UIImage? { + // TODO: improve this, so the drawable area is clear to the user and also cropped + // correctly when the image is submitted + let drawingArea = CGRect(x: 0, y: 0, width: 512, height: 512) + return canvasView.drawing.image(from: drawingArea, scale: 1.0) +// let renderer = ImageRenderer(content: self) +// guard let image = renderer.cgImage?.cropping(to: drawingArea) else { +// return nil +// } +// return UIImage(cgImage: image) + } +} diff --git a/Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/FalRealtimeSampleApp.entitlements b/Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/FalRealtimeSampleApp.entitlements new file mode 100644 index 0000000..f2ef3ae --- /dev/null +++ b/Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/FalRealtimeSampleApp.entitlements @@ -0,0 +1,10 @@ + + + + + com.apple.security.app-sandbox + + com.apple.security.files.user-selected.read-only + + + diff --git a/Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/FalRealtimeSampleAppApp.swift b/Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/FalRealtimeSampleAppApp.swift new file mode 100644 index 0000000..7756b09 --- /dev/null +++ b/Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/FalRealtimeSampleAppApp.swift @@ -0,0 +1,10 @@ +import SwiftUI + +@main +struct FalRealtimeSampleAppApp: App { + var body: some Scene { + WindowGroup { + ContentView() + } + } +} diff --git a/Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/Preview Content/Preview Assets.xcassets/Contents.json b/Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/Preview Content/Preview Assets.xcassets/Contents.json new file mode 100644 index 0000000..73c0059 --- /dev/null +++ b/Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/Preview Content/Preview Assets.xcassets/Contents.json @@ -0,0 +1,6 @@ +{ + "info" : { + "author" : "xcode", + "version" : 1 + } +} diff --git a/Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/ViewModel.swift b/Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/ViewModel.swift new file mode 100644 index 0000000..f4a44ae --- /dev/null +++ b/Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/ViewModel.swift @@ -0,0 +1,66 @@ +import FalClient +import SwiftUI + +// See https://www.fal.ai/models/latent-consistency-sd/api for API documentation + +let OptimizedLatentConsistency = "110602490-lcm-sd15-i2i" + +struct LcmInput: Encodable { + let prompt: String + let imageUrl: String + let seed: Int + let syncMode: Bool + + enum CodingKeys: String, CodingKey { + case prompt + case imageUrl = "image_url" + case seed + case syncMode = "sync_mode" + } +} + +struct LcmImage: Decodable { + let url: String + let width: Int + let height: Int +} + +struct LcmResponse: Decodable { + let images: [LcmImage] +} + +class LiveImage: ObservableObject { + @Published var currentImage: Data? + + // This example demonstrates the support to Codable types + // RealtimeConnection<[String: Any]> can also be used + // for untyped input / output using dictionaries + private var connection: RealtimeConnection? + + init() { + connection = try? fal.realtime.connect( + to: OptimizedLatentConsistency, + connectionKey: "PencilKitDemo", + throttleInterval: .milliseconds(128) + ) { (result: Result) in + if case let .success(data) = result, + let image = data.images.first { + let data = try? Data(contentsOf: URL(string: image.url)!) + DispatchQueue.main.async { + self.currentImage = data + } + } + } + } + + func generate(prompt: String, drawing: Data) throws { + if let connection = connection { + try connection.send(LcmInput( + prompt: prompt, + imageUrl: "data:image/jpeg;base64,\(drawing.base64EncodedString())", + seed: 6_252_023, + syncMode: true + )) + } + } +} diff --git a/Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/fal.swift b/Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/fal.swift new file mode 100644 index 0000000..a0e174d --- /dev/null +++ b/Sources/Samples/FalRealtimeSampleApp/FalRealtimeSampleApp/fal.swift @@ -0,0 +1,4 @@ +import FalClient + +let fal = FalClient.withProxy("http://localhost:3333/api/fal/proxy") +// let fal = FalClient.withCredentials(.keyPair("fal_key_id:fal_key_secret")) diff --git a/Sources/Samples/FalSampleApp/FalSampleApp/ContentView.swift b/Sources/Samples/FalSampleApp/FalSampleApp/ContentView.swift index 5e5560c..1f85a39 100644 --- a/Sources/Samples/FalSampleApp/FalSampleApp/ContentView.swift +++ b/Sources/Samples/FalSampleApp/FalSampleApp/ContentView.swift @@ -17,7 +17,7 @@ struct ContentView: View { print("Generate image...") isLoading = true do { - let result = try await fal.subscribe("110602490-fast-sdxl", input: [ + let result = try await fal.subscribe(to: "110602490-fast-sdxl", input: [ "prompt": PROMPT, ], includeLogs: true) { update in print(update)