diff --git a/.gitignore b/.gitignore index 9564a1a9..5173dc0f 100644 --- a/.gitignore +++ b/.gitignore @@ -17,6 +17,7 @@ DerivedData *.ipa *.xcuserstate *.xcscmblueprint +*.resolved # CocoaPods # @@ -54,5 +55,12 @@ PerfectServer/perfectserverhttp # SwiftPM .build/ Packages/ -PerfectLib.xcodeproj/ +*.xcodeproj/ docs/ +.swiftpm/ +smtp.test.json + +.vscode/ +*.json +*.pem +*.sqlite* diff --git a/.swiftlint.yml b/.swiftlint.yml new file mode 100644 index 00000000..b819d26f --- /dev/null +++ b/.swiftlint.yml @@ -0,0 +1,32 @@ +cyclomatic_complexity: + - 64 # warning + - 128 # error +file_length: + - 2048 # warning + - 4096 # error +function_body_length: + - 128 # warning + - 256 # error +line_length: + - 256 # warning + - 512 # error +type_body_length: + - 512 # warning + - 1024 # error +disabled_rules: + - empty_enum_arguments + - function_parameter_count + - identifier_name + - inclusive_language + - large_tuple + - multiple_closures_with_trailing_closure + - nesting + - redundant_optional_initialization + - syntactic_sugar + - unused_optional_binding + - vertical_parameter_alignment + - void_return +excluded: + - .build/ + - Sources/PerfectHTTP/Mime*.swift + - Sources/PerfectHTTPServer/HTTP2/HPACK.swift \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..f4ef4f6c --- /dev/null +++ b/Dockerfile @@ -0,0 +1,17 @@ +FROM ubuntu:20.04 +RUN apt-get update -y +RUN apt-get install -y wget +WORKDIR /tmp +ARG arch +COPY ./${arch}.url.txt /tmp/url.txt +RUN rm -rf /tmp/sw* +RUN wget -O /tmp/swift.tgz $(cat /tmp/url.txt) +RUN cd /tmp && tar xf /tmp/swift.tgz && rm -rf /tmp/swift.tgz && mv $(ls|grep swift) /tmp/swift/ +RUN cd /tmp/swift/usr/ && tar cf /tmp/sw.tar * +RUN cd /usr && tar xf /tmp/sw.tar +RUN rm -rf /tmp/sw* +RUN apt-get update -y +RUN apt-get install -y build-essential clang git +RUN apt-get install -y libcurl4-openssl-dev uuid-dev +RUN apt-get install -y libsqlite3-dev libncurses-dev +RUN DEBIAN_FRONTEND=noninteractive apt-get install -y libxml2-dev diff --git a/Package.resolved b/Package.resolved new file mode 100644 index 00000000..e68406e0 --- /dev/null +++ b/Package.resolved @@ -0,0 +1,52 @@ +{ + "object": { + "pins": [ + { + "package": "COpenSSL", + "repositoryURL": "https://github.com/PerfectlySoft/Perfect-COpenSSL.git", + "state": { + "branch": null, + "revision": "ce3113e159b8c6d8565e5d8db2672b572c81aea9", + "version": "4.0.2" + } + }, + { + "package": "PerfectCZlib", + "repositoryURL": "https://github.com/RockfordWei/Perfect-CZlib-src.git", + "state": { + "branch": null, + "revision": "8295883fd760f601a2c8c3236af83c8c35f941c6", + "version": "0.0.6" + } + }, + { + "package": "cURL", + "repositoryURL": "https://github.com/PerfectlySoft/Perfect-libcurl.git", + "state": { + "branch": null, + "revision": "b3d7e65ef5c27c0a027cdc621f34835975301bf1", + "version": "2.1.0" + } + }, + { + "package": "LinuxBridge", + "repositoryURL": "https://github.com/PerfectlySoft/Perfect-LinuxBridge.git", + "state": { + "branch": null, + "revision": "d6e64c48e6b06b6f1ab7ab9338280447baa8ca5c", + "version": "3.1.0" + } + }, + { + "package": "PerfectCSQLite3", + "repositoryURL": "https://github.com/PerfectlySoft/Perfect-sqlite3-support.git", + "state": { + "branch": null, + "revision": "64c2bd87e1fd3a41cdeeba073bab794db7e97e42", + "version": "3.1.1" + } + } + ] + }, + "version": 1 +} diff --git a/Package.swift b/Package.swift index b6623ef2..b8dfcbc2 100644 --- a/Package.swift +++ b/Package.swift @@ -1,50 +1,92 @@ -// swift-tools-version:5.1 -// -// Package.swift -// PerfectLib -// -// Created by Kyle Jessup on 3/22/16. -// Copyright (C) 2016 PerfectlySoft, Inc. -// -//===----------------------------------------------------------------------===// -// -// This source file is part of the Perfect.org open source project -// -// Copyright (c) 2015 - 2016 PerfectlySoft Inc. and the Perfect project authors -// Licensed under Apache License v2.0 -// -// See http://perfect.org/licensing.html for license information -// -//===----------------------------------------------------------------------===// -// +// swift-tools-version: 5.4 +// The swift-tools-version declares the minimum version of Swift required to build this package. import PackageDescription #if os(Linux) -let package = Package( - name: "PerfectLib", - products: [ - .library(name: "PerfectLib", targets: ["PerfectLib"]) - ], - dependencies: [.package(url: "https://github.com/PerfectlySoft/Perfect-LinuxBridge.git", from: "3.0.0")], - targets: [ - .target(name: "PerfectLib", dependencies: ["LinuxBridge"]), - .testTarget(name: "PerfectLibTests", dependencies: ["PerfectLib"]) - ] -) +let pkdep: [Package.Dependency] = [ + .package(url: "https://github.com/PerfectlySoft/Perfect-LinuxBridge.git", from: "3.1.0"), + .package(url: "https://github.com/PerfectlySoft/Perfect-sqlite3-support.git", from: "3.1.1") +] + +let sqlite3dep: [Target.Dependency] = [ + .product(name: "PerfectCSQLite3", package: "Perfect-sqlite3-support") +] + +let osdep: [Target.Dependency] = sqlite3dep + [.product(name: "LinuxBridge", package: "Perfect-LinuxBridge")] +let sqldep: [Target.Dependency] = sqlite3dep + [.init(stringLiteral: "PerfectCRUD")] #else +let pkdep: [Package.Dependency] = [] +let osdep: [Target.Dependency] = [] +let sqldep: [Target.Dependency] = ["PerfectCRUD"] +#endif + let package = Package( - name: "PerfectLib", - platforms: [ - .macOS(.v10_15) - ], - products: [ - .library(name: "PerfectLib", targets: ["PerfectLib"]) - ], - dependencies: [], - targets: [ - .target(name: "PerfectLib", dependencies: []), - .testTarget(name: "PerfectLibTests", dependencies: ["PerfectLib"]) - ] + name: "Perfect", + products: [ + .library(name: "PerfectAuth", targets: ["PerfectAuth"]), + .library(name: "PerfectCRUD", targets: ["PerfectCRUD"]), + .library(name: "PerfectCrypto", targets: ["PerfectCrypto"]), + .library(name: "PerfectCURL", targets: ["PerfectCURL"]), + .library(name: "PerfectLib", targets: ["PerfectLib"]), + .library(name: "PerfectHTTP", targets: ["PerfectHTTP"]), + .library(name: "PerfectHTTPServer", targets: ["PerfectHTTPServer"]), + .library(name: "PerfectMustache", targets: ["PerfectMustache"]), + .library(name: "PerfectNet", targets: ["PerfectNet"]), + .library(name: "PerfectSMTP", targets: ["PerfectSMTP"]), + .library(name: "PerfectSQLite", targets: ["PerfectSQLite"]), + .library(name: "PerfectThread", targets: ["PerfectThread"]), + .executable(name: "httpd", targets: ["httpd"]) + ], + dependencies: pkdep + [ + .package(url: "https://github.com/PerfectlySoft/Perfect-libcurl.git", from: "2.0.0"), + .package(url: "https://github.com/PerfectlySoft/Perfect-COpenSSL.git", from: "4.0.2"), + .package(url: "https://github.com/RockfordWei/Perfect-CZlib-src.git", from: "0.0.6") + ], + targets: [ + .target(name: "PerfectAuth", dependencies: ["PerfectCrypto", "PerfectCRUD", "PerfectSQLite"]), + .target(name: "PerfectCHTTPParser"), + .target(name: "PerfectLib", dependencies: osdep), + .target(name: "PerfectThread", dependencies: osdep), + .target(name: "PerfectCRUD"), + .target(name: "PerfectCrypto", dependencies: [ + .init(stringLiteral: "PerfectLib"), + .init(stringLiteral: "PerfectThread"), + .product(name: "COpenSSL", package: "Perfect-COpenSSL") + ]), + .target(name: "PerfectCURL", dependencies: [ + .product(name: "cURL", package: "Perfect-libcurl"), + .init(stringLiteral: "PerfectLib"), + .init(stringLiteral: "PerfectThread") + ]), + .target(name: "PerfectHTTP", dependencies: ["PerfectLib", "PerfectNet"]), + .target(name: "PerfectHTTPServer", dependencies: [ + .init(stringLiteral: "PerfectCHTTPParser"), + .init(stringLiteral: "PerfectCrypto"), + .init(stringLiteral: "PerfectNet"), + .init(stringLiteral: "PerfectHTTP"), + .product(name: "PerfectCZlib", package: "Perfect-CZlib-src") + ]), + .target(name: "PerfectMustache", dependencies: ["PerfectLib"]), + .target(name: "PerfectNet", dependencies: ["PerfectCrypto", "PerfectThread"]), + .target(name: "PerfectSMTP", dependencies: ["PerfectCURL", "PerfectCrypto", "PerfectHTTP"]), + .target(name: "PerfectSQLite", dependencies: sqldep), + .testTarget(name: "PerfectAuthTests", dependencies: [ + "PerfectAuth", "PerfectCRUD", "PerfectCrypto", "PerfectLib", "PerfectSQLite" + ]), + .testTarget(name: "PerfectCryptoTests", dependencies: ["PerfectCrypto"]), + .testTarget(name: "PerfectCURLTests", dependencies: ["PerfectCURL"]), + .testTarget(name: "PerfectHTTPTests", dependencies: ["PerfectHTTP"]), + .testTarget(name: "PerfectHTTPServerTests", dependencies: ["PerfectHTTPServer"]), + .testTarget(name: "PerfectLibTests", dependencies: ["PerfectLib"]), + .testTarget(name: "PerfectMustacheTests", dependencies: ["PerfectMustache"]), + .testTarget(name: "PerfectNetTests", dependencies: ["PerfectNet"]), + .testTarget(name: "PerfectSMTPTests", dependencies: ["PerfectSMTP"]), + .testTarget(name: "PerfectSQLiteTests", dependencies: ["PerfectSQLite"]), + .testTarget(name: "PerfectThreadTests", dependencies: ["PerfectThread"]), + .executableTarget(name: "httpd", dependencies: [ + "PerfectAuth", "PerfectCrypto", "PerfectLib", "PerfectHTTPServer", "PerfectHTTP", + "PerfectMustache", "PerfectSMTP", "PerfectSQLite" + ]) + ] ) -#endif diff --git a/README.md b/README.md index e5b0ab46..190d3ad6 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@

- Swift 5.2 + Swift 5.6 Platforms macOS | Linux @@ -17,6 +17,19 @@

+**OS**|**Version**|**Chip**|**Status** +--|-------|----|------ +Ventura|macOS 13.6|Apple Silicon M2|Ventura Apple silicon +Ubuntu|22.04 LTS|i386|Ubuntu 22.04 LTS i386 +Ubuntu|22.04 LTS|arm64|Ubuntu 22.04 LTS arm64 +**Package**|**Status**|**Package**|**Status** +PerfectLib|PerfectLib|PerfectThread|PerfectThread +PerfectAuth|PerfectAuth|PerfectCRUD|PerfectCRUD +PerfectCrypto|PerfectCrypto|PerfectCURL|PerfectCURL +PerfectHTTP|PerfectHTTP|PerfectHTTPServer|PerfectHTTPServer +PerfectMustache|PerfectMustache|PerfectNet|PerfectNet +PerfectSMTP|PerfectSMTP|PerfectSQLite|PerfectSQLite + ## Perfect: Server-Side Swift Perfect is a complete and powerful toolbox, framework, and application server for Linux, iOS, and macOS (OS X). It provides everything a Swift engineer needs for developing lightweight, maintainable, and scalable apps and other REST services entirely in the Swift programming language for both client-facing and server-side applications. @@ -53,7 +66,6 @@ Your Perfect project can be deployed to any Swift compatible Linux server. We pr Our library continues to grow as members of [the Swift-Perfect development community have shared many samples and examples](https://github.com/PerfectExamples) of their projects in Perfect. Examples include: -- [WebSockets Server](https://github.com/PerfectExamples/PerfectExample-WebSocketsServer) - [URL Routing](https://github.com/PerfectExamples/PerfectExample-URLRouting) - [Upload Enumerator](https://github.com/PerfectExamples/PerfectExample-UploadEnumerator) @@ -90,7 +102,6 @@ Perfect project is divided into several repositories to make it easy for you to - [Perfect HTTP Server](https://github.com/PerfectlySoft/Perfect-HTTPServer) - HTTP 1.1 server for Perfect - [Perfect Mustache](https://github.com/PerfectlySoft/Perfect-Mustache) - Mustache template support for Perfect - [Perfect CURL](https://github.com/PerfectlySoft/Perfect-CURL) - cURL support for Perfect -- [Perfect WebSockets](https://github.com/PerfectlySoft/Perfect-WebSockets) - WebSockets support for Perfect - [Perfect Zip](https://github.com/PerfectlySoft/Perfect-Zip) - provides simple zip and unzip functionality - [Perfect Notifications](https://github.com/PerfectlySoft/Perfect-Notifications) - provides support for Apple Push Notification Service (APNS). @@ -98,7 +109,7 @@ Perfect project is divided into several repositories to make it easy for you to Perfect operates using either a standalone [HTTP server](https://github.com/PerfectlySoft/Perfect-HTTP), [HTTPS server](https://github.com/PerfectlySoft/Perfect-HTTPServer), or through [FastCGI server](https://github.com/PerfectlySoft/Perfect-FastCGI). It provides a system for loading your Swift-based modules at startup, for interfacing those modules with its request/response objects, or to the built-in [Mustache template processing system](https://github.com/PerfectlySoft/Perfect-Mustache). -Perfect is built on a completely asynchronous, high-performance networking engine to provide a scalable option for internet services. It supports Secure Sockets Layer (SSL) encryption, and it features a suite of tools commonly required by internet servers such as [WebSockets](https://github.com/PerfectlySoft/Perfect-WebSockets) and [iOS push notifications](https://github.com/PerfectlySoft/Perfect-Notifications), but you are not limited to those options. +Perfect is built on a completely asynchronous, high-performance networking engine to provide a scalable option for internet services. It supports Secure Sockets Layer (SSL) encryption, and it features a suite of tools commonly required by internet servers such as [iOS push notifications](https://github.com/PerfectlySoft/Perfect-Notifications), but you are not limited to those options. Feel free to use your favourite JSON or templating systems, etc. diff --git a/Sources/PerfectAuth/Nonce.swift b/Sources/PerfectAuth/Nonce.swift new file mode 100644 index 00000000..d16c5700 --- /dev/null +++ b/Sources/PerfectAuth/Nonce.swift @@ -0,0 +1,47 @@ +// +// Nonce.swift +// +// +// Created by Rockford Wei on 2022-06-27. +// + +import Foundation +import PerfectCrypto + +/// Nonce is a special server allocated JWT which can be typically used to check if a post is valid. +/// For example, any post method should include a valid nonce before action, so if not, the server can just simply ignore it. +public struct Nonce { + fileprivate struct Payload: Codable { + let host: UUID + let timestamp: TimeInterval + init(host h: UUID, timestamp t: TimeInterval = Date().timeIntervalSince1970) { + host = h; timestamp = t + } + } + + fileprivate static let algo = JWT.Alg.hs256 + fileprivate static let host = UUID() + + /// allocate a nonce string + public static func allocate(authorityPrivateKey: PEMKey) throws -> String { + let payload = Payload(host: host) + return try JWTCreator(payload: payload).sign(alg: algo, key: authorityPrivateKey) + } + + /// check if this nonce is valid + public static func validate(nonce: String, seconds: Int = 900, authorityPublicKey: PEMKey) throws { + // swiftlint:disable type_name + typealias exception = AuthenticationTokenClaim.Exception + guard let jwt = JWTVerifier(nonce) else { + throw exception.invalidJsonWebToken + } + try jwt.verify(algo: algo, key: authorityPublicKey) + let payload = try jwt.decode(as: Payload.self) + guard payload.host == host else { + throw exception.invalidHostKey + } + guard payload.timestamp + TimeInterval(seconds) > Date().timeIntervalSince1970 else { + throw exception.expired + } + } +} diff --git a/Sources/PerfectAuth/PerfectAuth.swift b/Sources/PerfectAuth/PerfectAuth.swift new file mode 100644 index 00000000..f038e77e --- /dev/null +++ b/Sources/PerfectAuth/PerfectAuth.swift @@ -0,0 +1,36 @@ +// +// Based on SAuth.swift +// [SAuthLib](https://github.com/kjessup/SAuthLib) +// +// Created by Kyle Jessup on 2018-02-26. +// Digested by Rockford Wei on 2022-06-23. +// + +import Foundation +import PerfectCrypto + +open class AuthenticationUtilities { + public static func hash(password: String) -> (hexSalt: String, hexHash: String)? { + let saltBytes = Array(randomCount: 32) + guard let saltHex = saltBytes.encode(.hex), + let hashHex = hash(password: password, saltBytes: saltBytes) else { + return nil + } + return (String(validatingUTF8: saltHex) ?? "", hashHex) + } + public static func validate(password: String, hexSalt: String, hexHash: String) -> Bool { + guard let saltBytes = hexSalt.decode(.hex), + let compareHexHash = hash(password: password, saltBytes: saltBytes) else { + return false + } + return compareHexHash == hexHash + } + private static func hash(password: String, saltBytes: [UInt8]) -> String? { + let pwBytes = Array(password.utf8) + guard let hashBytes = Digest.sha256.deriveKey(password: pwBytes, salt: saltBytes, iterations: 2048, keyLength: 32), + let hashHex = hashBytes.encode(.hex) else { + return nil + } + return String(validatingUTF8: hashHex) ?? "" + } +} diff --git a/Sources/PerfectAuth/TokenClaim.swift b/Sources/PerfectAuth/TokenClaim.swift new file mode 100644 index 00000000..42368816 --- /dev/null +++ b/Sources/PerfectAuth/TokenClaim.swift @@ -0,0 +1,87 @@ +// +// Based on SAuth.swift & Codables.swift +// [SAuthLib](https://github.com/kjessup/SAuthLib) +// [SAuthCodables](https://github.com/kjessup/SAuthCodables) +// +// Created by Kyle Jessup on 2018-02-26. +// Digested by Rockford Wei on 2022-06-23. +// + +import Foundation +import PerfectCrypto + +// swiftlint:disable line_length +public struct AuthenticationTokenClaim { + public enum Keys { + public static let account = "acc" + public static let expiration = "exp" + public static let issuedAt = "iat" + public static let issuer = "iss" + public static let subject = "sub" + } + public enum Exception: Error { + case invalidJsonWebToken + case invalidHostKey + case expired + } + public let payload: [String: Any] + public var account: String? { + payload[Keys.account] as? String + } + public var expiration: Int? { + payload[Keys.expiration] as? Int + } + public var issuer: String? { + payload[Keys.issuer] as? String + } + public var issuedAt: Int? { + payload[Keys.issuedAt] as? Int + } + public var subject: String? { + payload[Keys.subject] as? String + } + public init(fields: [String: Any]) { + var fields = fields + self.init(account: fields.removeValue(forKey: Keys.account) as? String, expiration: fields.removeValue(forKey: Keys.expiration) as? Int, issuer: fields.removeValue(forKey: Keys.issuer) as? String, issuedAt: fields.removeValue(forKey: Keys.issuedAt) as? Int, subject: fields.removeValue(forKey: Keys.subject) as? String, extra: fields) + } + public init(account: String? = nil, expiration: Int? = nil, issuer: String? = nil, issuedAt: Int? = nil, subject: String? = nil, extra: [String: Any]? = nil) { + var p: [String: Any] = [:] + if let v = account { + p[Keys.account] = v + } + if let v = expiration { + p[Keys.expiration] = v + } + if let v = issuer { + p[Keys.issuer] = v + } + if let v = issuedAt { + p[Keys.issuedAt] = v + } + if let v = subject { + p[Keys.subject] = v + } + if let v = extra { + p.merge(v, uniquingKeysWith: { $1 }) + } + payload = p + } + public static let algo = JWT.Alg.rs256 + public init(jsonWebToken: String, authorityPublicKey: PEMKey) throws { + guard let jwt = JWTVerifier(jsonWebToken) else { + throw Exception.invalidJsonWebToken + } + try jwt.verify(algo: AuthenticationTokenClaim.algo, key: authorityPublicKey) + self.init(fields: jwt.payload) + } + + public func generateJsonWebToken(authorityPrivateKey: PEMKey) throws -> String? { + return try JWTCreator(payload: payload)?.sign(alg: AuthenticationTokenClaim.algo, key: authorityPrivateKey) + } +} + +extension AuthenticationTokenClaim: Equatable { + public static func == (lhs: AuthenticationTokenClaim, rhs: AuthenticationTokenClaim) -> Bool { + return lhs.account == rhs.account && lhs.subject == rhs.subject && lhs.issuer == rhs.issuer && lhs.issuedAt == rhs.issuedAt && lhs.expiration == rhs.expiration + } +} diff --git a/Sources/PerfectAuth/Transient.swift b/Sources/PerfectAuth/Transient.swift new file mode 100644 index 00000000..456c437e --- /dev/null +++ b/Sources/PerfectAuth/Transient.swift @@ -0,0 +1,88 @@ +// +// Transient.swift +// +// +// Created by Rockford Wei on 2022-07-12. +// + +import Foundation +import PerfectCRUD +import PerfectSQLite + +public struct OneTimeRecord: Codable { + public let id: Int + public let subject: String + public let createdAt: Int + init(id i: Int = Int.random(in: 0..<1_000_000), subject s: String, createdAt c: Int = Date().timestamp) { + id = i; subject = s; createdAt = c + } +} + +public final class Transient { + public enum Exception: Error { + case overAttempted(shouldWaitSeconds: Int) + case subjectNotFound + case invalidCode + case expired + } + public static var minimalRetrySeconds = 60 + public static let expirySeconds = 900 // 15 minutes + public static let dbPath = "/tmp/perfect-transient.sqlite3" + private static let queue = DispatchQueue(label: UUID().uuidString) + private static let db: Database = { + do { + let config = try SQLiteDatabaseConfiguration(dbPath) + let db = Database(configuration: config) + try db.create(OneTimeRecord.self, policy: .reconcileTable) + return db + } catch { + fatalError("unable load \(dbPath) for transient code control because \(error)") + } + }() + public static func cleanup(expiry: Int = expirySeconds) { + do { + let expired = Date().timestamp - expiry + let table = db.table(OneTimeRecord.self) + try table.where(\OneTimeRecord.createdAt <= expired).delete() + queue.asyncAfter(deadline: .now() + Double(expiry)) { + #if DEBUG + CRUDLogging.log(.info, "\(expired):: cleaning up obsolete records") + #endif + cleanup(expiry: expiry) + } + } catch { + CRUDLogging.log(.warning, "unable to clean the transient allocator because \(error)") + } + } + public static func record(of subject: String) throws -> OneTimeRecord? { + let table = db.table(OneTimeRecord.self) + return try table.order(descending: \OneTimeRecord.createdAt) + .where(\OneTimeRecord.subject == subject) + .first() + } + public static func allocate(subject: String, minimalRetry: Int = minimalRetrySeconds) throws -> Int { + let table = db.table(OneTimeRecord.self) + if let record = try record(of: subject) { + let secondsToWait = minimalRetry - (Date().timestamp - record.createdAt) + if secondsToWait > 0 { + throw Exception.overAttempted(shouldWaitSeconds: secondsToWait) + } else { + try table.where(\OneTimeRecord.subject == subject).delete() + } + } + let record = OneTimeRecord(subject: subject) + try table.insert(record) + return record.id + } + public static func validate(id: Int, subject: String, expiry: Int = expirySeconds) throws { + guard let record = try record(of: subject) else { + throw Exception.subjectNotFound + } + guard id == record.id else { + throw Exception.invalidCode + } + guard Date().timestamp < record.createdAt + expiry else { + throw Exception.expired + } + } +} diff --git a/Sources/PerfectCHTTPParser/http_parser.c b/Sources/PerfectCHTTPParser/http_parser.c new file mode 100755 index 00000000..895bf0c7 --- /dev/null +++ b/Sources/PerfectCHTTPParser/http_parser.c @@ -0,0 +1,2470 @@ +/* Based on src/http/ngx_http_parse.c from NGINX copyright Igor Sysoev + * + * Additional changes are licensed under the same terms as NGINX and + * copyright Joyent, Inc. and other Node contributors. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ +#include "http_parser.h" +#include +#include +#include +#include +#include +#include + +#ifndef ULLONG_MAX +# define ULLONG_MAX ((uint64_t) -1) /* 2^64-1 */ +#endif + +#ifndef MIN +# define MIN(a,b) ((a) < (b) ? (a) : (b)) +#endif + +#ifndef ARRAY_SIZE +# define ARRAY_SIZE(a) (sizeof(a) / sizeof((a)[0])) +#endif + +#ifndef BIT_AT +# define BIT_AT(a, i) \ + (!!((unsigned int) (a)[(unsigned int) (i) >> 3] & \ + (1 << ((unsigned int) (i) & 7)))) +#endif + +#ifndef ELEM_AT +# define ELEM_AT(a, i, v) ((unsigned int) (i) < ARRAY_SIZE(a) ? (a)[(i)] : (v)) +#endif + +#define SET_ERRNO(e) \ +do { \ + parser->http_errno = (e); \ +} while(0) + +#define CURRENT_STATE() p_state +#define UPDATE_STATE(V) p_state = (enum state) (V); +#define RETURN(V) \ +do { \ + parser->state = CURRENT_STATE(); \ + return (V); \ +} while (0); +#define REEXECUTE() \ + goto reexecute; \ + + +#ifdef __GNUC__ +# define LIKELY(X) __builtin_expect(!!(X), 1) +# define UNLIKELY(X) __builtin_expect(!!(X), 0) +#else +# define LIKELY(X) (X) +# define UNLIKELY(X) (X) +#endif + + +/* Run the notify callback FOR, returning ER if it fails */ +#define CALLBACK_NOTIFY_(FOR, ER) \ +do { \ + assert(HTTP_PARSER_ERRNO(parser) == HPE_OK); \ + \ + if (LIKELY(settings->on_##FOR)) { \ + parser->state = CURRENT_STATE(); \ + if (UNLIKELY(0 != settings->on_##FOR(parser))) { \ + SET_ERRNO(HPE_CB_##FOR); \ + } \ + UPDATE_STATE(parser->state); \ + \ + /* We either errored above or got paused; get out */ \ + if (UNLIKELY(HTTP_PARSER_ERRNO(parser) != HPE_OK)) { \ + return (ER); \ + } \ + } \ +} while (0) + +/* Run the notify callback FOR and consume the current byte */ +#define CALLBACK_NOTIFY(FOR) CALLBACK_NOTIFY_(FOR, p - data + 1) + +/* Run the notify callback FOR and don't consume the current byte */ +#define CALLBACK_NOTIFY_NOADVANCE(FOR) CALLBACK_NOTIFY_(FOR, p - data) + +/* Run data callback FOR with LEN bytes, returning ER if it fails */ +#define CALLBACK_DATA_(FOR, LEN, ER) \ +do { \ + assert(HTTP_PARSER_ERRNO(parser) == HPE_OK); \ + \ + if (FOR##_mark) { \ + if (LIKELY(settings->on_##FOR)) { \ + parser->state = CURRENT_STATE(); \ + if (UNLIKELY(0 != \ + settings->on_##FOR(parser, FOR##_mark, (LEN)))) { \ + SET_ERRNO(HPE_CB_##FOR); \ + } \ + UPDATE_STATE(parser->state); \ + \ + /* We either errored above or got paused; get out */ \ + if (UNLIKELY(HTTP_PARSER_ERRNO(parser) != HPE_OK)) { \ + return (ER); \ + } \ + } \ + FOR##_mark = NULL; \ + } \ +} while (0) + +/* Run the data callback FOR and consume the current byte */ +#define CALLBACK_DATA(FOR) \ + CALLBACK_DATA_(FOR, p - FOR##_mark, p - data + 1) + +/* Run the data callback FOR and don't consume the current byte */ +#define CALLBACK_DATA_NOADVANCE(FOR) \ + CALLBACK_DATA_(FOR, p - FOR##_mark, p - data) + +/* Set the mark FOR; non-destructive if mark is already set */ +#define MARK(FOR) \ +do { \ + if (!FOR##_mark) { \ + FOR##_mark = p; \ + } \ +} while (0) + +/* Don't allow the total size of the HTTP headers (including the status + * line) to exceed HTTP_MAX_HEADER_SIZE. This check is here to protect + * embedders against denial-of-service attacks where the attacker feeds + * us a never-ending header that the embedder keeps buffering. + * + * This check is arguably the responsibility of embedders but we're doing + * it on the embedder's behalf because most won't bother and this way we + * make the web a little safer. HTTP_MAX_HEADER_SIZE is still far bigger + * than any reasonable request or response so this should never affect + * day-to-day operation. + */ +#define COUNT_HEADER_SIZE(V) \ +do { \ + parser->nread += (V); \ + if (UNLIKELY(parser->nread > (HTTP_MAX_HEADER_SIZE))) { \ + SET_ERRNO(HPE_HEADER_OVERFLOW); \ + goto error; \ + } \ +} while (0) + + +#define PROXY_CONNECTION "proxy-connection" +#define CONNECTION "connection" +#define CONTENT_LENGTH "content-length" +#define TRANSFER_ENCODING "transfer-encoding" +#define UPGRADE "upgrade" +#define CHUNKED "chunked" +#define KEEP_ALIVE "keep-alive" +#define CLOSE "close" + + +static const char *method_strings[] = + { +#define XX(num, name, string) #string, + HTTP_METHOD_MAP(XX) +#undef XX + }; + + +/* Tokens as defined by rfc 2616. Also lowercases them. + * token = 1* + * separators = "(" | ")" | "<" | ">" | "@" + * | "," | ";" | ":" | "\" | <"> + * | "/" | "[" | "]" | "?" | "=" + * | "{" | "}" | SP | HT + */ +static const char tokens[256] = { +/* 0 nul 1 soh 2 stx 3 etx 4 eot 5 enq 6 ack 7 bel */ + 0, 0, 0, 0, 0, 0, 0, 0, +/* 8 bs 9 ht 10 nl 11 vt 12 np 13 cr 14 so 15 si */ + 0, 0, 0, 0, 0, 0, 0, 0, +/* 16 dle 17 dc1 18 dc2 19 dc3 20 dc4 21 nak 22 syn 23 etb */ + 0, 0, 0, 0, 0, 0, 0, 0, +/* 24 can 25 em 26 sub 27 esc 28 fs 29 gs 30 rs 31 us */ + 0, 0, 0, 0, 0, 0, 0, 0, +/* 32 sp 33 ! 34 " 35 # 36 $ 37 % 38 & 39 ' */ + 0, '!', 0, '#', '$', '%', '&', '\'', +/* 40 ( 41 ) 42 * 43 + 44 , 45 - 46 . 47 / */ + 0, 0, '*', '+', 0, '-', '.', 0, +/* 48 0 49 1 50 2 51 3 52 4 53 5 54 6 55 7 */ + '0', '1', '2', '3', '4', '5', '6', '7', +/* 56 8 57 9 58 : 59 ; 60 < 61 = 62 > 63 ? */ + '8', '9', 0, 0, 0, 0, 0, 0, +/* 64 @ 65 A 66 B 67 C 68 D 69 E 70 F 71 G */ + 0, 'a', 'b', 'c', 'd', 'e', 'f', 'g', +/* 72 H 73 I 74 J 75 K 76 L 77 M 78 N 79 O */ + 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', +/* 80 P 81 Q 82 R 83 S 84 T 85 U 86 V 87 W */ + 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', +/* 88 X 89 Y 90 Z 91 [ 92 \ 93 ] 94 ^ 95 _ */ + 'x', 'y', 'z', 0, 0, 0, '^', '_', +/* 96 ` 97 a 98 b 99 c 100 d 101 e 102 f 103 g */ + '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', +/* 104 h 105 i 106 j 107 k 108 l 109 m 110 n 111 o */ + 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', +/* 112 p 113 q 114 r 115 s 116 t 117 u 118 v 119 w */ + 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', +/* 120 x 121 y 122 z 123 { 124 | 125 } 126 ~ 127 del */ + 'x', 'y', 'z', 0, '|', 0, '~', 0 }; + + +static const int8_t unhex[256] = + {-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 + ,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 + ,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 + , 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,-1,-1,-1,-1,-1,-1 + ,-1,10,11,12,13,14,15,-1,-1,-1,-1,-1,-1,-1,-1,-1 + ,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 + ,-1,10,11,12,13,14,15,-1,-1,-1,-1,-1,-1,-1,-1,-1 + ,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 + }; + + +#if HTTP_PARSER_STRICT +# define T(v) 0 +#else +# define T(v) v +#endif + + +static const uint8_t normal_url_char[32] = { +/* 0 nul 1 soh 2 stx 3 etx 4 eot 5 enq 6 ack 7 bel */ + 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0, +/* 8 bs 9 ht 10 nl 11 vt 12 np 13 cr 14 so 15 si */ + 0 | T(2) | 0 | 0 | T(16) | 0 | 0 | 0, +/* 16 dle 17 dc1 18 dc2 19 dc3 20 dc4 21 nak 22 syn 23 etb */ + 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0, +/* 24 can 25 em 26 sub 27 esc 28 fs 29 gs 30 rs 31 us */ + 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0, +/* 32 sp 33 ! 34 " 35 # 36 $ 37 % 38 & 39 ' */ + 0 | 2 | 4 | 0 | 16 | 32 | 64 | 128, +/* 40 ( 41 ) 42 * 43 + 44 , 45 - 46 . 47 / */ + 1 | 2 | 4 | 8 | 16 | 32 | 64 | 128, +/* 48 0 49 1 50 2 51 3 52 4 53 5 54 6 55 7 */ + 1 | 2 | 4 | 8 | 16 | 32 | 64 | 128, +/* 56 8 57 9 58 : 59 ; 60 < 61 = 62 > 63 ? */ + 1 | 2 | 4 | 8 | 16 | 32 | 64 | 0, +/* 64 @ 65 A 66 B 67 C 68 D 69 E 70 F 71 G */ + 1 | 2 | 4 | 8 | 16 | 32 | 64 | 128, +/* 72 H 73 I 74 J 75 K 76 L 77 M 78 N 79 O */ + 1 | 2 | 4 | 8 | 16 | 32 | 64 | 128, +/* 80 P 81 Q 82 R 83 S 84 T 85 U 86 V 87 W */ + 1 | 2 | 4 | 8 | 16 | 32 | 64 | 128, +/* 88 X 89 Y 90 Z 91 [ 92 \ 93 ] 94 ^ 95 _ */ + 1 | 2 | 4 | 8 | 16 | 32 | 64 | 128, +/* 96 ` 97 a 98 b 99 c 100 d 101 e 102 f 103 g */ + 1 | 2 | 4 | 8 | 16 | 32 | 64 | 128, +/* 104 h 105 i 106 j 107 k 108 l 109 m 110 n 111 o */ + 1 | 2 | 4 | 8 | 16 | 32 | 64 | 128, +/* 112 p 113 q 114 r 115 s 116 t 117 u 118 v 119 w */ + 1 | 2 | 4 | 8 | 16 | 32 | 64 | 128, +/* 120 x 121 y 122 z 123 { 124 | 125 } 126 ~ 127 del */ + 1 | 2 | 4 | 8 | 16 | 32 | 64 | 0, }; + +#undef T + +enum state + { s_dead = 1 /* important that this is > 0 */ + + , s_start_req_or_res + , s_res_or_resp_H + , s_start_res + , s_res_H + , s_res_HT + , s_res_HTT + , s_res_HTTP + , s_res_first_http_major + , s_res_http_major + , s_res_first_http_minor + , s_res_http_minor + , s_res_first_status_code + , s_res_status_code + , s_res_status_start + , s_res_status + , s_res_line_almost_done + + , s_start_req + + , s_req_method + , s_req_spaces_before_url + , s_req_schema + , s_req_schema_slash + , s_req_schema_slash_slash + , s_req_server_start + , s_req_server + , s_req_server_with_at + , s_req_path + , s_req_query_string_start + , s_req_query_string + , s_req_fragment_start + , s_req_fragment + , s_req_http_start + , s_req_http_H + , s_req_http_HT + , s_req_http_HTT + , s_req_http_HTTP + , s_req_first_http_major + , s_req_http_major + , s_req_first_http_minor + , s_req_http_minor + , s_req_line_almost_done + + , s_header_field_start + , s_header_field + , s_header_value_discard_ws + , s_header_value_discard_ws_almost_done + , s_header_value_discard_lws + , s_header_value_start + , s_header_value + , s_header_value_lws + + , s_header_almost_done + + , s_chunk_size_start + , s_chunk_size + , s_chunk_parameters + , s_chunk_size_almost_done + + , s_headers_almost_done + , s_headers_done + + /* Important: 's_headers_done' must be the last 'header' state. All + * states beyond this must be 'body' states. It is used for overflow + * checking. See the PARSING_HEADER() macro. + */ + + , s_chunk_data + , s_chunk_data_almost_done + , s_chunk_data_done + + , s_body_identity + , s_body_identity_eof + + , s_message_done + }; + + +#define PARSING_HEADER(state) (state <= s_headers_done) + + +enum header_states + { h_general = 0 + , h_C + , h_CO + , h_CON + + , h_matching_connection + , h_matching_proxy_connection + , h_matching_content_length + , h_matching_transfer_encoding + , h_matching_upgrade + + , h_connection + , h_content_length + , h_transfer_encoding + , h_upgrade + + , h_matching_transfer_encoding_chunked + , h_matching_connection_token_start + , h_matching_connection_keep_alive + , h_matching_connection_close + , h_matching_connection_upgrade + , h_matching_connection_token + + , h_transfer_encoding_chunked + , h_connection_keep_alive + , h_connection_close + , h_connection_upgrade + }; + +enum http_host_state + { + s_http_host_dead = 1 + , s_http_userinfo_start + , s_http_userinfo + , s_http_host_start + , s_http_host_v6_start + , s_http_host + , s_http_host_v6 + , s_http_host_v6_end + , s_http_host_v6_zone_start + , s_http_host_v6_zone + , s_http_host_port_start + , s_http_host_port +}; + +/* Macros for character classes; depends on strict-mode */ +#define CR '\r' +#define LF '\n' +#define LOWER(c) (unsigned char)(c | 0x20) +#define IS_ALPHA(c) (LOWER(c) >= 'a' && LOWER(c) <= 'z') +#define IS_NUM(c) ((c) >= '0' && (c) <= '9') +#define IS_ALPHANUM(c) (IS_ALPHA(c) || IS_NUM(c)) +#define IS_HEX(c) (IS_NUM(c) || (LOWER(c) >= 'a' && LOWER(c) <= 'f')) +#define IS_MARK(c) ((c) == '-' || (c) == '_' || (c) == '.' || \ + (c) == '!' || (c) == '~' || (c) == '*' || (c) == '\'' || (c) == '(' || \ + (c) == ')') +#define IS_USERINFO_CHAR(c) (IS_ALPHANUM(c) || IS_MARK(c) || (c) == '%' || \ + (c) == ';' || (c) == ':' || (c) == '&' || (c) == '=' || (c) == '+' || \ + (c) == '$' || (c) == ',') + +#define STRICT_TOKEN(c) (tokens[(unsigned char)c]) + +#if HTTP_PARSER_STRICT +#define TOKEN(c) (tokens[(unsigned char)c]) +#define IS_URL_CHAR(c) (BIT_AT(normal_url_char, (unsigned char)c)) +#define IS_HOST_CHAR(c) (IS_ALPHANUM(c) || (c) == '.' || (c) == '-') +#else +#define TOKEN(c) ((c == ' ') ? ' ' : tokens[(unsigned char)c]) +#define IS_URL_CHAR(c) \ + (BIT_AT(normal_url_char, (unsigned char)c) || ((c) & 0x80)) +#define IS_HOST_CHAR(c) \ + (IS_ALPHANUM(c) || (c) == '.' || (c) == '-' || (c) == '_') +#endif + +/** + * Verify that a char is a valid visible (printable) US-ASCII + * character or %x80-FF + **/ +#define IS_HEADER_CHAR(ch) \ + (ch == CR || ch == LF || ch == 9 || ((unsigned char)ch > 31 && ch != 127)) + +#define start_state (parser->type == HTTP_REQUEST ? s_start_req : s_start_res) + + +#if HTTP_PARSER_STRICT +# define STRICT_CHECK(cond) \ +do { \ + if (cond) { \ + SET_ERRNO(HPE_STRICT); \ + goto error; \ + } \ +} while (0) +# define NEW_MESSAGE() (http_should_keep_alive(parser) ? start_state : s_dead) +#else +# define STRICT_CHECK(cond) +# define NEW_MESSAGE() start_state +#endif + + +/* Map errno values to strings for human-readable output */ +#define HTTP_STRERROR_GEN(n, s) { "HPE_" #n, s }, +static struct { + const char *name; + const char *description; +} http_strerror_tab[] = { + HTTP_ERRNO_MAP(HTTP_STRERROR_GEN) +}; +#undef HTTP_STRERROR_GEN + +int http_message_needs_eof(const http_parser *parser); + +/* Our URL parser. + * + * This is designed to be shared by http_parser_execute() for URL validation, + * hence it has a state transition + byte-for-byte interface. In addition, it + * is meant to be embedded in http_parser_parse_url(), which does the dirty + * work of turning state transitions URL components for its API. + * + * This function should only be invoked with non-space characters. It is + * assumed that the caller cares about (and can detect) the transition between + * URL and non-URL states by looking for these. + */ +static enum state +parse_url_char(enum state s, const char ch) +{ + if (ch == ' ' || ch == '\r' || ch == '\n') { + return s_dead; + } + +#if HTTP_PARSER_STRICT + if (ch == '\t' || ch == '\f') { + return s_dead; + } +#endif + + switch (s) { + case s_req_spaces_before_url: + /* Proxied requests are followed by scheme of an absolute URI (alpha). + * All methods except CONNECT are followed by '/' or '*'. + */ + + if (ch == '/' || ch == '*') { + return s_req_path; + } + + if (IS_ALPHA(ch)) { + return s_req_schema; + } + + break; + + case s_req_schema: + if (IS_ALPHA(ch)) { + return s; + } + + if (ch == ':') { + return s_req_schema_slash; + } + + break; + + case s_req_schema_slash: + if (ch == '/') { + return s_req_schema_slash_slash; + } + + break; + + case s_req_schema_slash_slash: + if (ch == '/') { + return s_req_server_start; + } + + break; + + case s_req_server_with_at: + if (ch == '@') { + return s_dead; + } + + /* FALLTHROUGH */ + case s_req_server_start: + case s_req_server: + if (ch == '/') { + return s_req_path; + } + + if (ch == '?') { + return s_req_query_string_start; + } + + if (ch == '@') { + return s_req_server_with_at; + } + + if (IS_USERINFO_CHAR(ch) || ch == '[' || ch == ']') { + return s_req_server; + } + + break; + + case s_req_path: + if (IS_URL_CHAR(ch)) { + return s; + } + + switch (ch) { + case '?': + return s_req_query_string_start; + + case '#': + return s_req_fragment_start; + } + + break; + + case s_req_query_string_start: + case s_req_query_string: + if (IS_URL_CHAR(ch)) { + return s_req_query_string; + } + + switch (ch) { + case '?': + /* allow extra '?' in query string */ + return s_req_query_string; + + case '#': + return s_req_fragment_start; + } + + break; + + case s_req_fragment_start: + if (IS_URL_CHAR(ch)) { + return s_req_fragment; + } + + switch (ch) { + case '?': + return s_req_fragment; + + case '#': + return s; + } + + break; + + case s_req_fragment: + if (IS_URL_CHAR(ch)) { + return s; + } + + switch (ch) { + case '?': + case '#': + return s; + } + + break; + + default: + break; + } + + /* We should never fall out of the switch above unless there's an error */ + return s_dead; +} + +size_t http_parser_execute (http_parser *parser, + const http_parser_settings *settings, + const char *data, + size_t len) +{ + char c, ch; + int8_t unhex_val; + const char *p = data; + const char *header_field_mark = 0; + const char *header_value_mark = 0; + const char *url_mark = 0; + const char *body_mark = 0; + const char *status_mark = 0; + enum state p_state = (enum state) parser->state; + const unsigned int lenient = parser->lenient_http_headers; + + /* We're in an error state. Don't bother doing anything. */ + if (HTTP_PARSER_ERRNO(parser) != HPE_OK) { + return 0; + } + + if (len == 0) { + switch (CURRENT_STATE()) { + case s_body_identity_eof: + /* Use of CALLBACK_NOTIFY() here would erroneously return 1 byte read if + * we got paused. + */ + CALLBACK_NOTIFY_NOADVANCE(message_complete); + return 0; + + case s_dead: + case s_start_req_or_res: + case s_start_res: + case s_start_req: + return 0; + + default: + SET_ERRNO(HPE_INVALID_EOF_STATE); + return 1; + } + } + + + if (CURRENT_STATE() == s_header_field) + header_field_mark = data; + if (CURRENT_STATE() == s_header_value) + header_value_mark = data; + switch (CURRENT_STATE()) { + case s_req_path: + case s_req_schema: + case s_req_schema_slash: + case s_req_schema_slash_slash: + case s_req_server_start: + case s_req_server: + case s_req_server_with_at: + case s_req_query_string_start: + case s_req_query_string: + case s_req_fragment_start: + case s_req_fragment: + url_mark = data; + break; + case s_res_status: + status_mark = data; + break; + default: + break; + } + + for (p=data; p != data + len; p++) { + ch = *p; + + if (PARSING_HEADER(CURRENT_STATE())) + COUNT_HEADER_SIZE(1); + +reexecute: + switch (CURRENT_STATE()) { + + case s_dead: + /* this state is used after a 'Connection: close' message + * the parser will error out if it reads another message + */ + if (LIKELY(ch == CR || ch == LF)) + break; + + SET_ERRNO(HPE_CLOSED_CONNECTION); + goto error; + + case s_start_req_or_res: + { + if (ch == CR || ch == LF) + break; + parser->flags = 0; + parser->content_length = ULLONG_MAX; + + if (ch == 'H') { + UPDATE_STATE(s_res_or_resp_H); + + CALLBACK_NOTIFY(message_begin); + } else { + parser->type = HTTP_REQUEST; + UPDATE_STATE(s_start_req); + REEXECUTE(); + } + + break; + } + + case s_res_or_resp_H: + if (ch == 'T') { + parser->type = HTTP_RESPONSE; + UPDATE_STATE(s_res_HT); + } else { + if (UNLIKELY(ch != 'E')) { + SET_ERRNO(HPE_INVALID_CONSTANT); + goto error; + } + + parser->type = HTTP_REQUEST; + parser->method = HTTP_HEAD; + parser->index = 2; + UPDATE_STATE(s_req_method); + } + break; + + case s_start_res: + { + parser->flags = 0; + parser->content_length = ULLONG_MAX; + + switch (ch) { + case 'H': + UPDATE_STATE(s_res_H); + break; + + case CR: + case LF: + break; + + default: + SET_ERRNO(HPE_INVALID_CONSTANT); + goto error; + } + + CALLBACK_NOTIFY(message_begin); + break; + } + + case s_res_H: + STRICT_CHECK(ch != 'T'); + UPDATE_STATE(s_res_HT); + break; + + case s_res_HT: + STRICT_CHECK(ch != 'T'); + UPDATE_STATE(s_res_HTT); + break; + + case s_res_HTT: + STRICT_CHECK(ch != 'P'); + UPDATE_STATE(s_res_HTTP); + break; + + case s_res_HTTP: + STRICT_CHECK(ch != '/'); + UPDATE_STATE(s_res_first_http_major); + break; + + case s_res_first_http_major: + if (UNLIKELY(ch < '0' || ch > '9')) { + SET_ERRNO(HPE_INVALID_VERSION); + goto error; + } + + parser->http_major = ch - '0'; + UPDATE_STATE(s_res_http_major); + break; + + /* major HTTP version or dot */ + case s_res_http_major: + { + if (ch == '.') { + UPDATE_STATE(s_res_first_http_minor); + break; + } + + if (!IS_NUM(ch)) { + SET_ERRNO(HPE_INVALID_VERSION); + goto error; + } + + parser->http_major *= 10; + parser->http_major += ch - '0'; + + if (UNLIKELY(parser->http_major > 999)) { + SET_ERRNO(HPE_INVALID_VERSION); + goto error; + } + + break; + } + + /* first digit of minor HTTP version */ + case s_res_first_http_minor: + if (UNLIKELY(!IS_NUM(ch))) { + SET_ERRNO(HPE_INVALID_VERSION); + goto error; + } + + parser->http_minor = ch - '0'; + UPDATE_STATE(s_res_http_minor); + break; + + /* minor HTTP version or end of request line */ + case s_res_http_minor: + { + if (ch == ' ') { + UPDATE_STATE(s_res_first_status_code); + break; + } + + if (UNLIKELY(!IS_NUM(ch))) { + SET_ERRNO(HPE_INVALID_VERSION); + goto error; + } + + parser->http_minor *= 10; + parser->http_minor += ch - '0'; + + if (UNLIKELY(parser->http_minor > 999)) { + SET_ERRNO(HPE_INVALID_VERSION); + goto error; + } + + break; + } + + case s_res_first_status_code: + { + if (!IS_NUM(ch)) { + if (ch == ' ') { + break; + } + + SET_ERRNO(HPE_INVALID_STATUS); + goto error; + } + parser->status_code = ch - '0'; + UPDATE_STATE(s_res_status_code); + break; + } + + case s_res_status_code: + { + if (!IS_NUM(ch)) { + switch (ch) { + case ' ': + UPDATE_STATE(s_res_status_start); + break; + case CR: + UPDATE_STATE(s_res_line_almost_done); + break; + case LF: + UPDATE_STATE(s_header_field_start); + break; + default: + SET_ERRNO(HPE_INVALID_STATUS); + goto error; + } + break; + } + + parser->status_code *= 10; + parser->status_code += ch - '0'; + + if (UNLIKELY(parser->status_code > 999)) { + SET_ERRNO(HPE_INVALID_STATUS); + goto error; + } + + break; + } + + case s_res_status_start: + { + if (ch == CR) { + UPDATE_STATE(s_res_line_almost_done); + break; + } + + if (ch == LF) { + UPDATE_STATE(s_header_field_start); + break; + } + + MARK(status); + UPDATE_STATE(s_res_status); + parser->index = 0; + break; + } + + case s_res_status: + if (ch == CR) { + UPDATE_STATE(s_res_line_almost_done); + CALLBACK_DATA(status); + break; + } + + if (ch == LF) { + UPDATE_STATE(s_header_field_start); + CALLBACK_DATA(status); + break; + } + + break; + + case s_res_line_almost_done: + STRICT_CHECK(ch != LF); + UPDATE_STATE(s_header_field_start); + break; + + case s_start_req: + { + if (ch == CR || ch == LF) + break; + parser->flags = 0; + parser->content_length = ULLONG_MAX; + + if (UNLIKELY(!IS_ALPHA(ch))) { + SET_ERRNO(HPE_INVALID_METHOD); + goto error; + } + + parser->method = (enum http_method) 0; + parser->index = 1; + switch (ch) { + case 'A': parser->method = HTTP_ACL; break; + case 'B': parser->method = HTTP_BIND; break; + case 'C': parser->method = HTTP_CONNECT; /* or COPY, CHECKOUT */ break; + case 'D': parser->method = HTTP_DELETE; break; + case 'G': parser->method = HTTP_GET; break; + case 'H': parser->method = HTTP_HEAD; break; + case 'L': parser->method = HTTP_LOCK; /* or LINK */ break; + case 'M': parser->method = HTTP_MKCOL; /* or MOVE, MKACTIVITY, MERGE, M-SEARCH, MKCALENDAR */ break; + case 'N': parser->method = HTTP_NOTIFY; break; + case 'O': parser->method = HTTP_OPTIONS; break; + case 'P': parser->method = HTTP_POST; + /* or PROPFIND|PROPPATCH|PUT|PATCH|PURGE */ + break; + case 'R': parser->method = HTTP_REPORT; /* or REBIND */ break; + case 'S': parser->method = HTTP_SUBSCRIBE; /* or SEARCH */ break; + case 'T': parser->method = HTTP_TRACE; break; + case 'U': parser->method = HTTP_UNLOCK; /* or UNSUBSCRIBE, UNBIND, UNLINK */ break; + default: + SET_ERRNO(HPE_INVALID_METHOD); + goto error; + } + UPDATE_STATE(s_req_method); + + CALLBACK_NOTIFY(message_begin); + + break; + } + + case s_req_method: + { + const char *matcher; + if (UNLIKELY(ch == '\0')) { + SET_ERRNO(HPE_INVALID_METHOD); + goto error; + } + + matcher = method_strings[parser->method]; + if (ch == ' ' && matcher[parser->index] == '\0') { + UPDATE_STATE(s_req_spaces_before_url); + } else if (ch == matcher[parser->index]) { + ; /* nada */ + } else if (IS_ALPHA(ch)) { + + switch (parser->method << 16 | parser->index << 8 | ch) { +#define XX(meth, pos, ch, new_meth) \ + case (HTTP_##meth << 16 | pos << 8 | ch): \ + parser->method = HTTP_##new_meth; break; + + XX(POST, 1, 'U', PUT) + XX(POST, 1, 'A', PATCH) + XX(CONNECT, 1, 'H', CHECKOUT) + XX(CONNECT, 2, 'P', COPY) + XX(MKCOL, 1, 'O', MOVE) + XX(MKCOL, 1, 'E', MERGE) + XX(MKCOL, 2, 'A', MKACTIVITY) + XX(MKCOL, 3, 'A', MKCALENDAR) + XX(SUBSCRIBE, 1, 'E', SEARCH) + XX(REPORT, 2, 'B', REBIND) + XX(POST, 1, 'R', PROPFIND) + XX(PROPFIND, 4, 'P', PROPPATCH) + XX(PUT, 2, 'R', PURGE) + XX(LOCK, 1, 'I', LINK) + XX(UNLOCK, 2, 'S', UNSUBSCRIBE) + XX(UNLOCK, 2, 'B', UNBIND) + XX(UNLOCK, 3, 'I', UNLINK) +#undef XX + + default: + SET_ERRNO(HPE_INVALID_METHOD); + goto error; + } + } else if (ch == '-' && + parser->index == 1 && + parser->method == HTTP_MKCOL) { + parser->method = HTTP_MSEARCH; + } else { + SET_ERRNO(HPE_INVALID_METHOD); + goto error; + } + + ++parser->index; + break; + } + + case s_req_spaces_before_url: + { + if (ch == ' ') break; + + MARK(url); + if (parser->method == HTTP_CONNECT) { + UPDATE_STATE(s_req_server_start); + } + + UPDATE_STATE(parse_url_char(CURRENT_STATE(), ch)); + if (UNLIKELY(CURRENT_STATE() == s_dead)) { + SET_ERRNO(HPE_INVALID_URL); + goto error; + } + + break; + } + + case s_req_schema: + case s_req_schema_slash: + case s_req_schema_slash_slash: + case s_req_server_start: + { + switch (ch) { + /* No whitespace allowed here */ + case ' ': + case CR: + case LF: + SET_ERRNO(HPE_INVALID_URL); + goto error; + default: + UPDATE_STATE(parse_url_char(CURRENT_STATE(), ch)); + if (UNLIKELY(CURRENT_STATE() == s_dead)) { + SET_ERRNO(HPE_INVALID_URL); + goto error; + } + } + + break; + } + + case s_req_server: + case s_req_server_with_at: + case s_req_path: + case s_req_query_string_start: + case s_req_query_string: + case s_req_fragment_start: + case s_req_fragment: + { + switch (ch) { + case ' ': + UPDATE_STATE(s_req_http_start); + CALLBACK_DATA(url); + break; + case CR: + case LF: + parser->http_major = 0; + parser->http_minor = 9; + UPDATE_STATE((ch == CR) ? + s_req_line_almost_done : + s_header_field_start); + CALLBACK_DATA(url); + break; + default: + UPDATE_STATE(parse_url_char(CURRENT_STATE(), ch)); + if (UNLIKELY(CURRENT_STATE() == s_dead)) { + SET_ERRNO(HPE_INVALID_URL); + goto error; + } + } + break; + } + + case s_req_http_start: + switch (ch) { + case 'H': + UPDATE_STATE(s_req_http_H); + break; + case ' ': + break; + default: + SET_ERRNO(HPE_INVALID_CONSTANT); + goto error; + } + break; + + case s_req_http_H: + STRICT_CHECK(ch != 'T'); + UPDATE_STATE(s_req_http_HT); + break; + + case s_req_http_HT: + STRICT_CHECK(ch != 'T'); + UPDATE_STATE(s_req_http_HTT); + break; + + case s_req_http_HTT: + STRICT_CHECK(ch != 'P'); + UPDATE_STATE(s_req_http_HTTP); + break; + + case s_req_http_HTTP: + STRICT_CHECK(ch != '/'); + UPDATE_STATE(s_req_first_http_major); + break; + + /* first digit of major HTTP version */ + case s_req_first_http_major: + if (UNLIKELY(ch < '1' || ch > '9')) { + SET_ERRNO(HPE_INVALID_VERSION); + goto error; + } + + parser->http_major = ch - '0'; + UPDATE_STATE(s_req_http_major); + break; + + /* major HTTP version or dot */ + case s_req_http_major: + { + if (ch == '.') { + UPDATE_STATE(s_req_first_http_minor); + break; + } + + if (UNLIKELY(!IS_NUM(ch))) { + SET_ERRNO(HPE_INVALID_VERSION); + goto error; + } + + parser->http_major *= 10; + parser->http_major += ch - '0'; + + if (UNLIKELY(parser->http_major > 999)) { + SET_ERRNO(HPE_INVALID_VERSION); + goto error; + } + + break; + } + + /* first digit of minor HTTP version */ + case s_req_first_http_minor: + if (UNLIKELY(!IS_NUM(ch))) { + SET_ERRNO(HPE_INVALID_VERSION); + goto error; + } + + parser->http_minor = ch - '0'; + UPDATE_STATE(s_req_http_minor); + break; + + /* minor HTTP version or end of request line */ + case s_req_http_minor: + { + if (ch == CR) { + UPDATE_STATE(s_req_line_almost_done); + break; + } + + if (ch == LF) { + UPDATE_STATE(s_header_field_start); + break; + } + + /* XXX allow spaces after digit? */ + + if (UNLIKELY(!IS_NUM(ch))) { + SET_ERRNO(HPE_INVALID_VERSION); + goto error; + } + + parser->http_minor *= 10; + parser->http_minor += ch - '0'; + + if (UNLIKELY(parser->http_minor > 999)) { + SET_ERRNO(HPE_INVALID_VERSION); + goto error; + } + + break; + } + + /* end of request line */ + case s_req_line_almost_done: + { + if (UNLIKELY(ch != LF)) { + SET_ERRNO(HPE_LF_EXPECTED); + goto error; + } + + UPDATE_STATE(s_header_field_start); + break; + } + + case s_header_field_start: + { + if (ch == CR) { + UPDATE_STATE(s_headers_almost_done); + break; + } + + if (ch == LF) { + /* they might be just sending \n instead of \r\n so this would be + * the second \n to denote the end of headers*/ + UPDATE_STATE(s_headers_almost_done); + REEXECUTE(); + } + + c = TOKEN(ch); + + if (UNLIKELY(!c)) { + SET_ERRNO(HPE_INVALID_HEADER_TOKEN); + goto error; + } + + MARK(header_field); + + parser->index = 0; + UPDATE_STATE(s_header_field); + + switch (c) { + case 'c': + parser->header_state = h_C; + break; + + case 'p': + parser->header_state = h_matching_proxy_connection; + break; + + case 't': + parser->header_state = h_matching_transfer_encoding; + break; + + case 'u': + parser->header_state = h_matching_upgrade; + break; + + default: + parser->header_state = h_general; + break; + } + break; + } + + case s_header_field: + { + const char* start = p; + for (; p != data + len; p++) { + ch = *p; + c = TOKEN(ch); + + if (!c) + break; + + switch (parser->header_state) { + case h_general: + break; + + case h_C: + parser->index++; + parser->header_state = (c == 'o' ? h_CO : h_general); + break; + + case h_CO: + parser->index++; + parser->header_state = (c == 'n' ? h_CON : h_general); + break; + + case h_CON: + parser->index++; + switch (c) { + case 'n': + parser->header_state = h_matching_connection; + break; + case 't': + parser->header_state = h_matching_content_length; + break; + default: + parser->header_state = h_general; + break; + } + break; + + /* connection */ + + case h_matching_connection: + parser->index++; + if (parser->index > sizeof(CONNECTION)-1 + || c != CONNECTION[parser->index]) { + parser->header_state = h_general; + } else if (parser->index == sizeof(CONNECTION)-2) { + parser->header_state = h_connection; + } + break; + + /* proxy-connection */ + + case h_matching_proxy_connection: + parser->index++; + if (parser->index > sizeof(PROXY_CONNECTION)-1 + || c != PROXY_CONNECTION[parser->index]) { + parser->header_state = h_general; + } else if (parser->index == sizeof(PROXY_CONNECTION)-2) { + parser->header_state = h_connection; + } + break; + + /* content-length */ + + case h_matching_content_length: + parser->index++; + if (parser->index > sizeof(CONTENT_LENGTH)-1 + || c != CONTENT_LENGTH[parser->index]) { + parser->header_state = h_general; + } else if (parser->index == sizeof(CONTENT_LENGTH)-2) { + parser->header_state = h_content_length; + } + break; + + /* transfer-encoding */ + + case h_matching_transfer_encoding: + parser->index++; + if (parser->index > sizeof(TRANSFER_ENCODING)-1 + || c != TRANSFER_ENCODING[parser->index]) { + parser->header_state = h_general; + } else if (parser->index == sizeof(TRANSFER_ENCODING)-2) { + parser->header_state = h_transfer_encoding; + } + break; + + /* upgrade */ + + case h_matching_upgrade: + parser->index++; + if (parser->index > sizeof(UPGRADE)-1 + || c != UPGRADE[parser->index]) { + parser->header_state = h_general; + } else if (parser->index == sizeof(UPGRADE)-2) { + parser->header_state = h_upgrade; + } + break; + + case h_connection: + case h_content_length: + case h_transfer_encoding: + case h_upgrade: + if (ch != ' ') parser->header_state = h_general; + break; + + default: + assert(0 && "Unknown header_state"); + break; + } + } + + COUNT_HEADER_SIZE(p - start); + + if (p == data + len) { + --p; + break; + } + + if (ch == ':') { + UPDATE_STATE(s_header_value_discard_ws); + CALLBACK_DATA(header_field); + break; + } + + SET_ERRNO(HPE_INVALID_HEADER_TOKEN); + goto error; + } + + case s_header_value_discard_ws: + if (ch == ' ' || ch == '\t') break; + + if (ch == CR) { + UPDATE_STATE(s_header_value_discard_ws_almost_done); + break; + } + + if (ch == LF) { + UPDATE_STATE(s_header_value_discard_lws); + break; + } + + /* FALLTHROUGH */ + + case s_header_value_start: + { + MARK(header_value); + + UPDATE_STATE(s_header_value); + parser->index = 0; + + c = LOWER(ch); + + switch (parser->header_state) { + case h_upgrade: + parser->flags |= F_UPGRADE; + parser->header_state = h_general; + break; + + case h_transfer_encoding: + /* looking for 'Transfer-Encoding: chunked' */ + if ('c' == c) { + parser->header_state = h_matching_transfer_encoding_chunked; + } else { + parser->header_state = h_general; + } + break; + + case h_content_length: + if (UNLIKELY(!IS_NUM(ch))) { + SET_ERRNO(HPE_INVALID_CONTENT_LENGTH); + goto error; + } + + if (parser->flags & F_CONTENTLENGTH) { + SET_ERRNO(HPE_UNEXPECTED_CONTENT_LENGTH); + goto error; + } + + parser->flags |= F_CONTENTLENGTH; + parser->content_length = ch - '0'; + break; + + case h_connection: + /* looking for 'Connection: keep-alive' */ + if (c == 'k') { + parser->header_state = h_matching_connection_keep_alive; + /* looking for 'Connection: close' */ + } else if (c == 'c') { + parser->header_state = h_matching_connection_close; + } else if (c == 'u') { + parser->header_state = h_matching_connection_upgrade; + } else { + parser->header_state = h_matching_connection_token; + } + break; + + /* Multi-value `Connection` header */ + case h_matching_connection_token_start: + break; + + default: + parser->header_state = h_general; + break; + } + break; + } + + case s_header_value: + { + const char* start = p; + enum header_states h_state = (enum header_states) parser->header_state; + for (; p != data + len; p++) { + ch = *p; + if (ch == CR) { + UPDATE_STATE(s_header_almost_done); + parser->header_state = h_state; + CALLBACK_DATA(header_value); + break; + } + + if (ch == LF) { + UPDATE_STATE(s_header_almost_done); + COUNT_HEADER_SIZE(p - start); + parser->header_state = h_state; + CALLBACK_DATA_NOADVANCE(header_value); + REEXECUTE(); + } + + if (!lenient && !IS_HEADER_CHAR(ch)) { + SET_ERRNO(HPE_INVALID_HEADER_TOKEN); + goto error; + } + + c = LOWER(ch); + + switch (h_state) { + case h_general: + { + const char* p_cr; + const char* p_lf; + size_t limit = data + len - p; + + limit = MIN(limit, HTTP_MAX_HEADER_SIZE); + + p_cr = (const char*) memchr(p, CR, limit); + p_lf = (const char*) memchr(p, LF, limit); + if (p_cr != NULL) { + if (p_lf != NULL && p_cr >= p_lf) + p = p_lf; + else + p = p_cr; + } else if (UNLIKELY(p_lf != NULL)) { + p = p_lf; + } else { + p = data + len; + } + --p; + + break; + } + + case h_connection: + case h_transfer_encoding: + assert(0 && "Shouldn't get here."); + break; + + case h_content_length: + { + uint64_t t; + + if (ch == ' ') break; + + if (UNLIKELY(!IS_NUM(ch))) { + SET_ERRNO(HPE_INVALID_CONTENT_LENGTH); + parser->header_state = h_state; + goto error; + } + + t = parser->content_length; + t *= 10; + t += ch - '0'; + + /* Overflow? Test against a conservative limit for simplicity. */ + if (UNLIKELY((ULLONG_MAX - 10) / 10 < parser->content_length)) { + SET_ERRNO(HPE_INVALID_CONTENT_LENGTH); + parser->header_state = h_state; + goto error; + } + + parser->content_length = t; + break; + } + + /* Transfer-Encoding: chunked */ + case h_matching_transfer_encoding_chunked: + parser->index++; + if (parser->index > sizeof(CHUNKED)-1 + || c != CHUNKED[parser->index]) { + h_state = h_general; + } else if (parser->index == sizeof(CHUNKED)-2) { + h_state = h_transfer_encoding_chunked; + } + break; + + case h_matching_connection_token_start: + /* looking for 'Connection: keep-alive' */ + if (c == 'k') { + h_state = h_matching_connection_keep_alive; + /* looking for 'Connection: close' */ + } else if (c == 'c') { + h_state = h_matching_connection_close; + } else if (c == 'u') { + h_state = h_matching_connection_upgrade; + } else if (STRICT_TOKEN(c)) { + h_state = h_matching_connection_token; + } else if (c == ' ' || c == '\t') { + /* Skip lws */ + } else { + h_state = h_general; + } + break; + + /* looking for 'Connection: keep-alive' */ + case h_matching_connection_keep_alive: + parser->index++; + if (parser->index > sizeof(KEEP_ALIVE)-1 + || c != KEEP_ALIVE[parser->index]) { + h_state = h_matching_connection_token; + } else if (parser->index == sizeof(KEEP_ALIVE)-2) { + h_state = h_connection_keep_alive; + } + break; + + /* looking for 'Connection: close' */ + case h_matching_connection_close: + parser->index++; + if (parser->index > sizeof(CLOSE)-1 || c != CLOSE[parser->index]) { + h_state = h_matching_connection_token; + } else if (parser->index == sizeof(CLOSE)-2) { + h_state = h_connection_close; + } + break; + + /* looking for 'Connection: upgrade' */ + case h_matching_connection_upgrade: + parser->index++; + if (parser->index > sizeof(UPGRADE) - 1 || + c != UPGRADE[parser->index]) { + h_state = h_matching_connection_token; + } else if (parser->index == sizeof(UPGRADE)-2) { + h_state = h_connection_upgrade; + } + break; + + case h_matching_connection_token: + if (ch == ',') { + h_state = h_matching_connection_token_start; + parser->index = 0; + } + break; + + case h_transfer_encoding_chunked: + if (ch != ' ') h_state = h_general; + break; + + case h_connection_keep_alive: + case h_connection_close: + case h_connection_upgrade: + if (ch == ',') { + if (h_state == h_connection_keep_alive) { + parser->flags |= F_CONNECTION_KEEP_ALIVE; + } else if (h_state == h_connection_close) { + parser->flags |= F_CONNECTION_CLOSE; + } else if (h_state == h_connection_upgrade) { + parser->flags |= F_CONNECTION_UPGRADE; + } + h_state = h_matching_connection_token_start; + parser->index = 0; + } else if (ch != ' ') { + h_state = h_matching_connection_token; + } + break; + + default: + UPDATE_STATE(s_header_value); + h_state = h_general; + break; + } + } + parser->header_state = h_state; + + COUNT_HEADER_SIZE(p - start); + + if (p == data + len) + --p; + break; + } + + case s_header_almost_done: + { + if (UNLIKELY(ch != LF)) { + SET_ERRNO(HPE_LF_EXPECTED); + goto error; + } + + UPDATE_STATE(s_header_value_lws); + break; + } + + case s_header_value_lws: + { + if (ch == ' ' || ch == '\t') { + UPDATE_STATE(s_header_value_start); + REEXECUTE(); + } + + /* finished the header */ + switch (parser->header_state) { + case h_connection_keep_alive: + parser->flags |= F_CONNECTION_KEEP_ALIVE; + break; + case h_connection_close: + parser->flags |= F_CONNECTION_CLOSE; + break; + case h_transfer_encoding_chunked: + parser->flags |= F_CHUNKED; + break; + case h_connection_upgrade: + parser->flags |= F_CONNECTION_UPGRADE; + break; + default: + break; + } + + UPDATE_STATE(s_header_field_start); + REEXECUTE(); + } + + case s_header_value_discard_ws_almost_done: + { + STRICT_CHECK(ch != LF); + UPDATE_STATE(s_header_value_discard_lws); + break; + } + + case s_header_value_discard_lws: + { + if (ch == ' ' || ch == '\t') { + UPDATE_STATE(s_header_value_discard_ws); + break; + } else { + switch (parser->header_state) { + case h_connection_keep_alive: + parser->flags |= F_CONNECTION_KEEP_ALIVE; + break; + case h_connection_close: + parser->flags |= F_CONNECTION_CLOSE; + break; + case h_connection_upgrade: + parser->flags |= F_CONNECTION_UPGRADE; + break; + case h_transfer_encoding_chunked: + parser->flags |= F_CHUNKED; + break; + default: + break; + } + + /* header value was empty */ + MARK(header_value); + UPDATE_STATE(s_header_field_start); + CALLBACK_DATA_NOADVANCE(header_value); + REEXECUTE(); + } + } + + case s_headers_almost_done: + { + STRICT_CHECK(ch != LF); + + if (parser->flags & F_TRAILING) { + /* End of a chunked request */ + UPDATE_STATE(s_message_done); + CALLBACK_NOTIFY_NOADVANCE(chunk_complete); + REEXECUTE(); + } + + /* Cannot use chunked encoding and a content-length header together + per the HTTP specification. */ + if ((parser->flags & F_CHUNKED) && + (parser->flags & F_CONTENTLENGTH)) { + SET_ERRNO(HPE_UNEXPECTED_CONTENT_LENGTH); + goto error; + } + + UPDATE_STATE(s_headers_done); + + /* Set this here so that on_headers_complete() callbacks can see it */ + parser->upgrade = + ((parser->flags & (F_UPGRADE | F_CONNECTION_UPGRADE)) == + (F_UPGRADE | F_CONNECTION_UPGRADE) || + parser->method == HTTP_CONNECT); + + /* Here we call the headers_complete callback. This is somewhat + * different than other callbacks because if the user returns 1, we + * will interpret that as saying that this message has no body. This + * is needed for the annoying case of recieving a response to a HEAD + * request. + * + * We'd like to use CALLBACK_NOTIFY_NOADVANCE() here but we cannot, so + * we have to simulate it by handling a change in errno below. + */ + if (settings->on_headers_complete) { + switch (settings->on_headers_complete(parser)) { + case 0: + break; + + case 2: + parser->upgrade = 1; + + case 1: + parser->flags |= F_SKIPBODY; + break; + + default: + SET_ERRNO(HPE_CB_headers_complete); + RETURN(p - data); /* Error */ + } + } + + if (HTTP_PARSER_ERRNO(parser) != HPE_OK) { + RETURN(p - data); + } + + REEXECUTE(); + } + + case s_headers_done: + { + int hasBody; + STRICT_CHECK(ch != LF); + + parser->nread = 0; + + hasBody = parser->flags & F_CHUNKED || + (parser->content_length > 0 && parser->content_length != ULLONG_MAX); + if (parser->upgrade && (parser->method == HTTP_CONNECT || + (parser->flags & F_SKIPBODY) || !hasBody)) { + /* Exit, the rest of the message is in a different protocol. */ + UPDATE_STATE(NEW_MESSAGE()); + CALLBACK_NOTIFY(message_complete); + RETURN((p - data) + 1); + } + + if (parser->flags & F_SKIPBODY) { + UPDATE_STATE(NEW_MESSAGE()); + CALLBACK_NOTIFY(message_complete); + } else if (parser->flags & F_CHUNKED) { + /* chunked encoding - ignore Content-Length header */ + UPDATE_STATE(s_chunk_size_start); + } else { + if (parser->content_length == 0) { + /* Content-Length header given but zero: Content-Length: 0\r\n */ + UPDATE_STATE(NEW_MESSAGE()); + CALLBACK_NOTIFY(message_complete); + } else if (parser->content_length != ULLONG_MAX) { + /* Content-Length header given and non-zero */ + UPDATE_STATE(s_body_identity); + } else { + if (!http_message_needs_eof(parser)) { + /* Assume content-length 0 - read the next */ + UPDATE_STATE(NEW_MESSAGE()); + CALLBACK_NOTIFY(message_complete); + } else { + /* Read body until EOF */ + UPDATE_STATE(s_body_identity_eof); + } + } + } + + break; + } + + case s_body_identity: + { + uint64_t to_read = MIN(parser->content_length, + (uint64_t) ((data + len) - p)); + + assert(parser->content_length != 0 + && parser->content_length != ULLONG_MAX); + + /* The difference between advancing content_length and p is because + * the latter will automaticaly advance on the next loop iteration. + * Further, if content_length ends up at 0, we want to see the last + * byte again for our message complete callback. + */ + MARK(body); + parser->content_length -= to_read; + p += to_read - 1; + + if (parser->content_length == 0) { + UPDATE_STATE(s_message_done); + + /* Mimic CALLBACK_DATA_NOADVANCE() but with one extra byte. + * + * The alternative to doing this is to wait for the next byte to + * trigger the data callback, just as in every other case. The + * problem with this is that this makes it difficult for the test + * harness to distinguish between complete-on-EOF and + * complete-on-length. It's not clear that this distinction is + * important for applications, but let's keep it for now. + */ + CALLBACK_DATA_(body, p - body_mark + 1, p - data); + REEXECUTE(); + } + + break; + } + + /* read until EOF */ + case s_body_identity_eof: + MARK(body); + p = data + len - 1; + + break; + + case s_message_done: + UPDATE_STATE(NEW_MESSAGE()); + CALLBACK_NOTIFY(message_complete); + if (parser->upgrade) { + /* Exit, the rest of the message is in a different protocol. */ + RETURN((p - data) + 1); + } + break; + + case s_chunk_size_start: + { + assert(parser->nread == 1); + assert(parser->flags & F_CHUNKED); + + unhex_val = unhex[(unsigned char)ch]; + if (UNLIKELY(unhex_val == -1)) { + SET_ERRNO(HPE_INVALID_CHUNK_SIZE); + goto error; + } + + parser->content_length = unhex_val; + UPDATE_STATE(s_chunk_size); + break; + } + + case s_chunk_size: + { + uint64_t t; + + assert(parser->flags & F_CHUNKED); + + if (ch == CR) { + UPDATE_STATE(s_chunk_size_almost_done); + break; + } + + unhex_val = unhex[(unsigned char)ch]; + + if (unhex_val == -1) { + if (ch == ';' || ch == ' ') { + UPDATE_STATE(s_chunk_parameters); + break; + } + + SET_ERRNO(HPE_INVALID_CHUNK_SIZE); + goto error; + } + + t = parser->content_length; + t *= 16; + t += unhex_val; + + /* Overflow? Test against a conservative limit for simplicity. */ + if (UNLIKELY((ULLONG_MAX - 16) / 16 < parser->content_length)) { + SET_ERRNO(HPE_INVALID_CONTENT_LENGTH); + goto error; + } + + parser->content_length = t; + break; + } + + case s_chunk_parameters: + { + assert(parser->flags & F_CHUNKED); + /* just ignore this shit. TODO check for overflow */ + if (ch == CR) { + UPDATE_STATE(s_chunk_size_almost_done); + break; + } + break; + } + + case s_chunk_size_almost_done: + { + assert(parser->flags & F_CHUNKED); + STRICT_CHECK(ch != LF); + + parser->nread = 0; + + if (parser->content_length == 0) { + parser->flags |= F_TRAILING; + UPDATE_STATE(s_header_field_start); + } else { + UPDATE_STATE(s_chunk_data); + } + CALLBACK_NOTIFY(chunk_header); + break; + } + + case s_chunk_data: + { + uint64_t to_read = MIN(parser->content_length, + (uint64_t) ((data + len) - p)); + + assert(parser->flags & F_CHUNKED); + assert(parser->content_length != 0 + && parser->content_length != ULLONG_MAX); + + /* See the explanation in s_body_identity for why the content + * length and data pointers are managed this way. + */ + MARK(body); + parser->content_length -= to_read; + p += to_read - 1; + + if (parser->content_length == 0) { + UPDATE_STATE(s_chunk_data_almost_done); + } + + break; + } + + case s_chunk_data_almost_done: + assert(parser->flags & F_CHUNKED); + assert(parser->content_length == 0); + STRICT_CHECK(ch != CR); + UPDATE_STATE(s_chunk_data_done); + CALLBACK_DATA(body); + break; + + case s_chunk_data_done: + assert(parser->flags & F_CHUNKED); + STRICT_CHECK(ch != LF); + parser->nread = 0; + UPDATE_STATE(s_chunk_size_start); + CALLBACK_NOTIFY(chunk_complete); + break; + + default: + assert(0 && "unhandled state"); + SET_ERRNO(HPE_INVALID_INTERNAL_STATE); + goto error; + } + } + + /* Run callbacks for any marks that we have leftover after we ran our of + * bytes. There should be at most one of these set, so it's OK to invoke + * them in series (unset marks will not result in callbacks). + * + * We use the NOADVANCE() variety of callbacks here because 'p' has already + * overflowed 'data' and this allows us to correct for the off-by-one that + * we'd otherwise have (since CALLBACK_DATA() is meant to be run with a 'p' + * value that's in-bounds). + */ + + assert(((header_field_mark ? 1 : 0) + + (header_value_mark ? 1 : 0) + + (url_mark ? 1 : 0) + + (body_mark ? 1 : 0) + + (status_mark ? 1 : 0)) <= 1); + + CALLBACK_DATA_NOADVANCE(header_field); + CALLBACK_DATA_NOADVANCE(header_value); + CALLBACK_DATA_NOADVANCE(url); + CALLBACK_DATA_NOADVANCE(body); + CALLBACK_DATA_NOADVANCE(status); + + RETURN(len); + +error: + if (HTTP_PARSER_ERRNO(parser) == HPE_OK) { + SET_ERRNO(HPE_UNKNOWN); + } + + RETURN(p - data); +} + + +/* Does the parser need to see an EOF to find the end of the message? */ +int +http_message_needs_eof (const http_parser *parser) +{ + if (parser->type == HTTP_REQUEST) { + return 0; + } + + /* See RFC 2616 section 4.4 */ + if (parser->status_code / 100 == 1 || /* 1xx e.g. Continue */ + parser->status_code == 204 || /* No Content */ + parser->status_code == 304 || /* Not Modified */ + parser->flags & F_SKIPBODY) { /* response to a HEAD request */ + return 0; + } + + if ((parser->flags & F_CHUNKED) || parser->content_length != ULLONG_MAX) { + return 0; + } + + return 1; +} + + +int +http_should_keep_alive (const http_parser *parser) +{ + if (parser->http_major > 0 && parser->http_minor > 0) { + /* HTTP/1.1 */ + if (parser->flags & F_CONNECTION_CLOSE) { + return 0; + } + } else { + /* HTTP/1.0 or earlier */ + if (!(parser->flags & F_CONNECTION_KEEP_ALIVE)) { + return 0; + } + } + + return !http_message_needs_eof(parser); +} + + +const char * +http_method_str (enum http_method m) +{ + return ELEM_AT(method_strings, m, ""); +} + + +void +http_parser_init (http_parser *parser, enum http_parser_type t) +{ + void *data = parser->data; /* preserve application data */ + memset(parser, 0, sizeof(*parser)); + parser->data = data; + parser->type = t; + parser->state = (t == HTTP_REQUEST ? s_start_req : (t == HTTP_RESPONSE ? s_start_res : s_start_req_or_res)); + parser->http_errno = HPE_OK; +} + +void +http_parser_settings_init(http_parser_settings *settings) +{ + memset(settings, 0, sizeof(*settings)); +} + +const char * +http_errno_name(enum http_errno err) { + assert(((size_t) err) < ARRAY_SIZE(http_strerror_tab)); + return http_strerror_tab[err].name; +} + +const char * +http_errno_description(enum http_errno err) { + assert(((size_t) err) < ARRAY_SIZE(http_strerror_tab)); + return http_strerror_tab[err].description; +} + +static enum http_host_state +http_parse_host_char(enum http_host_state s, const char ch) { + switch(s) { + case s_http_userinfo: + case s_http_userinfo_start: + if (ch == '@') { + return s_http_host_start; + } + + if (IS_USERINFO_CHAR(ch)) { + return s_http_userinfo; + } + break; + + case s_http_host_start: + if (ch == '[') { + return s_http_host_v6_start; + } + + if (IS_HOST_CHAR(ch)) { + return s_http_host; + } + + break; + + case s_http_host: + if (IS_HOST_CHAR(ch)) { + return s_http_host; + } + + /* FALLTHROUGH */ + case s_http_host_v6_end: + if (ch == ':') { + return s_http_host_port_start; + } + + break; + + case s_http_host_v6: + if (ch == ']') { + return s_http_host_v6_end; + } + + /* FALLTHROUGH */ + case s_http_host_v6_start: + if (IS_HEX(ch) || ch == ':' || ch == '.') { + return s_http_host_v6; + } + + if (s == s_http_host_v6 && ch == '%') { + return s_http_host_v6_zone_start; + } + break; + + case s_http_host_v6_zone: + if (ch == ']') { + return s_http_host_v6_end; + } + + /* FALLTHROUGH */ + case s_http_host_v6_zone_start: + /* RFC 6874 Zone ID consists of 1*( unreserved / pct-encoded) */ + if (IS_ALPHANUM(ch) || ch == '%' || ch == '.' || ch == '-' || ch == '_' || + ch == '~') { + return s_http_host_v6_zone; + } + break; + + case s_http_host_port: + case s_http_host_port_start: + if (IS_NUM(ch)) { + return s_http_host_port; + } + + break; + + default: + break; + } + return s_http_host_dead; +} + +static int +http_parse_host(const char * buf, struct http_parser_url *u, int found_at) { + enum http_host_state s; + + const char *p; + size_t buflen = u->field_data[UF_HOST].off + u->field_data[UF_HOST].len; + + assert(u->field_set & (1 << UF_HOST)); + + u->field_data[UF_HOST].len = 0; + + s = found_at ? s_http_userinfo_start : s_http_host_start; + + for (p = buf + u->field_data[UF_HOST].off; p < buf + buflen; p++) { + enum http_host_state new_s = http_parse_host_char(s, *p); + + if (new_s == s_http_host_dead) { + return 1; + } + + switch(new_s) { + case s_http_host: + if (s != s_http_host) { + u->field_data[UF_HOST].off = p - buf; + } + u->field_data[UF_HOST].len++; + break; + + case s_http_host_v6: + if (s != s_http_host_v6) { + u->field_data[UF_HOST].off = p - buf; + } + u->field_data[UF_HOST].len++; + break; + + case s_http_host_v6_zone_start: + case s_http_host_v6_zone: + u->field_data[UF_HOST].len++; + break; + + case s_http_host_port: + if (s != s_http_host_port) { + u->field_data[UF_PORT].off = p - buf; + u->field_data[UF_PORT].len = 0; + u->field_set |= (1 << UF_PORT); + } + u->field_data[UF_PORT].len++; + break; + + case s_http_userinfo: + if (s != s_http_userinfo) { + u->field_data[UF_USERINFO].off = p - buf ; + u->field_data[UF_USERINFO].len = 0; + u->field_set |= (1 << UF_USERINFO); + } + u->field_data[UF_USERINFO].len++; + break; + + default: + break; + } + s = new_s; + } + + /* Make sure we don't end somewhere unexpected */ + switch (s) { + case s_http_host_start: + case s_http_host_v6_start: + case s_http_host_v6: + case s_http_host_v6_zone_start: + case s_http_host_v6_zone: + case s_http_host_port_start: + case s_http_userinfo: + case s_http_userinfo_start: + return 1; + default: + break; + } + + return 0; +} + +void +http_parser_url_init(struct http_parser_url *u) { + memset(u, 0, sizeof(*u)); +} + +int +http_parser_parse_url(const char *buf, size_t buflen, int is_connect, + struct http_parser_url *u) +{ + enum state s; + const char *p; + enum http_parser_url_fields uf, old_uf; + int found_at = 0; + + u->port = u->field_set = 0; + s = is_connect ? s_req_server_start : s_req_spaces_before_url; + old_uf = UF_MAX; + + for (p = buf; p < buf + buflen; p++) { + s = parse_url_char(s, *p); + + /* Figure out the next field that we're operating on */ + switch (s) { + case s_dead: + return 1; + + /* Skip delimeters */ + case s_req_schema_slash: + case s_req_schema_slash_slash: + case s_req_server_start: + case s_req_query_string_start: + case s_req_fragment_start: + continue; + + case s_req_schema: + uf = UF_SCHEMA; + break; + + case s_req_server_with_at: + found_at = 1; + + /* FALLTROUGH */ + case s_req_server: + uf = UF_HOST; + break; + + case s_req_path: + uf = UF_PATH; + break; + + case s_req_query_string: + uf = UF_QUERY; + break; + + case s_req_fragment: + uf = UF_FRAGMENT; + break; + + default: + assert(!"Unexpected state"); + return 1; + } + + /* Nothing's changed; soldier on */ + if (uf == old_uf) { + u->field_data[uf].len++; + continue; + } + + u->field_data[uf].off = p - buf; + u->field_data[uf].len = 1; + + u->field_set |= (1 << uf); + old_uf = uf; + } + + /* host must be present if there is a schema */ + /* parsing http:///toto will fail */ + if ((u->field_set & (1 << UF_SCHEMA)) && + (u->field_set & (1 << UF_HOST)) == 0) { + return 1; + } + + if (u->field_set & (1 << UF_HOST)) { + if (http_parse_host(buf, u, found_at) != 0) { + return 1; + } + } + + /* CONNECT requests can only contain "hostname:port" */ + if (is_connect && u->field_set != ((1 << UF_HOST)|(1 << UF_PORT))) { + return 1; + } + + if (u->field_set & (1 << UF_PORT)) { + /* Don't bother with endp; we've already validated the string */ + unsigned long v = strtoul(buf + u->field_data[UF_PORT].off, NULL, 10); + + /* Ports have a max value of 2^16 */ + if (v > 0xffff) { + return 1; + } + + u->port = (uint16_t) v; + } + + return 0; +} + +void +http_parser_pause(http_parser *parser, int paused) { + /* Users should only be pausing/unpausing a parser that is not in an error + * state. In non-debug builds, there's not much that we can do about this + * other than ignore it. + */ + if (HTTP_PARSER_ERRNO(parser) == HPE_OK || + HTTP_PARSER_ERRNO(parser) == HPE_PAUSED) { + SET_ERRNO((paused) ? HPE_PAUSED : HPE_OK); + } else { + assert(0 && "Attempting to pause parser in error state"); + } +} + +int +http_body_is_final(const struct http_parser *parser) { + return parser->state == s_message_done; +} + +unsigned long +http_parser_version(void) { + return HTTP_PARSER_VERSION_MAJOR * 0x10000 | + HTTP_PARSER_VERSION_MINOR * 0x00100 | + HTTP_PARSER_VERSION_PATCH * 0x00001; +} diff --git a/Sources/PerfectCHTTPParser/http_parser.h b/Sources/PerfectCHTTPParser/http_parser.h new file mode 100755 index 00000000..0534bb45 --- /dev/null +++ b/Sources/PerfectCHTTPParser/http_parser.h @@ -0,0 +1,2 @@ + +#include "include/http_parser.h" diff --git a/Sources/PerfectCHTTPParser/include/http_parser.h b/Sources/PerfectCHTTPParser/include/http_parser.h new file mode 100755 index 00000000..ea263948 --- /dev/null +++ b/Sources/PerfectCHTTPParser/include/http_parser.h @@ -0,0 +1,362 @@ +/* Copyright Joyent, Inc. and other Node contributors. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ +#ifndef http_parser_h +#define http_parser_h +#ifdef __cplusplus +extern "C" { +#endif + +/* Also update SONAME in the Makefile whenever you change these. */ +#define HTTP_PARSER_VERSION_MAJOR 2 +#define HTTP_PARSER_VERSION_MINOR 7 +#define HTTP_PARSER_VERSION_PATCH 1 + +#include +#if defined(_WIN32) && !defined(__MINGW32__) && \ + (!defined(_MSC_VER) || _MSC_VER<1600) && !defined(__WINE__) +#include +#include +typedef __int8 int8_t; +typedef unsigned __int8 uint8_t; +typedef __int16 int16_t; +typedef unsigned __int16 uint16_t; +typedef __int32 int32_t; +typedef unsigned __int32 uint32_t; +typedef __int64 int64_t; +typedef unsigned __int64 uint64_t; +#else +#include +#endif + +/* Compile with -DHTTP_PARSER_STRICT=0 to make less checks, but run + * faster + */ +#ifndef HTTP_PARSER_STRICT +# define HTTP_PARSER_STRICT 1 +#endif + +/* Maximium header size allowed. If the macro is not defined + * before including this header then the default is used. To + * change the maximum header size, define the macro in the build + * environment (e.g. -DHTTP_MAX_HEADER_SIZE=). To remove + * the effective limit on the size of the header, define the macro + * to a very large number (e.g. -DHTTP_MAX_HEADER_SIZE=0x7fffffff) + */ +#ifndef HTTP_MAX_HEADER_SIZE +# define HTTP_MAX_HEADER_SIZE (80*1024) +#endif + +typedef struct http_parser http_parser; +typedef struct http_parser_settings http_parser_settings; + + +/* Callbacks should return non-zero to indicate an error. The parser will + * then halt execution. + * + * The one exception is on_headers_complete. In a HTTP_RESPONSE parser + * returning '1' from on_headers_complete will tell the parser that it + * should not expect a body. This is used when receiving a response to a + * HEAD request which may contain 'Content-Length' or 'Transfer-Encoding: + * chunked' headers that indicate the presence of a body. + * + * Returning `2` from on_headers_complete will tell parser that it should not + * expect neither a body nor any futher responses on this connection. This is + * useful for handling responses to a CONNECT request which may not contain + * `Upgrade` or `Connection: upgrade` headers. + * + * http_data_cb does not return data chunks. It will be called arbitrarily + * many times for each string. E.G. you might get 10 callbacks for "on_url" + * each providing just a few characters more data. + */ +typedef int (*http_data_cb) (http_parser*, const char *at, size_t length); +typedef int (*http_cb) (http_parser*); + + +/* Request Methods */ +#define HTTP_METHOD_MAP(XX) \ + XX(0, DELETE, DELETE) \ + XX(1, GET, GET) \ + XX(2, HEAD, HEAD) \ + XX(3, POST, POST) \ + XX(4, PUT, PUT) \ + /* pathological */ \ + XX(5, CONNECT, CONNECT) \ + XX(6, OPTIONS, OPTIONS) \ + XX(7, TRACE, TRACE) \ + /* WebDAV */ \ + XX(8, COPY, COPY) \ + XX(9, LOCK, LOCK) \ + XX(10, MKCOL, MKCOL) \ + XX(11, MOVE, MOVE) \ + XX(12, PROPFIND, PROPFIND) \ + XX(13, PROPPATCH, PROPPATCH) \ + XX(14, SEARCH, SEARCH) \ + XX(15, UNLOCK, UNLOCK) \ + XX(16, BIND, BIND) \ + XX(17, REBIND, REBIND) \ + XX(18, UNBIND, UNBIND) \ + XX(19, ACL, ACL) \ + /* subversion */ \ + XX(20, REPORT, REPORT) \ + XX(21, MKACTIVITY, MKACTIVITY) \ + XX(22, CHECKOUT, CHECKOUT) \ + XX(23, MERGE, MERGE) \ + /* upnp */ \ + XX(24, MSEARCH, M-SEARCH) \ + XX(25, NOTIFY, NOTIFY) \ + XX(26, SUBSCRIBE, SUBSCRIBE) \ + XX(27, UNSUBSCRIBE, UNSUBSCRIBE) \ + /* RFC-5789 */ \ + XX(28, PATCH, PATCH) \ + XX(29, PURGE, PURGE) \ + /* CalDAV */ \ + XX(30, MKCALENDAR, MKCALENDAR) \ + /* RFC-2068, section 19.6.1.2 */ \ + XX(31, LINK, LINK) \ + XX(32, UNLINK, UNLINK) \ + +enum http_method + { +#define XX(num, name, string) HTTP_##name = num, + HTTP_METHOD_MAP(XX) +#undef XX + }; + + +enum http_parser_type { HTTP_REQUEST, HTTP_RESPONSE, HTTP_BOTH }; + + +/* Flag values for http_parser.flags field */ +enum flags + { F_CHUNKED = 1 << 0 + , F_CONNECTION_KEEP_ALIVE = 1 << 1 + , F_CONNECTION_CLOSE = 1 << 2 + , F_CONNECTION_UPGRADE = 1 << 3 + , F_TRAILING = 1 << 4 + , F_UPGRADE = 1 << 5 + , F_SKIPBODY = 1 << 6 + , F_CONTENTLENGTH = 1 << 7 + }; + + +/* Map for errno-related constants + * + * The provided argument should be a macro that takes 2 arguments. + */ +#define HTTP_ERRNO_MAP(XX) \ + /* No error */ \ + XX(OK, "success") \ + \ + /* Callback-related errors */ \ + XX(CB_message_begin, "the on_message_begin callback failed") \ + XX(CB_url, "the on_url callback failed") \ + XX(CB_header_field, "the on_header_field callback failed") \ + XX(CB_header_value, "the on_header_value callback failed") \ + XX(CB_headers_complete, "the on_headers_complete callback failed") \ + XX(CB_body, "the on_body callback failed") \ + XX(CB_message_complete, "the on_message_complete callback failed") \ + XX(CB_status, "the on_status callback failed") \ + XX(CB_chunk_header, "the on_chunk_header callback failed") \ + XX(CB_chunk_complete, "the on_chunk_complete callback failed") \ + \ + /* Parsing-related errors */ \ + XX(INVALID_EOF_STATE, "stream ended at an unexpected time") \ + XX(HEADER_OVERFLOW, \ + "too many header bytes seen; overflow detected") \ + XX(CLOSED_CONNECTION, \ + "data received after completed connection: close message") \ + XX(INVALID_VERSION, "invalid HTTP version") \ + XX(INVALID_STATUS, "invalid HTTP status code") \ + XX(INVALID_METHOD, "invalid HTTP method") \ + XX(INVALID_URL, "invalid URL") \ + XX(INVALID_HOST, "invalid host") \ + XX(INVALID_PORT, "invalid port") \ + XX(INVALID_PATH, "invalid path") \ + XX(INVALID_QUERY_STRING, "invalid query string") \ + XX(INVALID_FRAGMENT, "invalid fragment") \ + XX(LF_EXPECTED, "LF character expected") \ + XX(INVALID_HEADER_TOKEN, "invalid character in header") \ + XX(INVALID_CONTENT_LENGTH, \ + "invalid character in content-length header") \ + XX(UNEXPECTED_CONTENT_LENGTH, \ + "unexpected content-length header") \ + XX(INVALID_CHUNK_SIZE, \ + "invalid character in chunk size header") \ + XX(INVALID_CONSTANT, "invalid constant string") \ + XX(INVALID_INTERNAL_STATE, "encountered unexpected internal state")\ + XX(STRICT, "strict mode assertion failed") \ + XX(PAUSED, "parser is paused") \ + XX(UNKNOWN, "an unknown error occurred") + + +/* Define HPE_* values for each errno value above */ +#define HTTP_ERRNO_GEN(n, s) HPE_##n, +enum http_errno { + HTTP_ERRNO_MAP(HTTP_ERRNO_GEN) +}; +#undef HTTP_ERRNO_GEN + + +/* Get an http_errno value from an http_parser */ +#define HTTP_PARSER_ERRNO(p) ((enum http_errno) (p)->http_errno) + + +struct http_parser { + /** PRIVATE **/ + unsigned int type : 2; /* enum http_parser_type */ + unsigned int flags : 8; /* F_* values from 'flags' enum; semi-public */ + unsigned int state : 7; /* enum state from http_parser.c */ + unsigned int header_state : 7; /* enum header_state from http_parser.c */ + unsigned int index : 7; /* index into current matcher */ + unsigned int lenient_http_headers : 1; + + uint32_t nread; /* # bytes read in various scenarios */ + uint64_t content_length; /* # bytes in body (0 if no Content-Length header) */ + + /** READ-ONLY **/ + unsigned short http_major; + unsigned short http_minor; + unsigned int status_code : 16; /* responses only */ + unsigned int method : 8; /* requests only */ + unsigned int http_errno : 7; + + /* 1 = Upgrade header was present and the parser has exited because of that. + * 0 = No upgrade header present. + * Should be checked when http_parser_execute() returns in addition to + * error checking. + */ + unsigned int upgrade : 1; + + /** PUBLIC **/ + void *data; /* A pointer to get hook to the "connection" or "socket" object */ +}; + + +struct http_parser_settings { + http_cb on_message_begin; + http_data_cb on_url; + http_data_cb on_status; + http_data_cb on_header_field; + http_data_cb on_header_value; + http_cb on_headers_complete; + http_data_cb on_body; + http_cb on_message_complete; + /* When on_chunk_header is called, the current chunk length is stored + * in parser->content_length. + */ + http_cb on_chunk_header; + http_cb on_chunk_complete; +}; + + +enum http_parser_url_fields + { UF_SCHEMA = 0 + , UF_HOST = 1 + , UF_PORT = 2 + , UF_PATH = 3 + , UF_QUERY = 4 + , UF_FRAGMENT = 5 + , UF_USERINFO = 6 + , UF_MAX = 7 + }; + + +/* Result structure for http_parser_parse_url(). + * + * Callers should index into field_data[] with UF_* values iff field_set + * has the relevant (1 << UF_*) bit set. As a courtesy to clients (and + * because we probably have padding left over), we convert any port to + * a uint16_t. + */ +struct http_parser_url { + uint16_t field_set; /* Bitmask of (1 << UF_*) values */ + uint16_t port; /* Converted UF_PORT string */ + + struct { + uint16_t off; /* Offset into buffer in which field starts */ + uint16_t len; /* Length of run in buffer */ + } field_data[UF_MAX]; +}; + + +/* Returns the library version. Bits 16-23 contain the major version number, + * bits 8-15 the minor version number and bits 0-7 the patch level. + * Usage example: + * + * unsigned long version = http_parser_version(); + * unsigned major = (version >> 16) & 255; + * unsigned minor = (version >> 8) & 255; + * unsigned patch = version & 255; + * printf("http_parser v%u.%u.%u\n", major, minor, patch); + */ +unsigned long http_parser_version(void); + +void http_parser_init(http_parser *parser, enum http_parser_type type); + + +/* Initialize http_parser_settings members to 0 + */ +void http_parser_settings_init(http_parser_settings *settings); + + +/* Executes the parser. Returns number of parsed bytes. Sets + * `parser->http_errno` on error. */ +size_t http_parser_execute(http_parser *parser, + const http_parser_settings *settings, + const char *data, + size_t len); + + +/* If http_should_keep_alive() in the on_headers_complete or + * on_message_complete callback returns 0, then this should be + * the last message on the connection. + * If you are the server, respond with the "Connection: close" header. + * If you are the client, close the connection. + */ +int http_should_keep_alive(const http_parser *parser); + +/* Returns a string version of the HTTP method. */ +const char *http_method_str(enum http_method m); + +/* Return a string name of the given error */ +const char *http_errno_name(enum http_errno err); + +/* Return a string description of the given error */ +const char *http_errno_description(enum http_errno err); + +/* Initialize all http_parser_url members to 0 */ +void http_parser_url_init(struct http_parser_url *u); + +/* Parse a URL; return nonzero on failure */ +int http_parser_parse_url(const char *buf, size_t buflen, + int is_connect, + struct http_parser_url *u); + +/* Pause or un-pause the parser; a nonzero value pauses */ +void http_parser_pause(http_parser *parser, int paused); + +/* Checks if this is the final chunk of the body. */ +int http_body_is_final(const http_parser *parser); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/Sources/PerfectCHTTPParser/include/module.modulemap b/Sources/PerfectCHTTPParser/include/module.modulemap new file mode 100644 index 00000000..895f5a0f --- /dev/null +++ b/Sources/PerfectCHTTPParser/include/module.modulemap @@ -0,0 +1,4 @@ +module PerfectCHTTPParser { + header "http_parser.h" + export * +} diff --git a/Sources/PerfectCRUD/Coding/Coding.swift b/Sources/PerfectCRUD/Coding/Coding.swift new file mode 100644 index 00000000..4db198e2 --- /dev/null +++ b/Sources/PerfectCRUD/Coding/Coding.swift @@ -0,0 +1,73 @@ +// +// PerfectCRUDCoding.swift +// PerfectCRUD +// +// Created by Kyle Jessup on 2017-11-22. +// Copyright (C) 2017 PerfectlySoft, Inc. +// +// ===----------------------------------------------------------------------===// +// +// This source file is part of the Perfect.org open source project +// +// Copyright (c) 2015 - 2017 PerfectlySoft Inc. and the Perfect project authors +// Licensed under Apache License v2.0 +// +// See http://perfect.org/licensing.html for license information +// +// ===----------------------------------------------------------------------===// +// + +import Foundation + +public struct CRUDDecoderError: Error { + public let msg: String + public init(_ m: String) { + msg = m + CRUDLogging.log(.error, m) + } +} + +public struct CRUDEncoderError: Error { + public let msg: String + public init(_ m: String) { + msg = m + CRUDLogging.log(.error, m) + } +} + +public struct ColumnKey: CodingKey { + public var stringValue: String + public var intValue: Int? = nil + public init?(stringValue s: String) { + stringValue = s + } + public init?(intValue: Int) { + return nil + } +} + +public indirect enum SpecialType { + case uint8Array, int8Array, data, uuid, date, codable, url, wrapped + public init?(_ type: Any.Type) { + switch type { + case is WrappedCodableProvider.Type: + self = .wrapped + case is [Int8].Type: + self = .int8Array + case is [UInt8].Type: + self = .uint8Array + case is Data.Type: + self = .data + case is UUID.Type: + self = .uuid + case is Date.Type: + self = .date + case is URL.Type: + self = .url + case is Codable.Type: + self = .codable + default: + return nil + } + } +} diff --git a/Sources/PerfectCRUD/Coding/CodingBindings.swift b/Sources/PerfectCRUD/Coding/CodingBindings.swift new file mode 100644 index 00000000..907af203 --- /dev/null +++ b/Sources/PerfectCRUD/Coding/CodingBindings.swift @@ -0,0 +1,192 @@ +// +// PerfectCRUDCodingBindings.swift +// PerfectCRUD +// +// Created by Kyle Jessup on 2017-11-25. +// Copyright (C) 2017 PerfectlySoft, Inc. +// +// ===----------------------------------------------------------------------===// +// +// This source file is part of the Perfect.org open source project +// +// Copyright (c) 2015 - 2017 PerfectlySoft Inc. and the Perfect project authors +// Licensed under Apache License v2.0 +// +// See http://perfect.org/licensing.html for license information +// +// ===----------------------------------------------------------------------===// +// + +import Foundation + +// -- generates bindings for an object +class CRUDBindingsWriter: KeyedEncodingContainerProtocol { + typealias Key = K + let codingPath: [CodingKey] = [] + let parent: CRUDBindingsEncoder + init(_ p: CRUDBindingsEncoder) { + parent = p + } + func addBinding(_ key: Key, value: Expression) throws { + try parent.addBinding(key: key, value: value) + } + func encodeNil(forKey key: K) throws { + // !FIX! this is never called + // Expect this to change in the future + // When nulls are important we have to use the column named decoder first + // and pass in the list of optionals to CRUDBindingsEncoder + CRUDLogging.log(.info, "CRUDBindingsWriter.encodeNil started being called.") + // try addBinding(key, value: .null) + } + func encode(_ value: Bool, forKey key: K) throws { + try addBinding(key, value: .bool(value)) + } + func encode(_ value: Int, forKey key: K) throws { + try addBinding(key, value: .integer(value)) + } + func encode(_ value: Int8, forKey key: K) throws { + try addBinding(key, value: .integer(Int(value))) + } + func encode(_ value: Int16, forKey key: K) throws { + try addBinding(key, value: .integer(Int(value))) + } + func encode(_ value: Int32, forKey key: K) throws { + try addBinding(key, value: .integer(Int(value))) + } + func encode(_ value: Int64, forKey key: K) throws { + try addBinding(key, value: .integer(Int(value))) + } + func encode(_ value: UInt, forKey key: K) throws { + try addBinding(key, value: .integer(Int(value))) + } + func encode(_ value: UInt8, forKey key: K) throws { + try addBinding(key, value: .integer(Int(value))) + } + func encode(_ value: UInt16, forKey key: K) throws { + try addBinding(key, value: .integer(Int(value))) + } + func encode(_ value: UInt32, forKey key: K) throws { + try addBinding(key, value: .integer(Int(value))) + } + func encode(_ value: UInt64, forKey key: K) throws { + try addBinding(key, value: .integer(Int(value))) + } + func encode(_ value: Float, forKey key: K) throws { + try addBinding(key, value: .decimal(Double(value))) + } + func encode(_ value: Double, forKey key: K) throws { + try addBinding(key, value: .decimal(value)) + } + func encode(_ value: String, forKey key: K) throws { + try addBinding(key, value: .string(value)) + } + // swiftlint:disable force_cast + func encode(_ value: T, forKey key: K) throws where T: Encodable { + guard let special = SpecialType(T.self) else { + throw CRUDEncoderError("Unsupported encoding type: \(value) for key: \(key.stringValue)") + } + switch special { + case .uint8Array: + try addBinding(key, value: .blob((value as! [UInt8]))) + case .int8Array: + try addBinding(key, value: .blob((value as! [Int8]).map { UInt8($0) })) + case .data: + try addBinding(key, value: .blob((value as! Data).map { $0 })) + case .uuid: + try addBinding(key, value: .uuid(value as! UUID)) + case .date: + try addBinding(key, value: .date(value as! Date)) + case .url: + try addBinding(key, value: .url(value as! URL)) + case .codable: + let data = try JSONEncoder().encode(value) + if let str = String(data: data, encoding: .utf8) { + try addBinding(key, value: .string(str)) + } + case .wrapped: + guard let wrapped = value as? WrappedCodableProvider else { + throw CRUDEncoderError("Unsupported encoding type: wrapped(\(value)) for key: \(key.stringValue)") + } + let wrappedValue = wrapped.provideWrappedValue() + switch wrappedValue { + case let m as Bool: try encode(m, forKey: key) + case let m as Int: try encode(m, forKey: key) + case let m as Int8: try encode(m, forKey: key) + case let m as Int16: try encode(m, forKey: key) + case let m as Int32: try encode(m, forKey: key) + case let m as Int64: try encode(m, forKey: key) + case let m as UInt: try encode(m, forKey: key) + case let m as UInt8: try encode(m, forKey: key) + case let m as UInt16: try encode(m, forKey: key) + case let m as UInt32: try encode(m, forKey: key) + case let m as UInt64: try encode(m, forKey: key) + case let m as Float: try encode(m, forKey: key) + case let m as Double: try encode(m, forKey: key) + case let m as String: try encode(m, forKey: key) + case let m as [UInt8]: try encode(m, forKey: key) + case let m as [Int8]: try encode(m, forKey: key) + case let m as Data: try encode(m, forKey: key) + case let m as UUID: try encode(m, forKey: key) + case let m as Date: try encode(m, forKey: key) + case let m as URL: try encode(m, forKey: key) + default: + throw CRUDEncoderError("Unsupported encoding type: wrapped(\(wrappedValue)) for key: \(key.stringValue)") + } + + } + } + func nestedContainer(keyedBy keyType: NestedKey.Type, forKey key: K) -> KeyedEncodingContainer where NestedKey: CodingKey { + fatalError("Unimplemented") + } + func nestedUnkeyedContainer(forKey key: K) -> UnkeyedEncodingContainer { + fatalError("Unimplemented") + } + func superEncoder() -> Encoder { + fatalError("Unimplemented") + } + func superEncoder(forKey key: K) -> Encoder { + fatalError("Unimplemented") + } +} + +public class CRUDBindingsEncoder: Encoder { + public let codingPath: [CodingKey] = [] + public let userInfo: [CodingUserInfoKey: Any] = [:] + let delegate: SQLGenDelegate + private var collectedBinds: [(String, Expression)] = [] + + public init(delegate d: SQLGenDelegate) throws { + delegate = d + } + + public func completedBindings(allKeys: [String], ignoreKeys: Set) throws -> [(column: String, identifier: String)] { + let exprDict: [String: Expression] = .init(uniqueKeysWithValues: collectedBinds) + let ret: [(column: String, identifier: String)] = try allKeys.filter { !ignoreKeys.contains($0) }.map { key in + let bindId: String + if let expr = exprDict[key] { + bindId = try delegate.getBinding(for: expr) + } else { + bindId = try delegate.getBinding(for: .null) + } + return (key, bindId) + } + return ret + } + + func completedBindings(ignoreKeys: Set) throws -> [(column: String, identifier: String)] { + return try completedBindings(allKeys: collectedBinds.map { $0.0 }, ignoreKeys: ignoreKeys) + } + + func addBinding(key: Key, value: Expression) throws { + collectedBinds.append((key.stringValue, value)) + } + public func container(keyedBy type: Key.Type) -> KeyedEncodingContainer where Key: CodingKey { + return KeyedEncodingContainer(CRUDBindingsWriter(self)) + } + public func unkeyedContainer() -> UnkeyedEncodingContainer { + fatalError("Unimplemented") + } + public func singleValueContainer() -> SingleValueEncodingContainer { + fatalError("Unimplemented") + } +} diff --git a/Sources/PerfectCRUD/Coding/CodingJoins.swift b/Sources/PerfectCRUD/Coding/CodingJoins.swift new file mode 100644 index 00000000..707fc300 --- /dev/null +++ b/Sources/PerfectCRUD/Coding/CodingJoins.swift @@ -0,0 +1,169 @@ +// +// PerfectCRUDCodingJoinings.swift +// PerfectCRUD +// +// Created by Kyle Jessup on 2017-12-01. +// + +import Foundation + +class SQLTopRowReader: KeyedDecodingContainerProtocol { + typealias Key = K + var codingPath: [CodingKey] = [] + var allKeys: [Key] = [] + let exeDelegate: SQLTopExeDelegate + let subRowReader: KeyedDecodingContainer + init(exeDelegate e: SQLTopExeDelegate, subRowReader s: KeyedDecodingContainer) { + exeDelegate = e + subRowReader = s + } + func contains(_ key: Key) -> Bool { + return subRowReader.contains(key) || nil != exeDelegate.subObjects.index(forKey: key.stringValue) + } + func decodeNil(forKey key: Key) throws -> Bool { + if nil != exeDelegate.subObjects.index(forKey: key.stringValue) { + return false + } + return try subRowReader.decodeNil(forKey: key) + } + func decode(_ type: Bool.Type, forKey key: Key) throws -> Bool { + return try subRowReader.decode(type, forKey: key) + } + func decode(_ type: Int.Type, forKey key: Key) throws -> Int { + return try subRowReader.decode(type, forKey: key) + } + func decode(_ type: Int8.Type, forKey key: Key) throws -> Int8 { + return try subRowReader.decode(type, forKey: key) + } + func decode(_ type: Int16.Type, forKey key: Key) throws -> Int16 { + return try subRowReader.decode(type, forKey: key) + } + func decode(_ type: Int32.Type, forKey key: Key) throws -> Int32 { + return try subRowReader.decode(type, forKey: key) + } + func decode(_ type: Int64.Type, forKey key: Key) throws -> Int64 { + return try subRowReader.decode(type, forKey: key) + } + func decode(_ type: UInt.Type, forKey key: Key) throws -> UInt { + return try subRowReader.decode(type, forKey: key) + } + func decode(_ type: UInt8.Type, forKey key: Key) throws -> UInt8 { + return try subRowReader.decode(type, forKey: key) + } + func decode(_ type: UInt16.Type, forKey key: Key) throws -> UInt16 { + return try subRowReader.decode(type, forKey: key) + } + func decode(_ type: UInt32.Type, forKey key: Key) throws -> UInt32 { + return try subRowReader.decode(type, forKey: key) + } + func decode(_ type: UInt64.Type, forKey key: Key) throws -> UInt64 { + return try subRowReader.decode(type, forKey: key) + } + func decode(_ type: Float.Type, forKey key: Key) throws -> Float { + return try subRowReader.decode(type, forKey: key) + } + func decode(_ type: Double.Type, forKey key: Key) throws -> Double { + return try subRowReader.decode(type, forKey: key) + } + func decode(_ type: String.Type, forKey key: Key) throws -> String { + return try subRowReader.decode(type, forKey: key) + } + // main table join mechanism + // !FIX! to put cached sub objects in foreign key dictionary + func decode(_ intype: T.Type, forKey key: Key) throws -> T where T: Decodable { + if let (onKeyName, onKey, equalsKey, objects) = exeDelegate.subObjects[key.stringValue], + let columnKey = Key(stringValue: onKeyName), + let comparisonType = type(of: onKey).valueType as? Decodable.Type { + + // I could not get this to compile. because comparisonType isn't known at compile time? + // let keyValue = try subRowReader.decode(comparisonType, forKey: columnKey) + let theseObjs: [Any] + switch comparisonType { + case let i as Bool.Type: + let keyValue = try subRowReader.decode(i, forKey: columnKey) + theseObjs = filteredValues(objects, lhs: keyValue, rhsKey: equalsKey) + case let i as Int.Type: + let keyValue = try subRowReader.decode(i, forKey: columnKey) + theseObjs = filteredValues(objects, lhs: keyValue, rhsKey: equalsKey) + case let i as Int8.Type: + let keyValue = try subRowReader.decode(i, forKey: columnKey) + theseObjs = filteredValues(objects, lhs: keyValue, rhsKey: equalsKey) + case let i as Int16.Type: + let keyValue = try subRowReader.decode(i, forKey: columnKey) + theseObjs = filteredValues(objects, lhs: keyValue, rhsKey: equalsKey) + case let i as Int32.Type: + let keyValue = try subRowReader.decode(i, forKey: columnKey) + theseObjs = filteredValues(objects, lhs: keyValue, rhsKey: equalsKey) + case let i as Int64.Type: + let keyValue = try subRowReader.decode(i, forKey: columnKey) + theseObjs = filteredValues(objects, lhs: keyValue, rhsKey: equalsKey) + case let i as UInt.Type: + let keyValue = try subRowReader.decode(i, forKey: columnKey) + theseObjs = filteredValues(objects, lhs: keyValue, rhsKey: equalsKey) + case let i as UInt8.Type: + let keyValue = try subRowReader.decode(i, forKey: columnKey) + theseObjs = filteredValues(objects, lhs: keyValue, rhsKey: equalsKey) + case let i as UInt16.Type: + let keyValue = try subRowReader.decode(i, forKey: columnKey) + theseObjs = filteredValues(objects, lhs: keyValue, rhsKey: equalsKey) + case let i as UInt32.Type: + let keyValue = try subRowReader.decode(i, forKey: columnKey) + theseObjs = filteredValues(objects, lhs: keyValue, rhsKey: equalsKey) + case let i as UInt64.Type: + let keyValue = try subRowReader.decode(i, forKey: columnKey) + theseObjs = filteredValues(objects, lhs: keyValue, rhsKey: equalsKey) + case let i as Float.Type: + let keyValue = try subRowReader.decode(i, forKey: columnKey) + theseObjs = filteredValues(objects, lhs: keyValue, rhsKey: equalsKey) + case let i as Double.Type: + let keyValue = try subRowReader.decode(i, forKey: columnKey) + theseObjs = filteredValues(objects, lhs: keyValue, rhsKey: equalsKey) + case let i as String.Type: + let keyValue = try subRowReader.decode(i, forKey: columnKey) + theseObjs = filteredValues(objects, lhs: keyValue, rhsKey: equalsKey) + case let i as Date.Type: + let keyValue = try subRowReader.decode(i, forKey: columnKey) + theseObjs = filteredValues(objects, lhs: keyValue, rhsKey: equalsKey) + case let i as Data.Type: + let keyValue = try subRowReader.decode(i, forKey: columnKey) + theseObjs = filteredValues(objects, lhs: keyValue, rhsKey: equalsKey) + case let i as UUID.Type: + let keyValue = try subRowReader.decode(i, forKey: columnKey) + theseObjs = filteredValues(objects, lhs: keyValue, rhsKey: equalsKey) + default: + throw CRUDSQLExeError("Invalid join comparison type \(comparisonType).") + } + // swiftlint:disable force_cast + return theseObjs as! T + } + return try subRowReader.decode(intype, forKey: key) + } + private func filteredValues(_ values: [Any], lhs: ComparisonType, rhsKey: AnyKeyPath) -> [Any] { + return values.compactMap { + if let p = $0 as? PivotContainer { + guard let rhs = p.keys.first as? ComparisonType, + lhs == rhs else { + return nil + } + return p.instance + } + guard let rhs = $0[keyPath: rhsKey] as? ComparisonType, + lhs == rhs else { + return nil + } + return $0 + } + } + func nestedContainer(keyedBy type: NestedKey.Type, forKey key: Key) throws -> KeyedDecodingContainer where NestedKey: CodingKey { + return try subRowReader.nestedContainer(keyedBy: type, forKey: key) + } + func nestedUnkeyedContainer(forKey key: Key) throws -> UnkeyedDecodingContainer { + return try subRowReader.nestedUnkeyedContainer(forKey: key) + } + func superDecoder() throws -> Decoder { + return try subRowReader.superDecoder() + } + func superDecoder(forKey key: Key) throws -> Decoder { + return try subRowReader.superDecoder(forKey: key) + } +} diff --git a/Sources/PerfectCRUD/Coding/CodingKeyPaths.swift b/Sources/PerfectCRUD/Coding/CodingKeyPaths.swift new file mode 100644 index 00000000..71563fc0 --- /dev/null +++ b/Sources/PerfectCRUD/Coding/CodingKeyPaths.swift @@ -0,0 +1,457 @@ +// +// PerfectCRUDCodingKeyPaths.swift +// PerfectCRUD +// +// Created by Kyle Jessup on 2017-11-27. +// + +import Foundation + +class CRUDKeyPathsReader: KeyedDecodingContainerProtocol { + typealias Key = K + let codingPath: [CodingKey] = [] + let allKeys: [Key] = [] + let parent: CRUDKeyPathsDecoder + + init(_ p: CRUDKeyPathsDecoder) { + parent = p + } + func contains(_ key: Key) -> Bool { + return true + } + func decodeNil(forKey key: Key) throws -> Bool { + return false + } + func decode(_ type: Bool.Type, forKey key: Key) throws -> Bool { + return try parent.countBool(key) + } + func decode(_ type: Int.Type, forKey key: Key) throws -> Int { + return Int(parent.countKey(key)) + } + func decode(_ type: Int8.Type, forKey key: Key) throws -> Int8 { + return parent.countKey(key) + } + func decode(_ type: Int16.Type, forKey key: Key) throws -> Int16 { + return Int16(parent.countKey(key)) + } + func decode(_ type: Int32.Type, forKey key: Key) throws -> Int32 { + return Int32(parent.countKey(key)) + } + func decode(_ type: Int64.Type, forKey key: Key) throws -> Int64 { + return Int64(parent.countKey(key)) + } + func decode(_ type: UInt.Type, forKey key: Key) throws -> UInt { + return UInt(parent.countKey(key)) + } + func decode(_ type: UInt8.Type, forKey key: Key) throws -> UInt8 { + return UInt8(parent.countKey(key)) + } + func decode(_ type: UInt16.Type, forKey key: Key) throws -> UInt16 { + return UInt16(parent.countKey(key)) + } + func decode(_ type: UInt32.Type, forKey key: Key) throws -> UInt32 { + return UInt32(parent.countKey(key)) + } + func decode(_ type: UInt64.Type, forKey key: Key) throws -> UInt64 { + return UInt64(parent.countKey(key)) + } + func decode(_ type: Float.Type, forKey key: Key) throws -> Float { + return Float(parent.countKey(key)) + } + func decode(_ type: Double.Type, forKey key: Key) throws -> Double { + return Double(parent.countKey(key)) + } + func decode(_ type: String.Type, forKey key: Key) throws -> String { + return "\(parent.countKey(key))" + } + // swiftlint:disable comma force_cast + func decode(_ type: T.Type, forKey key: Key) throws -> T { + if type is WrappedCodableProvider.Type { + parent.wrappedKey = key + let decoded = try T(from: parent) + defer { + parent.wrappedKey = nil + } + return decoded + } + let counter = parent.countKey(key) + if let special = SpecialType(type) { + switch special { + case .uint8Array: + return [UInt8(counter)] as! T + case .int8Array: + return [Int8(counter)] as! T + case .data: + return Data([UInt8(counter)]) as! T + case .uuid: + return UUID(uuid: uuid_t(UInt8(counter),0,0,0,0,0,0,0,0,0,0,0,0,0,0,0)) as! T + case .date: + return Date(timeIntervalSinceReferenceDate: TimeInterval(counter)) as! T + case .url: + return URL(string: "http://localhost:\(counter)/")! as! T + case .codable: + let decoder = CRUDKeyPathsDecoder(depth: 1 + parent.depth) + let decoded = try T(from: decoder) + parent.subTypeMap.append((key.stringValue, type, decoder)) + return decoded + case .wrapped: + throw CRUDDecoderError("Unhandled decode type \(type)") + } + } else { + let decoder = CRUDKeyPathsDecoder(depth: 1 + parent.depth) + let decoded = try T(from: decoder) + parent.subTypeMap.append((key.stringValue, type, decoder)) + return decoded + } + } + func nestedContainer(keyedBy type: NestedKey.Type, forKey key: Key) throws -> KeyedDecodingContainer where NestedKey: CodingKey { + throw CRUDDecoderError("Unimplimented nestedContainer") + } + func nestedUnkeyedContainer(forKey key: Key) throws -> UnkeyedDecodingContainer { + throw CRUDDecoderError("Unimplimented nestedUnkeyedContainer") + } + func superDecoder() throws -> Decoder { + return parent + } + func superDecoder(forKey key: Key) throws -> Decoder { + throw CRUDDecoderError("Unimplimented superDecoder") + } +} + +class CRUDKeyPathsUnkeyedReader: UnkeyedDecodingContainer, SingleValueDecodingContainer { + let codingPath: [CodingKey] = [] + var count: Int? = 1 + var isAtEnd: Bool { return !(currentIndex < count ?? 0) } + var currentIndex: Int = 0 + let parent: CRUDKeyPathsDecoder + let wrappedKey: CodingKey + + init(_ p: CRUDKeyPathsDecoder, key: CodingKey) { + wrappedKey = key + parent = p + } + + func decodeNil() -> Bool { + return false + } + + func decode(_ type: Bool.Type) throws -> Bool { + return try parent.countBool(wrappedKey) + } + + func decode(_ type: Int.Type) throws -> Int { + return Int(parent.countKey(wrappedKey)) + } + + func decode(_ type: Int8.Type) throws -> Int8 { + return Int8(parent.countKey(wrappedKey)) + } + + func decode(_ type: Int16.Type) throws -> Int16 { + return Int16(parent.countKey(wrappedKey)) + } + + func decode(_ type: Int32.Type) throws -> Int32 { + return Int32(parent.countKey(wrappedKey)) + } + + func decode(_ type: Int64.Type) throws -> Int64 { + return Int64(parent.countKey(wrappedKey)) + } + + func decode(_ type: UInt.Type) throws -> UInt { + return UInt(parent.countKey(wrappedKey)) + } + + func decode(_ type: UInt8.Type) throws -> UInt8 { + return UInt8(parent.countKey(wrappedKey)) + } + + func decode(_ type: UInt16.Type) throws -> UInt16 { + return UInt16(parent.countKey(wrappedKey)) + } + + func decode(_ type: UInt32.Type) throws -> UInt32 { + return UInt32(parent.countKey(wrappedKey)) + } + + func decode(_ type: UInt64.Type) throws -> UInt64 { + return UInt64(parent.countKey(wrappedKey)) + } + + func decode(_ type: Float.Type) throws -> Float { + return Float(parent.countKey(wrappedKey)) + } + + func decode(_ type: Double.Type) throws -> Double { + return Double(parent.countKey(wrappedKey)) + } + + func decode(_ type: String.Type) throws -> String { + return "\(parent.countKey(wrappedKey))" + } + + // swiftlint:disable comma + func decode(_ type: T.Type) throws -> T { + // this is being called in some cases for primitive types like Int + // instead of the proper funtion above + switch type { + case let t as Bool.Type: return try decode(t) as! T + case let t as Int.Type: return try decode(t) as! T + case let t as Int8.Type: return try decode(t) as! T + case let t as Int16.Type: return try decode(t) as! T + case let t as Int32.Type: return try decode(t) as! T + case let t as Int64.Type: return try decode(t) as! T + case let t as UInt.Type: return try decode(t) as! T + case let t as UInt8.Type: return try decode(t) as! T + case let t as UInt16.Type: return try decode(t) as! T + case let t as UInt32.Type: return try decode(t) as! T + case let t as UInt64.Type: return try decode(t) as! T + case let t as Float.Type: return try decode(t) as! T + case let t as Double.Type: return try decode(t) as! T + case let t as String.Type: return try decode(t) as! T + default: () + } + currentIndex += 1 + let counter = parent.countKey(wrappedKey) + if let special = SpecialType(type) { + switch special { + case .uint8Array: + return [UInt8(counter)] as! T + case .int8Array: + return [Int8(counter)] as! T + case .data: + return Data([UInt8(counter)]) as! T + case .uuid: + return UUID(uuid: uuid_t(UInt8(counter),0,0,0,0,0,0,0,0,0,0,0,0,0,0,0)) as! T + case .date: + return Date(timeIntervalSinceReferenceDate: TimeInterval(counter)) as! T + case .url: + return URL(string: "http://localhost:\(counter)/")! as! T + case .codable: + let decoder = CRUDKeyPathsDecoder(depth: 1 + parent.depth) + let decoded = try T(from: decoder) + parent.subTypeMap.append((wrappedKey.stringValue, type, decoder)) + return decoded + case .wrapped: + throw CRUDDecoderError("Unhandled decode type \(type)") + } + } else { + let decoder = CRUDKeyPathsDecoder(depth: 1 + parent.depth) + let decoded = try T(from: decoder) + parent.subTypeMap.append((wrappedKey.stringValue, type, decoder)) + return decoded + } + } + + func nestedContainer(keyedBy type: NestedKey.Type) throws -> KeyedDecodingContainer where NestedKey: CodingKey { + throw CRUDDecoderError("Unimplimented nestedContainer") + } + + func nestedUnkeyedContainer() throws -> UnkeyedDecodingContainer { + throw CRUDDecoderError("Unimplimented nestedUnkeyedContainer") + } + + func superDecoder() throws -> Decoder { + currentIndex += 1 + return parent + } +} + +class MyUnkeyedDecodingContainer: UnkeyedDecodingContainer { + var codingPath: [CodingKey] = [] + var count: Int? = 0 + var isAtEnd: Bool = true + var currentIndex: Int = 0 + + func decodeNil() throws -> Bool { + throw CRUDDecoderError("MyUnkeyedDecodingContainer zero count") + } + + func decode(_ type: Bool.Type) throws -> Bool { + throw CRUDDecoderError("MyUnkeyedDecodingContainer zero count") + } + + func decode(_ type: String.Type) throws -> String { + throw CRUDDecoderError("MyUnkeyedDecodingContainer zero count") + } + + func decode(_ type: Double.Type) throws -> Double { + throw CRUDDecoderError("MyUnkeyedDecodingContainer zero count") + } + + func decode(_ type: Float.Type) throws -> Float { + throw CRUDDecoderError("MyUnkeyedDecodingContainer zero count") + } + + func decode(_ type: Int.Type) throws -> Int { + throw CRUDDecoderError("MyUnkeyedDecodingContainer zero count") + } + + func decode(_ type: Int8.Type) throws -> Int8 { + throw CRUDDecoderError("MyUnkeyedDecodingContainer zero count") + } + + func decode(_ type: Int16.Type) throws -> Int16 { + throw CRUDDecoderError("MyUnkeyedDecodingContainer zero count") + } + + func decode(_ type: Int32.Type) throws -> Int32 { + throw CRUDDecoderError("MyUnkeyedDecodingContainer zero count") + } + + func decode(_ type: Int64.Type) throws -> Int64 { + throw CRUDDecoderError("MyUnkeyedDecodingContainer zero count") + } + + func decode(_ type: UInt.Type) throws -> UInt { + throw CRUDDecoderError("MyUnkeyedDecodingContainer zero count") + } + + func decode(_ type: UInt8.Type) throws -> UInt8 { + throw CRUDDecoderError("MyUnkeyedDecodingContainer zero count") + } + + func decode(_ type: UInt16.Type) throws -> UInt16 { + throw CRUDDecoderError("MyUnkeyedDecodingContainer zero count") + } + + func decode(_ type: UInt32.Type) throws -> UInt32 { + throw CRUDDecoderError("MyUnkeyedDecodingContainer zero count") + } + + func decode(_ type: UInt64.Type) throws -> UInt64 { + throw CRUDDecoderError("MyUnkeyedDecodingContainer zero count") + } + + func decode(_ type: T.Type) throws -> T where T: Decodable { + throw CRUDDecoderError("MyUnkeyedDecodingContainer zero count") + } + + func nestedContainer(keyedBy type: NestedKey.Type) throws -> KeyedDecodingContainer where NestedKey: CodingKey { + throw CRUDDecoderError("MyUnkeyedDecodingContainer zero count") + } + + func nestedUnkeyedContainer() throws -> UnkeyedDecodingContainer { + throw CRUDDecoderError("MyUnkeyedDecodingContainer zero count") + } + + func superDecoder() throws -> Decoder { + throw CRUDDecoderError("MyUnkeyedDecodingContainer zero count") + } +} + +public class CRUDKeyPathsDecoder: Decoder { + public var codingPath: [CodingKey] = [] + public var userInfo: [CodingUserInfoKey: Any] = [:] + var counter: Int8 = 1 + var boolCounter: Int8 = 0 + var typeMap: [Int8: String] = [:] + var subTypeMap: [(String, Decodable.Type, CRUDKeyPathsDecoder)] = [] + let depth: Int + var wrappedKey: CodingKey? + + init(depth d: Int = 0) { + depth = d + } + + func countKey(_ key: CodingKey) -> Int8 { + counter += 1 + typeMap[counter] = key.stringValue + return counter + } + + func countBool(_ key: CodingKey) throws -> Bool { + guard boolCounter < 2 else { + throw CRUDDecoderError("Perfect-CRUD table types can have up to two Bool properties. Try using small ints (Int8) with bool 'var' accessors.") + } + typeMap[boolCounter] = key.stringValue + boolCounter += 1 + return boolCounter == 2 + } + + public func container(keyedBy type: Key.Type) throws -> KeyedDecodingContainer { + return KeyedDecodingContainer(CRUDKeyPathsReader(self)) + } + public func unkeyedContainer() throws -> UnkeyedDecodingContainer { + return MyUnkeyedDecodingContainer() + } + public func singleValueContainer() throws -> SingleValueDecodingContainer { + guard let wrappedKey = self.wrappedKey else { + throw CRUDDecoderError("No wrappedKey waiting for unkeyedContainer") + } + return CRUDKeyPathsUnkeyedReader(self, key: wrappedKey) + } + public func getKeyPathName(_ instance: Any, keyPath: AnyKeyPath) throws -> String? { + guard let v = instance[keyPath: keyPath] else { + return nil + } + return try getKeyPathName(fromValue: v) + } + private func getKeyPathName(fromValue v: Any) throws -> String? { + switch v { + case let b as Bool: + return typeMap[b ? 1 : 0] + case let s as String: + guard let v = Int8(s) else { + return nil + } + return typeMap[v] + case let i as Int: + return typeMap[Int8(i)] + case let i as Int8: + return typeMap[Int8(i)] + case let i as Int16: + return typeMap[Int8(i)] + case let i as Int32: + return typeMap[Int8(i)] + case let i as Int64: + return typeMap[Int8(i)] + case let i as UInt: + return typeMap[Int8(i)] + case let i as UInt8: + return typeMap[Int8(i)] + case let i as UInt16: + return typeMap[Int8(i)] + case let i as UInt32: + return typeMap[Int8(i)] + case let i as UInt64: + return typeMap[Int8(i)] + case let i as Float: + return typeMap[Int8(i)] + case let i as Double: + return typeMap[Int8(i)] + case let o as Any?: + guard let unType = o else { + return nil + } + if let found = subTypeMap.first(where: { $0.1 == type(of: unType) }) { + return found.0 + } + if let special = SpecialType(type(of: unType)) { + switch special { + case .uint8Array: + return typeMap[Int8((v as! [UInt8])[0])] + case .int8Array: + return typeMap[Int8((v as! [Int8])[0])] + case .data: + return typeMap[Int8((v as! Data).first!)] + case .uuid: + return typeMap[Int8((v as! UUID).uuid.0)] + case .date: + return typeMap[Int8((v as! Date).timeIntervalSinceReferenceDate)] + case .url: + return typeMap[Int8((v as! URL).port!)] + case .codable, .wrapped: + throw CRUDDecoderError("Unsupported operation on codable column.") + } + } + return nil + default: + guard let found = subTypeMap.first(where: { $0.1 == type(of: v) }) else { + return nil + } + return found.0 + } + } +} diff --git a/Sources/PerfectCRUD/Coding/CodingNames.swift b/Sources/PerfectCRUD/Coding/CodingNames.swift new file mode 100644 index 00000000..908a42b0 --- /dev/null +++ b/Sources/PerfectCRUD/Coding/CodingNames.swift @@ -0,0 +1,384 @@ +// +// PerfectCRUDCodingNames.swift +// PerfectCRUD +// +// Created by Kyle Jessup on 2017-11-25. +// Copyright (C) 2017 PerfectlySoft, Inc. +// +// ===----------------------------------------------------------------------===// +// +// This source file is part of the Perfect.org open source project +// +// Copyright (c) 2015 - 2017 PerfectlySoft Inc. and the Perfect project authors +// Licensed under Apache License v2.0 +// +// See http://perfect.org/licensing.html for license information +// +// ===----------------------------------------------------------------------===// +// + +import Foundation + +// -- reads and records the coding keys for an object +class CRUDColumnNamesReader: KeyedDecodingContainerProtocol { + typealias Key = K + var codingPath: [CodingKey] = [] + + var allKeys: [Key] = [] + var parent: CRUDColumnNameDecoder + var knownKeys = Set() + var isOptional = false + init(_ p: CRUDColumnNameDecoder) { + parent = p + } + func appendKey(_ key: Key, _ type: Any.Type) { + let s = key.stringValue + if !knownKeys.contains(s) { + parent.collectedKeys.append((s, isOptional, type)) + knownKeys.insert(s) + } + isOptional = false // reset + } + func contains(_ key: Key) -> Bool { + return true + } + func decodeNil(forKey key: Key) throws -> Bool { + isOptional = true + return false + } + func decode(_ type: Bool.Type, forKey key: Key) throws -> Bool { + appendKey(key, type) + return true + } + func decode(_ type: Int.Type, forKey key: Key) throws -> Int { + appendKey(key, type) + return 0 + } + func decode(_ type: Int8.Type, forKey key: Key) throws -> Int8 { + appendKey(key, type) + return 0 + } + func decode(_ type: Int16.Type, forKey key: Key) throws -> Int16 { + appendKey(key, type) + return 0 + } + func decode(_ type: Int32.Type, forKey key: Key) throws -> Int32 { + appendKey(key, type) + return 0 + } + func decode(_ type: Int64.Type, forKey key: Key) throws -> Int64 { + appendKey(key, type) + return 0 + } + func decode(_ type: UInt.Type, forKey key: Key) throws -> UInt { + appendKey(key, type) + return 0 + } + func decode(_ type: UInt8.Type, forKey key: Key) throws -> UInt8 { + appendKey(key, type) + return 0 + } + func decode(_ type: UInt16.Type, forKey key: Key) throws -> UInt16 { + appendKey(key, type) + return 0 + } + func decode(_ type: UInt32.Type, forKey key: Key) throws -> UInt32 { + appendKey(key, type) + return 0 + } + func decode(_ type: UInt64.Type, forKey key: Key) throws -> UInt64 { + appendKey(key, type) + return 0 + } + func decode(_ type: Float.Type, forKey key: Key) throws -> Float { + appendKey(key, type) + return 0 + } + func decode(_ type: Double.Type, forKey key: Key) throws -> Double { + appendKey(key, type) + return 0 + } + func decode(_ type: String.Type, forKey key: Key) throws -> String { + appendKey(key, type) + return "" + } + // swiftlint:disable force_cast + func decode(_ t: T.Type, forKey key: Key) throws -> T { + if let special = SpecialType(t) { + switch special { + case .uint8Array: + appendKey(key, t) + return [UInt8]() as! T + case .int8Array: + appendKey(key, t) + return [Int8]() as! T + case .data: + appendKey(key, t) + return Data() as! T + case .uuid: + appendKey(key, t) + return UUID() as! T + case .date: + appendKey(key, t) + return Date() as! T + case .url: + appendKey(key, t) + return URL(string: "http://localhost")! as! T + case .codable: + () + case .wrapped: + () +// guard let wrapped = t as? WrappedCodableProvider.Type else { +// throw CRUDEncoderError("Unsupported decoding type: \(t) for key: \(key.stringValue)") +// } +// let wrappedValueType = wrapped.provideWrappedValueType() +// switch wrappedValueType { +// case let m as Bool.Type: return try decode(m, forKey: key) +// case let m as Int.Type: return try decode(m, forKey: key) +// case let m as Int8.Type: return try decode(m, forKey: key) +// case let m as Int16.Type: return try decode(m, forKey: key) +// case let m as Int32.Type: return try decode(m, forKey: key) +// case let m as Int64.Type: return try decode(m, forKey: key) +// case let m as UInt.Type: return try decode(m, forKey: key) +// case let m as UInt8.Type: return try decode(m, forKey: key) +// case let m as UInt16.Type: return try decode(m, forKey: key) +// case let m as UInt32.Type: return try decode(m, forKey: key) +// case let m as UInt64.Type: return try decode(m, forKey: key) +// case let m as Float.Type: return try decode(m, forKey: key) +// case let m as Double.Type: return try decode(m, forKey: key) +// case let m as String.Type: return try decode(m, forKey: key) +// default: +// throw CRUDEncoderError("Unsupported decoding type: wrapped(\(wrappedValueType)) for key: \(key.stringValue)") +// } + } + } + return try decodeInner(t, forKey: key) + } + + func decodeInner(_ t: T.Type, forKey key: Key) throws -> T { + let sub = CRUDColumnNameDecoder(depth: 1 + parent.depth) + let ret = try T(from: sub) + if let ar = ret as? [Codable] { + if !ar.isEmpty { + let subType = type(of: ar[0]) + sub.codingPath.append(key) + sub.tableNamePath.append(subType.CRUDTableName) + ar[0].addSubTable(to: parent, name: key.stringValue, decoder: sub) + } + return ret +// } else if ret is WrappedValueTypeProvider { +// appendKey(key, type(of: ret as! WrappedValueTypeProvider).wrappedValueType()) +// return ret + } else if ret is Codable { // ... + appendKey(key, type(of: ret)) + return ret + } + throw CRUDSQLGenError("Unsupported sub-table type \(T.self)") + } + + func nestedContainer(keyedBy type: NestedKey.Type, forKey key: Key) throws -> KeyedDecodingContainer { + throw CRUDDecoderError("Unimplimented nestedContainer") + } + func nestedUnkeyedContainer(forKey key: Key) throws -> UnkeyedDecodingContainer { + throw CRUDDecoderError("Unimplimented nestedUnkeyedContainer") + } + func superDecoder() throws -> Decoder { + throw CRUDDecoderError("Unimplimented superDecoder") + } + func superDecoder(forKey key: Key) throws -> Decoder { + throw CRUDDecoderError("Unimplimented superDecoder") + } +} + +class CRUDColumnNameUnkeyedReader: UnkeyedDecodingContainer, SingleValueDecodingContainer { + let codingPath: [CodingKey] = [] + var count: Int? = 1 + var isAtEnd: Bool { return !(currentIndex < count ?? 0) } + var currentIndex: Int = 0 + let parent: CRUDColumnNameDecoder + var decodedType: Any.Type? + var typeDecoder: CRUDColumnNameDecoder? + init(parent p: CRUDColumnNameDecoder) { + parent = p + } + func advance(_ t: Any.Type) { + currentIndex += 1 + decodedType = t + } + func decodeNil() -> Bool { + return false + } + + func decode(_ type: Bool.Type) throws -> Bool { + advance(type) + return false + } + + func decode(_ type: Int.Type) throws -> Int { + advance(type) + return 0 + } + + func decode(_ type: Int8.Type) throws -> Int8 { + advance(type) + return 0 + } + + func decode(_ type: Int16.Type) throws -> Int16 { + advance(type) + return 0 + } + + func decode(_ type: Int32.Type) throws -> Int32 { + advance(type) + return 0 + } + + func decode(_ type: Int64.Type) throws -> Int64 { + advance(type) + return 0 + } + + func decode(_ type: UInt.Type) throws -> UInt { + advance(type) + return 0 + } + + func decode(_ type: UInt8.Type) throws -> UInt8 { + advance(type) + return 0 + } + + func decode(_ type: UInt16.Type) throws -> UInt16 { + advance(type) + return 0 + } + + func decode(_ type: UInt32.Type) throws -> UInt32 { + advance(type) + return 0 + } + func decode(_ type: UInt64.Type) throws -> UInt64 { + advance(type) + return 0 + } + func decode(_ type: Float.Type) throws -> Float { + advance(type) + return 0 + } + func decode(_ type: Double.Type) throws -> Double { + advance(type) + return 0 + } + func decode(_ type: String.Type) throws -> String { + advance(type) + return "" + } + // swiftlint:disable force_cast + func decode(_ t: T.Type) throws -> T { + advance(t) + if let special = SpecialType(t) { + switch special { + case .uint8Array: + return [UInt8]() as! T + case .int8Array: + return [Int8]() as! T + case .data: + return Data() as! T + case .uuid: + return UUID() as! T + case .date: + return Date() as! T + case .url: + return URL(string: "http://localhost")! as! T + case .codable, .wrapped: + () + } + } + return try T(from: parent) + } + func nestedContainer(keyedBy type: NestedKey.Type) throws -> KeyedDecodingContainer { + throw CRUDDecoderError("Unimplimented nestedContainer") + } + func nestedUnkeyedContainer() throws -> UnkeyedDecodingContainer { + throw CRUDDecoderError("Unimplimented nestedUnkeyedContainer") + } + func superDecoder() throws -> Decoder { + currentIndex += 1 + return parent + } +} + +protocol SubTableProto { + var name: String { get } +// var type: Decodable.Type { get } +// var decoder: CRUDColumnNameDecoder { get } + func tableStructure() throws -> TableStructure + func matches(_ type: T.Type) -> Bool +} + +struct SubTable: SubTableProto { + let name: String + let type: T.Type + let decoder: CRUDColumnNameDecoder + let realType: R.Type + func tableStructure() throws -> TableStructure { + return try type.self.CRUDTableStructure(columnDecoder: decoder) + } + func matches(_ type: U.Type) -> Bool { + return self.type == type + } +} + +extension Decodable where Self: Encodable { + @available(macOS 10.15.0, *) + func makeSubTable(name: String, decoder: CRUDColumnNameDecoder) -> some SubTableProto { + return SubTable(name: name, type: Self.self, decoder: decoder, realType: Self.self) + } + + func addSubTable(to: CRUDColumnNameDecoder, name: String, decoder: CRUDColumnNameDecoder) { + to.addSubTable(SubTable(name: name, type: Self.self, decoder: decoder, realType: Self.self)) + } +} + +public class CRUDColumnNameDecoder: Decoder { + public var codingPath: [CodingKey] = [] + public var userInfo: [CodingUserInfoKey: Any] = [:] + + var tableNamePath: [String] = [] + public var collectedKeys: [(name: String, optional: Bool, type: Any.Type)] = [] + var subTables: [SubTableProto] = [] + var pendingReader: CRUDColumnNameUnkeyedReader? + let depth: Int + public init(depth d: Int = 0) { + depth = d + } + func addSubTable(_ name: String, type: T.Type, decoder: CRUDColumnNameDecoder) { + guard subTables.filter({ $0.name == name }).count == 0 else { + return + } + subTables.append(SubTable(name: name, type: type, decoder: decoder, realType: type)) + } + func addSubTable(_ sub: T) { + guard subTables.filter({ $0.name == sub.name }).count == 0 else { + return + } + subTables.append(sub) + } + + public func container(keyedBy type: Key.Type) throws -> KeyedDecodingContainer { + return KeyedDecodingContainer(CRUDColumnNamesReader(self)) + } + public func unkeyedContainer() throws -> UnkeyedDecodingContainer { + let r = CRUDColumnNameUnkeyedReader(parent: self) + if depth > 1 { + r.count = 0 + } + pendingReader = r + return r + } + public func singleValueContainer() throws -> SingleValueDecodingContainer { + let r = CRUDColumnNameUnkeyedReader(parent: self) + return r + } +} diff --git a/Sources/PerfectCRUD/Coding/CodingRows.swift b/Sources/PerfectCRUD/Coding/CodingRows.swift new file mode 100644 index 00000000..241595b1 --- /dev/null +++ b/Sources/PerfectCRUD/Coding/CodingRows.swift @@ -0,0 +1,291 @@ +// +// PerfectCRUDCodingRows.swift +// PerfectCRUD +// +// Created by Kyle Jessup on 2017-11-25. +// Copyright (C) 2017 PerfectlySoft, Inc. +// +// ===----------------------------------------------------------------------===// +// +// This source file is part of the Perfect.org open source project +// +// Copyright (c) 2015 - 2017 PerfectlySoft Inc. and the Perfect project authors +// Licensed under Apache License v2.0 +// +// See http://perfect.org/licensing.html for license information +// +// ===----------------------------------------------------------------------===// +// + +import Foundation + +public class CRUDRowDecoder: Decoder { + public typealias Key = K + public var codingPath: [CodingKey] = [] + public var userInfo: [CodingUserInfoKey: Any] = [:] + let delegate: SQLExeDelegate + public init(delegate d: SQLExeDelegate) { + delegate = d + } + public func container(keyedBy type: Key.Type) throws -> KeyedDecodingContainer where Key: CodingKey { + guard let next: KeyedDecodingContainer = try delegate.next() else { + throw CRUDDecoderError("No row.") + } + return next + } + public func unkeyedContainer() throws -> UnkeyedDecodingContainer { + throw CRUDDecoderError("Unimplemented") + } + public func singleValueContainer() throws -> SingleValueDecodingContainer { + throw CRUDDecoderError("Unimplemented") + } +} + +public class CRUDColumnValueDecoder: Decoder, SingleValueDecodingContainer { + + public typealias Key = K + public var codingPath: [CodingKey] = [] + public var userInfo: [CodingUserInfoKey: Any] = [:] + // swiftlint:disable force_cast + var key: Key { codingPath.first! as! Key } + let source: KeyedDecodingContainer + public init(source: KeyedDecodingContainer, key: K) { + self.source = source + codingPath = [key] + } + public func container(keyedBy type: Key.Type) throws -> KeyedDecodingContainer where Key: CodingKey { + throw CRUDDecoderError("Unimplemented container") + } + public func unkeyedContainer() throws -> UnkeyedDecodingContainer { + throw CRUDDecoderError("Unimplemented") + } + public func singleValueContainer() throws -> SingleValueDecodingContainer { + return self // throw CRUDDecoderError("Unimplemented singleValueContainer") + } + + public func decodeNil() -> Bool { + return (try? source.decodeNil(forKey: key)) ?? false + } + + public func decode(_ type: Bool.Type) throws -> Bool { + return try source.decode(type, forKey: key) + } + + public func decode(_ type: String.Type) throws -> String { + return try source.decode(type, forKey: key) + } + + public func decode(_ type: Double.Type) throws -> Double { + return try source.decode(type, forKey: key) + } + + public func decode(_ type: Float.Type) throws -> Float { + return try source.decode(type, forKey: key) + } + + public func decode(_ type: Int.Type) throws -> Int { + return try source.decode(type, forKey: key) + } + + public func decode(_ type: Int8.Type) throws -> Int8 { + return try source.decode(type, forKey: key) + } + + public func decode(_ type: Int16.Type) throws -> Int16 { + return try source.decode(type, forKey: key) + } + + public func decode(_ type: Int32.Type) throws -> Int32 { + return try source.decode(type, forKey: key) + } + + public func decode(_ type: Int64.Type) throws -> Int64 { + return try source.decode(type, forKey: key) + } + + public func decode(_ type: UInt.Type) throws -> UInt { + return try source.decode(type, forKey: key) + } + + public func decode(_ type: UInt8.Type) throws -> UInt8 { + return try source.decode(type, forKey: key) + } + + public func decode(_ type: UInt16.Type) throws -> UInt16 { + return try source.decode(type, forKey: key) + } + + public func decode(_ type: UInt32.Type) throws -> UInt32 { + return try source.decode(type, forKey: key) + } + + public func decode(_ type: UInt64.Type) throws -> UInt64 { + return try source.decode(type, forKey: key) + } + + public func decode(_ type: T.Type) throws -> T where T: Decodable { + return try source.decode(type, forKey: key) + } +} + +struct PivotKey: Codable { + let _crud_pivot_id_: T +} + +public class CRUDPivotRowDecoder: Decoder { + public typealias Key = K + + public var codingPath: [CodingKey] = [] + public var userInfo: [CodingUserInfoKey: Any] = [:] + let delegate: SQLExeDelegate + let pivotOnType: Codable.Type + public var orderedKeys: [Codable] = [] + public init(delegate d: SQLExeDelegate, pivotOn p: Codable.Type) { + delegate = d + pivotOnType = p + } + public func container(keyedBy type: Key.Type) throws -> KeyedDecodingContainer where Key: CodingKey { + guard let next: KeyedDecodingContainer = try delegate.next(), + let columnKey = ColumnKey(stringValue: joinPivotIdColumnName) else { + throw CRUDDecoderError("No row.") + } + switch pivotOnType { + case let i as Bool.Type: + let keyValue = try next.decode(i, forKey: columnKey) + orderedKeys.append(keyValue) + case let i as Int.Type: + let keyValue = try next.decode(i, forKey: columnKey) + orderedKeys.append(keyValue) + case let i as Int8.Type: + let keyValue = try next.decode(i, forKey: columnKey) + orderedKeys.append(keyValue) + case let i as Int16.Type: + let keyValue = try next.decode(i, forKey: columnKey) + orderedKeys.append(keyValue) + case let i as Int32.Type: + let keyValue = try next.decode(i, forKey: columnKey) + orderedKeys.append(keyValue) + case let i as Int64.Type: + let keyValue = try next.decode(i, forKey: columnKey) + orderedKeys.append(keyValue) + case let i as UInt.Type: + let keyValue = try next.decode(i, forKey: columnKey) + orderedKeys.append(keyValue) + case let i as UInt8.Type: + let keyValue = try next.decode(i, forKey: columnKey) + orderedKeys.append(keyValue) + case let i as UInt16.Type: + let keyValue = try next.decode(i, forKey: columnKey) + orderedKeys.append(keyValue) + case let i as UInt32.Type: + let keyValue = try next.decode(i, forKey: columnKey) + orderedKeys.append(keyValue) + case let i as UInt64.Type: + let keyValue = try next.decode(i, forKey: columnKey) + orderedKeys.append(keyValue) + case let i as Float.Type: + let keyValue = try next.decode(i, forKey: columnKey) + orderedKeys.append(keyValue) + case let i as Double.Type: + let keyValue = try next.decode(i, forKey: columnKey) + orderedKeys.append(keyValue) + case let i as String.Type: + let keyValue = try next.decode(i, forKey: columnKey) + orderedKeys.append(keyValue) + case let i as Date.Type: + let keyValue = try next.decode(i, forKey: columnKey) + orderedKeys.append(keyValue) + case let i as Data.Type: + let keyValue = try next.decode(i, forKey: columnKey) + orderedKeys.append(keyValue) + case let i as UUID.Type: + let keyValue = try next.decode(i, forKey: columnKey) + orderedKeys.append(keyValue) + default: + throw CRUDSQLExeError("Invalid join comparison type \(pivotOnType).") + } + return KeyedDecodingContainer(CRUDPivotRowReader(subReader: next)) + } + public func unkeyedContainer() throws -> UnkeyedDecodingContainer { + throw CRUDDecoderError("Unimplemented") + } + public func singleValueContainer() throws -> SingleValueDecodingContainer { + throw CRUDDecoderError("Unimplemented") + } +} + +class CRUDPivotRowReader: KeyedDecodingContainerProtocol { + typealias Key = K + var codingPath: [CodingKey] = [] + var allKeys: [K] = [] + let subReader: KeyedDecodingContainer + init(subReader s: KeyedDecodingContainer) { + subReader = s + } + private func k(_ key: Key) -> K2 { + return K2(stringValue: key.stringValue)! + } + func contains(_ key: K) -> Bool { + return subReader.contains(k(key)) + } + func decodeNil(forKey key: K) throws -> Bool { + return try subReader.decodeNil(forKey: k(key)) + } + func decode(_ type: Bool.Type, forKey key: K) throws -> Bool { + return try subReader.decode(type, forKey: k(key)) + } + func decode(_ type: Int.Type, forKey key: K) throws -> Int { + return try subReader.decode(type, forKey: k(key)) + } + func decode(_ type: Int8.Type, forKey key: K) throws -> Int8 { + return try subReader.decode(type, forKey: k(key)) + } + func decode(_ type: Int16.Type, forKey key: K) throws -> Int16 { + return try subReader.decode(type, forKey: k(key)) + } + func decode(_ type: Int32.Type, forKey key: K) throws -> Int32 { + return try subReader.decode(type, forKey: k(key)) + } + func decode(_ type: Int64.Type, forKey key: K) throws -> Int64 { + return try subReader.decode(type, forKey: k(key)) + } + func decode(_ type: UInt.Type, forKey key: K) throws -> UInt { + return try subReader.decode(type, forKey: k(key)) + } + func decode(_ type: UInt8.Type, forKey key: K) throws -> UInt8 { + return try subReader.decode(type, forKey: k(key)) + } + func decode(_ type: UInt16.Type, forKey key: K) throws -> UInt16 { + return try subReader.decode(type, forKey: k(key)) + } + func decode(_ type: UInt32.Type, forKey key: K) throws -> UInt32 { + return try subReader.decode(type, forKey: k(key)) + } + func decode(_ type: UInt64.Type, forKey key: K) throws -> UInt64 { + return try subReader.decode(type, forKey: k(key)) + } + func decode(_ type: Float.Type, forKey key: K) throws -> Float { + return try subReader.decode(type, forKey: k(key)) + } + func decode(_ type: Double.Type, forKey key: K) throws -> Double { + return try subReader.decode(type, forKey: k(key)) + } + func decode(_ type: String.Type, forKey key: K) throws -> String { + return try subReader.decode(type, forKey: k(key)) + } + func decode(_ type: T.Type, forKey key: K) throws -> T where T: Decodable { + return try subReader.decode(type, forKey: k(key)) + } + func nestedContainer(keyedBy type: NestedKey.Type, forKey key: K) throws -> KeyedDecodingContainer where NestedKey: CodingKey { + return try subReader.nestedContainer(keyedBy: type, forKey: k(key)) + } + func nestedUnkeyedContainer(forKey key: K) throws -> UnkeyedDecodingContainer { + return try subReader.nestedUnkeyedContainer(forKey: k(key)) + } + func superDecoder() throws -> Decoder { + return try subReader.superDecoder() + } + func superDecoder(forKey key: K) throws -> Decoder { + return try subReader.superDecoder(forKey: k(key)) + } +} diff --git a/Sources/PerfectCRUD/Create.swift b/Sources/PerfectCRUD/Create.swift new file mode 100644 index 00000000..8cf4eed0 --- /dev/null +++ b/Sources/PerfectCRUD/Create.swift @@ -0,0 +1,303 @@ +// +// PerfectCRUDCreate.swift +// PerfectCRUD +// +// Created by Kyle Jessup on 2017-12-03. +// + +import Foundation + +public struct TableCreatePolicy: OptionSet { + public let rawValue: Int + public init(rawValue r: Int) { rawValue = r } + public static let shallow = TableCreatePolicy(rawValue: 1) + public static let dropTable = TableCreatePolicy(rawValue: 2) + public static let reconcileTable = TableCreatePolicy(rawValue: 4) + + public static let defaultPolicy: TableCreatePolicy = [] +} + +public class TableStructure { + public class Column { + public enum Property: Equatable { + case primaryKey + case foreignKey(String, String, ForeignKeyAction, ForeignKeyAction) // table, column, onDelete, onUpdate + } + public let name: String + public let type: Any.Type + public let optional: Bool + public let properties: [Property] + init(name: String, type: Any.Type, optional: Bool, properties: [Property]) { + self.name = name + self.type = type + self.optional = optional + self.properties = properties + } + } + public let tableName: String + public var primaryKeyName: String? { columns.first(where: { $0.properties.contains(.primaryKey) })?.name } + public let columns: [Column] + public var subTables: [TableStructure] + public let indexes: [String] + init(tableName: String, columns: [Column], subTables: [TableStructure], indexes: [String]) { + self.tableName = tableName + self.columns = columns + self.subTables = subTables + self.indexes = indexes + } +} + +public protocol WrappedCodableProvider: Codable { + static func provideWrappedValueType() -> Codable.Type + func provideWrappedValue() -> Codable +} + +protocol PrimaryKeyWrapper: WrappedCodableProvider {} + +@propertyWrapper +public struct PrimaryKey: PrimaryKeyWrapper, Codable { + public static func provideWrappedValueType() -> Codable.Type { Value.self } + public var wrappedValue: Value + public var projectedValue: Value { wrappedValue } + public init(wrappedValue: Value) { + self.wrappedValue = wrappedValue + } + public init(from decoder: Decoder) throws { + wrappedValue = try decoder.singleValueContainer().decode(Value.self) + } + public func encode(to encoder: Encoder) throws { + var c = encoder.singleValueContainer() + try c.encode(wrappedValue) + } + public func provideWrappedValue() -> Codable { + return wrappedValue + } +} + +public enum ForeignKeyAction { + case ignore, restrict, setNull, setDefault, cascade +} + +public protocol ForeignKeyActionProvider { + static var action: ForeignKeyAction { get } +} + +public struct ForeignKeyActionIgnore: ForeignKeyActionProvider { + static public var action = ForeignKeyAction.ignore +} + +public struct ForeignKeyActionRestrict: ForeignKeyActionProvider { + static public var action = ForeignKeyAction.restrict +} + +public struct ForeignKeyActionSetNull: ForeignKeyActionProvider { + static public var action = ForeignKeyAction.setNull +} + +public struct ForeignKeyActionSetDefault: ForeignKeyActionProvider { + static public var action = ForeignKeyAction.setDefault +} + +public struct ForeignKeyActionCascade: ForeignKeyActionProvider { + static public var action = ForeignKeyAction.cascade +} + +public let ignore = ForeignKeyActionIgnore() +public let restrict = ForeignKeyActionRestrict() +public let setNull = ForeignKeyActionSetNull() +public let setDefault = ForeignKeyActionSetDefault() +public let cascade = ForeignKeyActionCascade() + +protocol ForeignKeyWrapper: WrappedCodableProvider { + static func foreignTableStructure() throws -> TableStructure + static func foreignKeyDeleteAction() -> ForeignKeyAction + static func foreignKeyUpdateAction() -> ForeignKeyAction +} + +@propertyWrapper +public struct ForeignKey: ForeignKeyWrapper, Codable { + public static func provideWrappedValueType() -> Codable.Type { Value.self } + static func foreignKeyDeleteAction() -> ForeignKeyAction { DeleteAction.action } + static func foreignKeyUpdateAction() -> ForeignKeyAction { UpdateAction.action } + static func foreignTableStructure() throws -> TableStructure { + return try Table.CRUDTableStructure() + } + + public var wrappedValue: Value { + get { projectedValue! } + set { projectedValue = newValue } + } + public var projectedValue: Value? = nil + + public init(_ parent: Table.Type, onDelete: DeleteAction, onUpdate: UpdateAction, wrappedValue: Value) { + self.projectedValue = wrappedValue + } + public init(_ parent: Table.Type, onDelete: DeleteAction, onUpdate: UpdateAction) { + + } + public init(from decoder: Decoder) throws { + projectedValue = try decoder.singleValueContainer().decode(Value.self) + } + public func encode(to encoder: Encoder) throws { + var c = encoder.singleValueContainer() + try c.encode(projectedValue!) + } + public func provideWrappedValue() -> Codable { + return wrappedValue + } +} + +private var tableStructureCache: [String: TableStructure] = [:] + +// for tests +public func CRUDClearTableStructureCache() { + tableStructureCache.removeAll() +} + +extension Decodable { + static func CRUDTableStructure(primaryKey: PartialKeyPath? = nil) throws -> TableStructure { + let columnDecoder = CRUDColumnNameDecoder() + columnDecoder.tableNamePath.append("\(Self.CRUDTableName)") + _ = try Self.init(from: columnDecoder) + return try CRUDTableStructure(columnDecoder: columnDecoder, primaryKey: primaryKey) + } + static func CRUDTableStructure(columnDecoder: CRUDColumnNameDecoder, primaryKey: PartialKeyPath? = nil) throws -> TableStructure { + let cacheKey = "\(type(of: Self.self))" + if let cached = tableStructureCache[cacheKey] { + return cached + } + let primaryKeyName: String? + if let pkpk = primaryKey { + let pathDecoder = CRUDKeyPathsDecoder() + let pathInstance = try Self.init(from: pathDecoder) + guard let pkn = try pathDecoder.getKeyPathName(pathInstance, keyPath: pkpk) else { + throw CRUDSQLGenError("Could not get column name for primary key \(Self.self).") + } + primaryKeyName = pkn + } else if let key = columnDecoder.collectedKeys.filter({$0.type is PrimaryKeyWrapper.Type }).first { + primaryKeyName = key.name + } else if columnDecoder.collectedKeys.map({$0.0}).contains("id") { + primaryKeyName = "id" + } else { + primaryKeyName = nil + } + let thisTableName = columnDecoder.tableNamePath.last! + let tableStruct = TableStructure( + tableName: thisTableName, + columns: columnDecoder.collectedKeys.map { + var props: [TableStructure.Column.Property] = [] + if $0.0 == primaryKeyName { + props.append(.primaryKey) + } + if let foreignWrapper = $0.type as? ForeignKeyWrapper.Type, + let foreignInfo = try? foreignWrapper.foreignTableStructure(), + let foreignPK = foreignInfo.columns.first(where: { $0.properties.contains(.primaryKey) }) { + props.append(.foreignKey(foreignInfo.tableName, foreignPK.name, foreignWrapper.foreignKeyDeleteAction(), foreignWrapper.foreignKeyUpdateAction())) + } + let itype: Any.Type + if let wrapper = $0.type as? WrappedCodableProvider.Type { + itype = wrapper.provideWrappedValueType() + } else { + itype = $0.type + } + return .init(name: $0.name, type: itype, optional: $0.optional, properties: props) + }, + subTables: [], + indexes: []) + tableStructureCache[cacheKey] = tableStruct + tableStruct.subTables = try columnDecoder.subTables.filter { !$0.matches(Self.self) }.map { + return try $0.tableStructure() + } + return tableStruct + } +} + +public struct Create { + typealias OverAllForm = OAF + let fromDatabase: D + let policy: TableCreatePolicy + let tableStructure: TableStructure + init(fromDatabase ft: D, primaryKey: PartialKeyPath?, policy p: TableCreatePolicy) throws { + fromDatabase = ft + policy = p + tableStructure = try OverAllForm.CRUDTableStructure(primaryKey: primaryKey) + let delegate = fromDatabase.configuration.sqlGenDelegate + let sql = try delegate.getCreateTableSQL(forTable: tableStructure, policy: policy) + for stat in sql { + CRUDLogging.log(.query, stat) + let exeDelegate = try fromDatabase.configuration.sqlExeDelegate(forSQL: stat) + _ = try exeDelegate.hasNext() + } + } +} + +public struct Index: FromTableProtocol, TableProtocol { + public typealias Form = OAF + public typealias FromTableType = A + public typealias OverAllForm = OAF + public let fromTable: FromTableType + init(fromTable ft: FromTableType, keys: [PartialKeyPath], unique: Bool) throws { + fromTable = ft + let delegate = ft.databaseConfiguration.sqlGenDelegate + let tableName = "\(OverAllForm.CRUDTableName)" + let pathDecoder = CRUDKeyPathsDecoder() + let pathInstance = try OverAllForm.init(from: pathDecoder) + let keyNames: [String] = try keys.map { + guard let pkn = try pathDecoder.getKeyPathName(pathInstance, keyPath: $0) else { + throw CRUDSQLGenError("Could not get column name for index \(OverAllForm.self).") + } + return pkn + } + let sql = try delegate.getCreateIndexSQL(forTable: tableName, on: keyNames, unique: unique) + for stat in sql { + CRUDLogging.log(.query, stat) + let exeDelegate = try ft.databaseConfiguration.sqlExeDelegate(forSQL: stat) + _ = try exeDelegate.hasNext() + } + } + public func setState(state: inout SQLGenState) throws {} + public func setSQL(state: inout SQLGenState) throws {} +} + +public extension DatabaseProtocol { + @discardableResult + func create(_ type: A.Type, policy: TableCreatePolicy = .defaultPolicy) throws -> Table { + let _: Create = try Create(fromDatabase: self, primaryKey: nil, policy: policy) + return Table(database: self) + } + @discardableResult + func create(_ type: A.Type, primaryKey: KeyPath? = nil, policy: TableCreatePolicy = .defaultPolicy) throws -> Table { + let _: Create = try Create(fromDatabase: self, primaryKey: primaryKey, policy: policy) + return Table(database: self) + } +} +// swiftlint:disable line_length +public extension Table { + @discardableResult + func index(unique: Bool = false, _ keys: PartialKeyPath...) throws -> Index { + return try .init(fromTable: self, keys: keys, unique: unique) + } + // !FIX! Swift 4.0.2 seems to have a problem with type inference for the above func + // would not let \.name type references to be used + // this is an ugly work around + @discardableResult + func index(unique: Bool = false, _ key: KeyPath) throws -> Index { + return try .init(fromTable: self, keys: [key], unique: unique) + } + @discardableResult + func index(unique: Bool = false, _ key: KeyPath, _ key2: KeyPath) throws -> Index { + return try .init(fromTable: self, keys: [key, key2], unique: unique) + } + @discardableResult + func index(unique: Bool = false, _ key: KeyPath, _ key2: KeyPath, _ key3: KeyPath) throws -> Index { + return try .init(fromTable: self, keys: [key, key2, key3], unique: unique) + } + @discardableResult + func index(unique: Bool = false, _ key: KeyPath, _ key2: KeyPath, _ key3: KeyPath, _ key4: KeyPath) throws -> Index { + return try .init(fromTable: self, keys: [key, key2, key3, key4], unique: unique) + } + @discardableResult + func index(unique: Bool = false, _ key: KeyPath, _ key2: KeyPath, _ key3: KeyPath, _ key4: KeyPath, _ key5: KeyPath) throws -> Index { + return try .init(fromTable: self, keys: [key, key2, key3, key4, key5], unique: unique) + } +} diff --git a/Sources/PerfectCRUD/Database.swift b/Sources/PerfectCRUD/Database.swift new file mode 100644 index 00000000..f607f2cd --- /dev/null +++ b/Sources/PerfectCRUD/Database.swift @@ -0,0 +1,70 @@ +// +// PerfectCRUDDatabase.swift +// PerfectCRUD +// +// Created by Kyle Jessup on 2017-12-02. +// + +import Foundation + +public struct Database: DatabaseProtocol { + public typealias Configuration = C + public let configuration: Configuration + public init(configuration c: Configuration) { + configuration = c + } + public func table(_ form: T.Type) -> Table { + return .init(database: self) + } +} + +public extension Database { + func sql(_ sql: String, bindings: Bindings = []) throws { + CRUDLogging.log(.query, sql) + let delegate = try configuration.sqlExeDelegate(forSQL: sql) + try delegate.bind(bindings, skip: 0) + _ = try delegate.hasNext() + } + func sql(_ sql: String, bindings: Bindings = [], _ type: A.Type) throws -> [A] { + CRUDLogging.log(.query, sql) + let delegate = try configuration.sqlExeDelegate(forSQL: sql) + try delegate.bind(bindings, skip: 0) + var ret: [A] = [] + while try delegate.hasNext() { + let rowDecoder: CRUDRowDecoder = CRUDRowDecoder(delegate: delegate) + ret.append(try A(from: rowDecoder)) + } + return ret + } + func asyncSql(_ sql: String, bindings: Bindings = [], _ type: A.Type, completion: @escaping ([A], Error?) -> ()) throws { + CRUDLogging.log(.query, sql) + let delegate = try configuration.sqlExeDelegate(forSQL: sql) + try delegate.bind(bindings, skip: 0) + delegate.asyncExecute { delegate in + var ret: [A] = [] + do { + while try delegate.hasNext() { + let rowDecoder: CRUDRowDecoder = CRUDRowDecoder(delegate: delegate) + ret.append(try A(from: rowDecoder)) + } + completion(ret, nil) + } catch { + completion(ret, error) + } + } + } +} + +public extension Database { + func transaction(_ body: () throws -> T) throws -> T { + try sql("BEGIN") + do { + let r = try body() + try sql("COMMIT") + return r + } catch { + try sql("ROLLBACK") + throw error + } + } +} diff --git a/Sources/PerfectCRUD/Delete.swift b/Sources/PerfectCRUD/Delete.swift new file mode 100644 index 00000000..dfe046fc --- /dev/null +++ b/Sources/PerfectCRUD/Delete.swift @@ -0,0 +1,41 @@ +// +// PerfectCRUDDelete.swift +// PerfectCRUD +// +// Created by Kyle Jessup on 2017-12-03. +// + +import Foundation + +public protocol Deleteable: TableProtocol { + @discardableResult + func delete() throws -> Delete +} + +public extension Deleteable { + @discardableResult + func delete() throws -> Delete { + return try .init(fromTable: self) + } +} + +public struct Delete: FromTableProtocol, CommandProtocol { + public typealias FromTableType = A + public typealias OverAllForm = OAF + public let fromTable: FromTableType + public let sqlGenState: SQLGenState + init(fromTable ft: FromTableType) throws { + fromTable = ft + let delegate = ft.databaseConfiguration.sqlGenDelegate + var state = SQLGenState(delegate: delegate) + state.command = .delete + try ft.setState(state: &state) + try ft.setSQL(state: &state) + sqlGenState = state + for stat in state.statements { + let exeDelegate = try databaseConfiguration.sqlExeDelegate(forSQL: stat.sql) + try exeDelegate.bind(stat.bindings) + _ = try exeDelegate.hasNext() + } + } +} diff --git a/Sources/PerfectCRUD/Expression/Comparison.swift b/Sources/PerfectCRUD/Expression/Comparison.swift new file mode 100644 index 00000000..4b66a3c8 --- /dev/null +++ b/Sources/PerfectCRUD/Expression/Comparison.swift @@ -0,0 +1,73 @@ +// +// Comparison.swift +// PerfectCRUD +// +// Created by Kyle Jessup on 2018-02-18. +// + +import Foundation + +// < +public func < (lhs: KeyPath, rhs: String) -> CRUDBooleanExpression { + return RealBooleanExpression(.lessThan(lhs: .keyPath(lhs), rhs: .string(rhs))) +} +public func < (lhs: KeyPath, rhs: Double) -> CRUDBooleanExpression { + return RealBooleanExpression(.lessThan(lhs: .keyPath(lhs), rhs: .decimal(rhs))) +} +public func < (lhs: KeyPath, rhs: Bool) -> CRUDBooleanExpression { + return RealBooleanExpression(.lessThan(lhs: .keyPath(lhs), rhs: .bool(rhs))) +} +public func < (lhs: KeyPath, rhs: UUID) -> CRUDBooleanExpression { + return RealBooleanExpression(.lessThan(lhs: .keyPath(lhs), rhs: .uuid(rhs))) +} +public func < (lhs: KeyPath, rhs: Date) -> CRUDBooleanExpression { + return RealBooleanExpression(.lessThan(lhs: .keyPath(lhs), rhs: .date(rhs))) +} +// > +public func > (lhs: KeyPath, rhs: String) -> CRUDBooleanExpression { + return RealBooleanExpression(.greaterThan(lhs: .keyPath(lhs), rhs: .string(rhs))) +} +public func > (lhs: KeyPath, rhs: Double) -> CRUDBooleanExpression { + return RealBooleanExpression(.greaterThan(lhs: .keyPath(lhs), rhs: .decimal(rhs))) +} +public func > (lhs: KeyPath, rhs: Bool) -> CRUDBooleanExpression { + return RealBooleanExpression(.greaterThan(lhs: .keyPath(lhs), rhs: .bool(rhs))) +} +public func > (lhs: KeyPath, rhs: UUID) -> CRUDBooleanExpression { + return RealBooleanExpression(.greaterThan(lhs: .keyPath(lhs), rhs: .uuid(rhs))) +} +public func > (lhs: KeyPath, rhs: Date) -> CRUDBooleanExpression { + return RealBooleanExpression(.greaterThan(lhs: .keyPath(lhs), rhs: .date(rhs))) +} +// <= +public func <= (lhs: KeyPath, rhs: String) -> CRUDBooleanExpression { + return RealBooleanExpression(.lessThanEqual(lhs: .keyPath(lhs), rhs: .string(rhs))) +} +public func <= (lhs: KeyPath, rhs: Double) -> CRUDBooleanExpression { + return RealBooleanExpression(.lessThanEqual(lhs: .keyPath(lhs), rhs: .decimal(rhs))) +} +public func <= (lhs: KeyPath, rhs: Bool) -> CRUDBooleanExpression { + return RealBooleanExpression(.lessThanEqual(lhs: .keyPath(lhs), rhs: .bool(rhs))) +} +public func <= (lhs: KeyPath, rhs: UUID) -> CRUDBooleanExpression { + return RealBooleanExpression(.lessThanEqual(lhs: .keyPath(lhs), rhs: .uuid(rhs))) +} +public func <= (lhs: KeyPath, rhs: Date) -> CRUDBooleanExpression { + return RealBooleanExpression(.lessThanEqual(lhs: .keyPath(lhs), rhs: .date(rhs))) +} +// >= +public func >= (lhs: KeyPath, rhs: String) -> CRUDBooleanExpression { + return RealBooleanExpression(.greaterThanEqual(lhs: .keyPath(lhs), rhs: .string(rhs))) +} +public func >= (lhs: KeyPath, rhs: Double) -> CRUDBooleanExpression { + return RealBooleanExpression(.greaterThanEqual(lhs: .keyPath(lhs), rhs: .decimal(rhs))) +} +public func >= (lhs: KeyPath, rhs: Bool) -> CRUDBooleanExpression { + return RealBooleanExpression(.greaterThanEqual(lhs: .keyPath(lhs), rhs: .bool(rhs))) +} +public func >= (lhs: KeyPath, rhs: UUID) -> CRUDBooleanExpression { + return RealBooleanExpression(.greaterThanEqual(lhs: .keyPath(lhs), rhs: .uuid(rhs))) +} +public func >= (lhs: KeyPath, rhs: Date) -> CRUDBooleanExpression { + return RealBooleanExpression(.greaterThanEqual(lhs: .keyPath(lhs), rhs: .date(rhs))) +} diff --git a/Sources/PerfectCRUD/Expression/ComparisonInts.swift b/Sources/PerfectCRUD/Expression/ComparisonInts.swift new file mode 100644 index 00000000..504d5859 --- /dev/null +++ b/Sources/PerfectCRUD/Expression/ComparisonInts.swift @@ -0,0 +1,157 @@ +// +// ComparisonInts.swift +// PerfectCRUD +// +// Created by Kyle Jessup on 2018-03-11. +// + +import Foundation + +// < +public func < (lhs: KeyPath, rhs: Int) -> CRUDBooleanExpression { + return RealBooleanExpression(.lessThan(lhs: .keyPath(lhs), rhs: .integer(rhs))) +} +public func < (lhs: KeyPath, rhs: UInt) -> CRUDBooleanExpression { + return RealBooleanExpression(.lessThan(lhs: .keyPath(lhs), rhs: .uinteger(rhs))) +} +public func < (lhs: KeyPath, rhs: Int64) -> CRUDBooleanExpression { + return RealBooleanExpression(.lessThan(lhs: .keyPath(lhs), rhs: .integer64(rhs))) +} +public func < (lhs: KeyPath, rhs: UInt64) -> CRUDBooleanExpression { + return RealBooleanExpression(.lessThan(lhs: .keyPath(lhs), rhs: .uinteger64(rhs))) +} +public func < (lhs: KeyPath, rhs: Int32) -> CRUDBooleanExpression { + return RealBooleanExpression(.lessThan(lhs: .keyPath(lhs), rhs: .integer32(rhs))) +} +public func < (lhs: KeyPath, rhs: UInt32) -> CRUDBooleanExpression { + return RealBooleanExpression(.lessThan(lhs: .keyPath(lhs), rhs: .uinteger32(rhs))) +} +public func < (lhs: KeyPath, rhs: Int16) -> CRUDBooleanExpression { + return RealBooleanExpression(.lessThan(lhs: .keyPath(lhs), rhs: .integer16(rhs))) +} +public func < (lhs: KeyPath, rhs: UInt16) -> CRUDBooleanExpression { + return RealBooleanExpression(.lessThan(lhs: .keyPath(lhs), rhs: .uinteger16(rhs))) +} +public func < (lhs: KeyPath, rhs: Int8) -> CRUDBooleanExpression { + return RealBooleanExpression(.lessThan(lhs: .keyPath(lhs), rhs: .integer8(rhs))) +} +public func < (lhs: KeyPath, rhs: UInt8) -> CRUDBooleanExpression { + return RealBooleanExpression(.lessThan(lhs: .keyPath(lhs), rhs: .uinteger8(rhs))) +} +public func < (lhs: KeyPath, rhs: [UInt8]) -> CRUDBooleanExpression { + return RealBooleanExpression(.lessThan(lhs: .keyPath(lhs), rhs: .blob(rhs))) +} +public func < (lhs: KeyPath, rhs: [Int8]) -> CRUDBooleanExpression { + return RealBooleanExpression(.lessThan(lhs: .keyPath(lhs), rhs: .sblob(rhs))) +} +// > +public func > (lhs: KeyPath, rhs: Int) -> CRUDBooleanExpression { + return RealBooleanExpression(.greaterThan(lhs: .keyPath(lhs), rhs: .integer(rhs))) +} +public func > (lhs: KeyPath, rhs: UInt) -> CRUDBooleanExpression { + return RealBooleanExpression(.greaterThan(lhs: .keyPath(lhs), rhs: .uinteger(rhs))) +} +public func > (lhs: KeyPath, rhs: Int64) -> CRUDBooleanExpression { + return RealBooleanExpression(.greaterThan(lhs: .keyPath(lhs), rhs: .integer64(rhs))) +} +public func > (lhs: KeyPath, rhs: UInt64) -> CRUDBooleanExpression { + return RealBooleanExpression(.greaterThan(lhs: .keyPath(lhs), rhs: .uinteger64(rhs))) +} +public func > (lhs: KeyPath, rhs: Int32) -> CRUDBooleanExpression { + return RealBooleanExpression(.greaterThan(lhs: .keyPath(lhs), rhs: .integer32(rhs))) +} +public func > (lhs: KeyPath, rhs: UInt32) -> CRUDBooleanExpression { + return RealBooleanExpression(.greaterThan(lhs: .keyPath(lhs), rhs: .uinteger32(rhs))) +} +public func > (lhs: KeyPath, rhs: Int16) -> CRUDBooleanExpression { + return RealBooleanExpression(.greaterThan(lhs: .keyPath(lhs), rhs: .integer16(rhs))) +} +public func > (lhs: KeyPath, rhs: UInt16) -> CRUDBooleanExpression { + return RealBooleanExpression(.greaterThan(lhs: .keyPath(lhs), rhs: .uinteger16(rhs))) +} +public func > (lhs: KeyPath, rhs: Int8) -> CRUDBooleanExpression { + return RealBooleanExpression(.greaterThan(lhs: .keyPath(lhs), rhs: .integer8(rhs))) +} +public func > (lhs: KeyPath, rhs: UInt8) -> CRUDBooleanExpression { + return RealBooleanExpression(.greaterThan(lhs: .keyPath(lhs), rhs: .uinteger8(rhs))) +} +public func > (lhs: KeyPath, rhs: [UInt8]) -> CRUDBooleanExpression { + return RealBooleanExpression(.greaterThan(lhs: .keyPath(lhs), rhs: .blob(rhs))) +} +public func > (lhs: KeyPath, rhs: [Int8]) -> CRUDBooleanExpression { + return RealBooleanExpression(.greaterThan(lhs: .keyPath(lhs), rhs: .sblob(rhs))) +} +// <= +public func <= (lhs: KeyPath, rhs: Int) -> CRUDBooleanExpression { + return RealBooleanExpression(.lessThanEqual(lhs: .keyPath(lhs), rhs: .integer(rhs))) +} +public func <= (lhs: KeyPath, rhs: UInt) -> CRUDBooleanExpression { + return RealBooleanExpression(.lessThanEqual(lhs: .keyPath(lhs), rhs: .uinteger(rhs))) +} +public func <= (lhs: KeyPath, rhs: Int64) -> CRUDBooleanExpression { + return RealBooleanExpression(.lessThanEqual(lhs: .keyPath(lhs), rhs: .integer64(rhs))) +} +public func <= (lhs: KeyPath, rhs: UInt64) -> CRUDBooleanExpression { + return RealBooleanExpression(.lessThanEqual(lhs: .keyPath(lhs), rhs: .uinteger64(rhs))) +} +public func <= (lhs: KeyPath, rhs: Int32) -> CRUDBooleanExpression { + return RealBooleanExpression(.lessThanEqual(lhs: .keyPath(lhs), rhs: .integer32(rhs))) +} +public func <= (lhs: KeyPath, rhs: UInt32) -> CRUDBooleanExpression { + return RealBooleanExpression(.lessThanEqual(lhs: .keyPath(lhs), rhs: .uinteger32(rhs))) +} +public func <= (lhs: KeyPath, rhs: Int16) -> CRUDBooleanExpression { + return RealBooleanExpression(.lessThanEqual(lhs: .keyPath(lhs), rhs: .integer16(rhs))) +} +public func <= (lhs: KeyPath, rhs: UInt16) -> CRUDBooleanExpression { + return RealBooleanExpression(.lessThanEqual(lhs: .keyPath(lhs), rhs: .uinteger16(rhs))) +} +public func <= (lhs: KeyPath, rhs: Int8) -> CRUDBooleanExpression { + return RealBooleanExpression(.lessThanEqual(lhs: .keyPath(lhs), rhs: .integer8(rhs))) +} +public func <= (lhs: KeyPath, rhs: UInt8) -> CRUDBooleanExpression { + return RealBooleanExpression(.lessThanEqual(lhs: .keyPath(lhs), rhs: .uinteger8(rhs))) +} +public func <= (lhs: KeyPath, rhs: [UInt8]) -> CRUDBooleanExpression { + return RealBooleanExpression(.lessThanEqual(lhs: .keyPath(lhs), rhs: .blob(rhs))) +} +public func <= (lhs: KeyPath, rhs: [Int8]) -> CRUDBooleanExpression { + return RealBooleanExpression(.lessThanEqual(lhs: .keyPath(lhs), rhs: .sblob(rhs))) +} +// >= +public func >= (lhs: KeyPath, rhs: Int) -> CRUDBooleanExpression { + return RealBooleanExpression(.greaterThanEqual(lhs: .keyPath(lhs), rhs: .integer(rhs))) +} +public func >= (lhs: KeyPath, rhs: UInt) -> CRUDBooleanExpression { + return RealBooleanExpression(.greaterThanEqual(lhs: .keyPath(lhs), rhs: .uinteger(rhs))) +} +public func >= (lhs: KeyPath, rhs: Int64) -> CRUDBooleanExpression { + return RealBooleanExpression(.greaterThanEqual(lhs: .keyPath(lhs), rhs: .integer64(rhs))) +} +public func >= (lhs: KeyPath, rhs: UInt64) -> CRUDBooleanExpression { + return RealBooleanExpression(.greaterThanEqual(lhs: .keyPath(lhs), rhs: .uinteger64(rhs))) +} +public func >= (lhs: KeyPath, rhs: Int32) -> CRUDBooleanExpression { + return RealBooleanExpression(.greaterThanEqual(lhs: .keyPath(lhs), rhs: .integer32(rhs))) +} +public func >= (lhs: KeyPath, rhs: UInt32) -> CRUDBooleanExpression { + return RealBooleanExpression(.greaterThanEqual(lhs: .keyPath(lhs), rhs: .uinteger32(rhs))) +} +public func >= (lhs: KeyPath, rhs: Int16) -> CRUDBooleanExpression { + return RealBooleanExpression(.greaterThanEqual(lhs: .keyPath(lhs), rhs: .integer16(rhs))) +} +public func >= (lhs: KeyPath, rhs: UInt16) -> CRUDBooleanExpression { + return RealBooleanExpression(.greaterThanEqual(lhs: .keyPath(lhs), rhs: .uinteger16(rhs))) +} +public func >= (lhs: KeyPath, rhs: Int8) -> CRUDBooleanExpression { + return RealBooleanExpression(.greaterThanEqual(lhs: .keyPath(lhs), rhs: .integer8(rhs))) +} +public func >= (lhs: KeyPath, rhs: UInt8) -> CRUDBooleanExpression { + return RealBooleanExpression(.greaterThanEqual(lhs: .keyPath(lhs), rhs: .uinteger8(rhs))) +} +public func >= (lhs: KeyPath, rhs: [UInt8]) -> CRUDBooleanExpression { + return RealBooleanExpression(.greaterThanEqual(lhs: .keyPath(lhs), rhs: .blob(rhs))) +} +public func >= (lhs: KeyPath, rhs: [Int8]) -> CRUDBooleanExpression { + return RealBooleanExpression(.greaterThanEqual(lhs: .keyPath(lhs), rhs: .sblob(rhs))) +} diff --git a/Sources/PerfectCRUD/Expression/Equality.swift b/Sources/PerfectCRUD/Expression/Equality.swift new file mode 100644 index 00000000..42b82364 --- /dev/null +++ b/Sources/PerfectCRUD/Expression/Equality.swift @@ -0,0 +1,121 @@ +// +// Equality.swift +// PerfectCRUD +// +// Created by Kyle Jessup on 2018-02-18. +// + +import Foundation + +// == +public func == (lhs: KeyPath, rhs: String) -> CRUDBooleanExpression { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .string(rhs))) +} +public func == (lhs: KeyPath, rhs: Double) -> CRUDBooleanExpression { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .decimal(rhs))) +} +public func == (lhs: KeyPath, rhs: Float) -> CRUDBooleanExpression { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .float(rhs))) +} +public func == (lhs: KeyPath, rhs: Bool) -> CRUDBooleanExpression { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .bool(rhs))) +} +public func == (lhs: KeyPath, rhs: UUID) -> CRUDBooleanExpression { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .uuid(rhs))) +} +public func == (lhs: KeyPath, rhs: Date) -> CRUDBooleanExpression { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .date(rhs))) +} +// == ? +public func == (lhs: KeyPath, rhs: String?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .string(rhs))) + } + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .null)) +} +public func == (lhs: KeyPath, rhs: Double?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .decimal(rhs))) + } + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .null)) +} +public func == (lhs: KeyPath, rhs: Float?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .float(rhs))) + } + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .null)) +} +public func == (lhs: KeyPath, rhs: Bool?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .bool(rhs))) + } + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .null)) +} +public func == (lhs: KeyPath, rhs: UUID?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .uuid(rhs))) + } + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .null)) +} +public func == (lhs: KeyPath, rhs: Date?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .date(rhs))) + } + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .null)) +} +// != +public func != (lhs: KeyPath, rhs: String) -> CRUDBooleanExpression { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .string(rhs))) +} +public func != (lhs: KeyPath, rhs: Double) -> CRUDBooleanExpression { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .decimal(rhs))) +} +public func != (lhs: KeyPath, rhs: Float) -> CRUDBooleanExpression { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .float(rhs))) +} +public func != (lhs: KeyPath, rhs: Bool) -> CRUDBooleanExpression { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .bool(rhs))) +} +public func != (lhs: KeyPath, rhs: UUID) -> CRUDBooleanExpression { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .uuid(rhs))) +} +public func != (lhs: KeyPath, rhs: Date) -> CRUDBooleanExpression { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .date(rhs))) +} +// != ? +public func != (lhs: KeyPath, rhs: String?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .string(rhs))) + } + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .null)) +} +public func != (lhs: KeyPath, rhs: Double?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .decimal(rhs))) + } + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .null)) +} +public func != (lhs: KeyPath, rhs: Float?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .float(rhs))) + } + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .null)) +} +public func != (lhs: KeyPath, rhs: Bool?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .bool(rhs))) + } + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .null)) +} +public func != (lhs: KeyPath, rhs: UUID?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .uuid(rhs))) + } + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .null)) +} +public func != (lhs: KeyPath, rhs: Date?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .date(rhs))) + } + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .null)) +} diff --git a/Sources/PerfectCRUD/Expression/EqualityInts.swift b/Sources/PerfectCRUD/Expression/EqualityInts.swift new file mode 100644 index 00000000..f147cb67 --- /dev/null +++ b/Sources/PerfectCRUD/Expression/EqualityInts.swift @@ -0,0 +1,367 @@ +// +// EqualityInts.swift +// PerfectCRUD +// +// Created by Kyle Jessup on 2018-03-11. +// + +import Foundation + +// == +public func == (lhs: KeyPath, rhs: Int) -> CRUDBooleanExpression { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .integer(rhs))) +} +public func == (lhs: KeyPath, rhs: UInt) -> CRUDBooleanExpression { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .uinteger(rhs))) +} +public func == (lhs: KeyPath, rhs: Int) -> CRUDBooleanExpression { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .uinteger(UInt(rhs)))) +} +public func == (lhs: KeyPath, rhs: Int64) -> CRUDBooleanExpression { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .integer64(rhs))) +} +public func == (lhs: KeyPath, rhs: Int) -> CRUDBooleanExpression { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .integer64(Int64(rhs)))) +} +public func == (lhs: KeyPath, rhs: UInt64) -> CRUDBooleanExpression { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .uinteger64(rhs))) +} +public func == (lhs: KeyPath, rhs: Int) -> CRUDBooleanExpression { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .uinteger64(UInt64(rhs)))) +} +public func == (lhs: KeyPath, rhs: Int32) -> CRUDBooleanExpression { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .integer32(rhs))) +} +public func == (lhs: KeyPath, rhs: Int) -> CRUDBooleanExpression { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .integer32(Int32(rhs)))) +} +public func == (lhs: KeyPath, rhs: UInt32) -> CRUDBooleanExpression { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .uinteger32(rhs))) +} +public func == (lhs: KeyPath, rhs: Int) -> CRUDBooleanExpression { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .uinteger32(UInt32(rhs)))) +} +public func == (lhs: KeyPath, rhs: Int16) -> CRUDBooleanExpression { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .integer16(rhs))) +} +public func == (lhs: KeyPath, rhs: Int) -> CRUDBooleanExpression { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .integer16(Int16(rhs)))) +} +public func == (lhs: KeyPath, rhs: UInt16) -> CRUDBooleanExpression { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .uinteger16(rhs))) +} +public func == (lhs: KeyPath, rhs: Int) -> CRUDBooleanExpression { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .uinteger16(UInt16(rhs)))) +} +public func == (lhs: KeyPath, rhs: Int8) -> CRUDBooleanExpression { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .integer8(rhs))) +} +public func == (lhs: KeyPath, rhs: Int) -> CRUDBooleanExpression { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .integer8(Int8(rhs)))) +} +public func == (lhs: KeyPath, rhs: UInt8) -> CRUDBooleanExpression { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .uinteger8(rhs))) +} +public func == (lhs: KeyPath, rhs: Int) -> CRUDBooleanExpression { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .uinteger8(UInt8(rhs)))) +} +public func == (lhs: KeyPath, rhs: [UInt8]) -> CRUDBooleanExpression { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .blob(rhs))) +} +public func == (lhs: KeyPath, rhs: [Int8]) -> CRUDBooleanExpression { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .sblob(rhs))) +} +// == ? +public func == (lhs: KeyPath, rhs: Int?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .integer(rhs))) + } + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .null)) +} +public func == (lhs: KeyPath, rhs: UInt?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .uinteger(rhs))) + } + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .null)) +} +public func == (lhs: KeyPath, rhs: Int?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .uinteger(UInt(rhs)))) + } + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .null)) +} +public func == (lhs: KeyPath, rhs: Int64?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .integer64(rhs))) + } + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .null)) +} +public func == (lhs: KeyPath, rhs: Int?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .integer64(Int64(rhs)))) + } + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .null)) +} +public func == (lhs: KeyPath, rhs: UInt64?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .uinteger64(rhs))) + } + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .null)) +} +public func == (lhs: KeyPath, rhs: Int?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .uinteger64(UInt64(rhs)))) + } + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .null)) +} +public func == (lhs: KeyPath, rhs: Int32?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .integer32(rhs))) + } + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .null)) +} +public func == (lhs: KeyPath, rhs: Int?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .integer32(Int32(rhs)))) + } + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .null)) +} +public func == (lhs: KeyPath, rhs: UInt32?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .uinteger32(rhs))) + } + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .null)) +} +public func == (lhs: KeyPath, rhs: Int?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .uinteger32(UInt32(rhs)))) + } + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .null)) +} +public func == (lhs: KeyPath, rhs: Int16?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .integer16(rhs))) + } + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .null)) +} +public func == (lhs: KeyPath, rhs: Int?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .integer16(Int16(rhs)))) + } + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .null)) +} +public func == (lhs: KeyPath, rhs: UInt16?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .uinteger16(rhs))) + } + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .null)) +} +public func == (lhs: KeyPath, rhs: Int?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .uinteger16(UInt16(rhs)))) + } + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .null)) +} +public func == (lhs: KeyPath, rhs: Int8?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .integer8(rhs))) + } + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .null)) +} +public func == (lhs: KeyPath, rhs: Int?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .integer8(Int8(rhs)))) + } + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .null)) +} +public func == (lhs: KeyPath, rhs: UInt8?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .uinteger8(rhs))) + } + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .null)) +} +public func == (lhs: KeyPath, rhs: Int?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .uinteger8(UInt8(rhs)))) + } + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .null)) +} +public func == (lhs: KeyPath, rhs: [UInt8]?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .blob(rhs))) + } + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .null)) +} +public func == (lhs: KeyPath, rhs: [Int8]?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .sblob(rhs))) + } + return RealBooleanExpression(.equality(lhs: .keyPath(lhs), rhs: .null)) +} +// != +public func != (lhs: KeyPath, rhs: Int) -> CRUDBooleanExpression { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .integer(rhs))) +} +public func != (lhs: KeyPath, rhs: UInt) -> CRUDBooleanExpression { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .uinteger(rhs))) +} +public func != (lhs: KeyPath, rhs: Int) -> CRUDBooleanExpression { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .uinteger(UInt(rhs)))) +} +public func != (lhs: KeyPath, rhs: Int64) -> CRUDBooleanExpression { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .integer64(rhs))) +} +public func != (lhs: KeyPath, rhs: Int) -> CRUDBooleanExpression { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .integer64(Int64(rhs)))) +} +public func != (lhs: KeyPath, rhs: Int32) -> CRUDBooleanExpression { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .integer32(rhs))) +} +public func != (lhs: KeyPath, rhs: Int) -> CRUDBooleanExpression { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .integer32(Int32(rhs)))) +} +public func != (lhs: KeyPath, rhs: Int16) -> CRUDBooleanExpression { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .integer16(rhs))) +} +public func != (lhs: KeyPath, rhs: Int) -> CRUDBooleanExpression { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .integer16(Int16(rhs)))) +} +public func != (lhs: KeyPath, rhs: Int8) -> CRUDBooleanExpression { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .integer8(rhs))) +} +public func != (lhs: KeyPath, rhs: Int) -> CRUDBooleanExpression { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .integer8(Int8(rhs)))) +} +public func != (lhs: KeyPath, rhs: [UInt8]) -> CRUDBooleanExpression { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .blob(rhs))) +} +public func != (lhs: KeyPath, rhs: [Int8]) -> CRUDBooleanExpression { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .sblob(rhs))) +} +// != ? +public func != (lhs: KeyPath, rhs: Int?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .integer(rhs))) + } + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .null)) +} +public func != (lhs: KeyPath, rhs: UInt?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .uinteger(rhs))) + } + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .null)) +} +public func != (lhs: KeyPath, rhs: Int?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .uinteger(UInt(rhs)))) + } + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .null)) +} +public func != (lhs: KeyPath, rhs: Int64?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .integer64(rhs))) + } + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .null)) +} +public func != (lhs: KeyPath, rhs: Int?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .integer64(Int64(rhs)))) + } + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .null)) +} +public func != (lhs: KeyPath, rhs: UInt64?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .uinteger64(rhs))) + } + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .null)) +} +public func != (lhs: KeyPath, rhs: Int?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .uinteger64(UInt64(rhs)))) + } + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .null)) +} +public func != (lhs: KeyPath, rhs: Int32?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .integer32(rhs))) + } + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .null)) +} +public func != (lhs: KeyPath, rhs: Int?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .integer32(Int32(rhs)))) + } + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .null)) +} +public func != (lhs: KeyPath, rhs: UInt32?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .uinteger32(rhs))) + } + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .null)) +} +public func != (lhs: KeyPath, rhs: Int?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .uinteger32(UInt32(rhs)))) + } + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .null)) +} +public func != (lhs: KeyPath, rhs: Int16?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .integer16(rhs))) + } + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .null)) +} +public func != (lhs: KeyPath, rhs: Int?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .integer16(Int16(rhs)))) + } + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .null)) +} +public func != (lhs: KeyPath, rhs: UInt16?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .uinteger16(rhs))) + } + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .null)) +} +public func != (lhs: KeyPath, rhs: Int?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .uinteger16(UInt16(rhs)))) + } + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .null)) +} +public func != (lhs: KeyPath, rhs: Int8?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .integer8(rhs))) + } + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .null)) +} +public func != (lhs: KeyPath, rhs: Int?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .integer8(Int8(rhs)))) + } + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .null)) +} +public func != (lhs: KeyPath, rhs: UInt8?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .uinteger8(rhs))) + } + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .null)) +} +public func != (lhs: KeyPath, rhs: Int?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .uinteger8(UInt8(rhs)))) + } + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .null)) +} +public func != (lhs: KeyPath, rhs: [UInt8]?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .blob(rhs))) + } + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .null)) +} +public func != (lhs: KeyPath, rhs: [Int8]?) -> CRUDBooleanExpression { + if let rhs = rhs { + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .sblob(rhs))) + } + return RealBooleanExpression(.inequality(lhs: .keyPath(lhs), rhs: .null)) +} diff --git a/Sources/PerfectCRUD/Expression/Expression.swift b/Sources/PerfectCRUD/Expression/Expression.swift new file mode 100644 index 00000000..54092b47 --- /dev/null +++ b/Sources/PerfectCRUD/Expression/Expression.swift @@ -0,0 +1,204 @@ +// +// PerfectCRUDExpressions.swift +// PerfectCRUD +// +// Created by Kyle Jessup on 2017-11-22. +// Copyright (C) 2017 PerfectlySoft, Inc. +// +// ===----------------------------------------------------------------------===// +// +// This source file is part of the Perfect.org open source project +// +// Copyright (c) 2015 - 2017 PerfectlySoft Inc. and the Perfect project authors +// Licensed under Apache License v2.0 +// +// See http://perfect.org/licensing.html for license information +// +// ===----------------------------------------------------------------------===// +// +import Foundation + +public indirect enum CRUDExpression { + public typealias ExpressionProducer = () -> CRUDExpression + + case column(String) + case and(lhs: CRUDExpression, rhs: CRUDExpression) + case or(lhs: CRUDExpression, rhs: CRUDExpression) + case equality(lhs: CRUDExpression, rhs: CRUDExpression) + case inequality(lhs: CRUDExpression, rhs: CRUDExpression) + case not(rhs: CRUDExpression) + case lessThan(lhs: CRUDExpression, rhs: CRUDExpression) + case lessThanEqual(lhs: CRUDExpression, rhs: CRUDExpression) + case greaterThan(lhs: CRUDExpression, rhs: CRUDExpression) + case greaterThanEqual(lhs: CRUDExpression, rhs: CRUDExpression) + case `in`(lhs: CRUDExpression, rhs: [CRUDExpression]) + case like(lhs: CRUDExpression, wild1: Bool, String, wild2: Bool) + case lazy(ExpressionProducer) + case keyPath(AnyKeyPath) + + case integer(Int) + case uinteger(UInt) + case integer64(Int64) + case uinteger64(UInt64) + case integer32(Int32) + case uinteger32(UInt32) + case integer16(Int16) + case uinteger16(UInt16) + case integer8(Int8) + case uinteger8(UInt8) + + case decimal(Double) + case float(Float) + case string(String) + case blob([UInt8]) + case sblob([Int8]) + case bool(Bool) + case uuid(UUID) + case date(Date) + case url(URL) + case null + + // todo: + // .blob with Data + // .integer of varying width +} + +public protocol CRUDBooleanExpression { + var crudExpression: CRUDExpression { get } +} + +struct RealBooleanExpression: CRUDBooleanExpression { + let crudExpression: CRUDExpression + init(_ e: CRUDExpression) { + crudExpression = e + } +} + +infix operator ~: ComparisonPrecedence // IN, matches +infix operator !~: ComparisonPrecedence // NOT IN, matches +infix operator %=%: ComparisonPrecedence // LIKE %v% . string or regexp or in array +infix operator =%: ComparisonPrecedence // LIKE v% . string +infix operator %!=: ComparisonPrecedence // NOT LIKE %v . string +infix operator %!=%: ComparisonPrecedence // NOT LIKE %v% . string or regexp or array +infix operator !=%: ComparisonPrecedence // NOT LIKE v% . string + +extension CRUDExpression { + static func sqlSnippet(keyPath: AnyKeyPath, tableData: SQLGenState.TableData, state: SQLGenState) throws -> String { + let delegate = state.delegate + let rootType = type(of: keyPath).rootType + guard let modelInstance = tableData.modelInstance else { + throw CRUDSQLGenError("Unable to get table for KeyPath root \(rootType).") + } + guard let keyName = try tableData.keyPathDecoder.getKeyPathName(modelInstance, keyPath: keyPath) else { + throw CRUDSQLGenError("Unable to get KeyPath name for table \(rootType).") + } + let nameQ = try delegate.quote(identifier: keyName) + switch state.command { + case .select, .count: + let aliasQ = try delegate.quote(identifier: tableData.alias) + return "\(aliasQ).\(nameQ)" + case .insert, .update, .delete: + return nameQ + case .unknown: + throw CRUDSQLGenError("Can not process unknown command.") + } + } + func sqlSnippet(state: SQLGenState) throws -> String { + let delegate = state.delegate + switch self { + case .column(let name): + return try delegate.quote(identifier: name) + case .and(let lhs, let rhs): + return try binparen(state, "AND", lhs, rhs) + case .or(let lhs, let rhs): + return try binparen(state, "OR", lhs, rhs) + case .equality(let lhs, let rhs): + if case .null = rhs { + return "\(try lhs.sqlSnippet(state: state)) IS NULL" + } + return try bin(state, "=", lhs, rhs) + case .inequality(let lhs, let rhs): + if case .null = rhs { + return "\(try lhs.sqlSnippet(state: state)) IS NOT NULL" + } + return try bin(state, "!=", lhs, rhs) + case .not(let rhs): + let rhsStr = try rhs.sqlSnippet(state: state) + return "NOT (\(rhsStr))" + case .lessThan(let lhs, let rhs): + return try bin(state, "<", lhs, rhs) + case .lessThanEqual(let lhs, let rhs): + return try bin(state, "<=", lhs, rhs) + case .greaterThan(let lhs, let rhs): + return try bin(state, ">", lhs, rhs) + case .greaterThanEqual(let lhs, let rhs): + return try bin(state, ">=", lhs, rhs) + case .keyPath(let k): + let rootType = type(of: k).rootType + guard let tableData = state.getTableData(type: rootType) else { + throw CRUDSQLGenError("Unable to get table for KeyPath root \(rootType).") + } + return try CRUDExpression.sqlSnippet(keyPath: k, tableData: tableData, state: state) + case .null: + return "NULL" + case .lazy(let e): + return try e().sqlSnippet(state: state) + case .integer(_), .uinteger(_), .integer64(_), .uinteger64(_), .integer32(_), .uinteger32(_), .integer16(_), .uinteger16(_), .integer8(_), .uinteger8(_): + return try delegate.getBinding(for: self) + case .decimal(_), .float(_), .string(_), .blob(_), .sblob(_), .bool(_), .uuid(_), .date(_), .url(_): + return try delegate.getBinding(for: self) + case .in(let lhs, let exprs): + return "\(try lhs.sqlSnippet(state: state)) IN (\(try exprs.map { try $0.sqlSnippet(state: state) }.joined(separator: ",")))" + case .like(let lhs, let wild1, let match, let wild2): + let rhs = "\(wild1 ? "%" : "")\(match.replacingOccurrences(of: "%", with: "\\%"))\(wild2 ? "%" : "")" + return try bin(state, "LIKE", lhs, .string(rhs)) + } + } + private func bin(_ state: SQLGenState, _ op: String, _ lhs: CRUDExpression, _ rhs: CRUDExpression) throws -> String { + return "\(try lhs.sqlSnippet(state: state)) \(op) \(try rhs.sqlSnippet(state: state))" + } + private func binparen(_ state: SQLGenState, _ op: String, _ lhs: CRUDExpression, _ rhs: CRUDExpression) throws -> String { + return "(\(try lhs.sqlSnippet(state: state)) \(op) \(try rhs.sqlSnippet(state: state)))" + } + private func un(_ state: SQLGenState, _ op: String, _ rhs: CRUDExpression) throws -> String { + return "\(op) \(try rhs.sqlSnippet(state: state))" + } + func referencedTypes() -> [Any.Type] { + switch self { + case .column(_): + return [] + case .and(let lhs, let rhs): + return lhs.referencedTypes() + rhs.referencedTypes() + case .or(let lhs, let rhs): + return lhs.referencedTypes() + rhs.referencedTypes() + case .equality(let lhs, let rhs): + return lhs.referencedTypes() + rhs.referencedTypes() + case .inequality(let lhs, let rhs): + return lhs.referencedTypes() + rhs.referencedTypes() + case .not(let rhs): + return rhs.referencedTypes() + case .lessThan(let lhs, let rhs): + return lhs.referencedTypes() + rhs.referencedTypes() + case .lessThanEqual(let lhs, let rhs): + return lhs.referencedTypes() + rhs.referencedTypes() + case .greaterThan(let lhs, let rhs): + return lhs.referencedTypes() + rhs.referencedTypes() + case .greaterThanEqual(let lhs, let rhs): + return lhs.referencedTypes() + rhs.referencedTypes() + case .keyPath(let k): + return [type(of: k).rootType] + case .null: + return [] + case .lazy(let e): + return e().referencedTypes() + case .integer(_), .uinteger(_), .integer64(_), .uinteger64(_), .integer32(_), .uinteger32(_), .integer16(_), .uinteger16(_), .integer8(_), .uinteger8(_): + return [] + case .decimal(_), .float(_), .string(_), .blob(_), .sblob(_), .bool(_), .uuid(_), .date(_), .url(_): + return [] + case .in(let lhs, let exprs): + return lhs.referencedTypes() + exprs.flatMap { $0.referencedTypes() } + case .like(let lhs, _, _, _): + return lhs.referencedTypes() + } + } +} diff --git a/Sources/PerfectCRUD/Expression/In.swift b/Sources/PerfectCRUD/Expression/In.swift new file mode 100644 index 00000000..7ab57e91 --- /dev/null +++ b/Sources/PerfectCRUD/Expression/In.swift @@ -0,0 +1,47 @@ +// +// In.swift +// PerfectCRUD +// +// Created by Kyle Jessup on 2018-02-18. +// + +import Foundation + +// ~ IN +public func ~ (lhs: KeyPath, rhs: [String]) -> CRUDBooleanExpression { + return RealBooleanExpression(.in(lhs: .keyPath(lhs), rhs: rhs.map { .string($0) })) +} +public func ~ (lhs: KeyPath, rhs: [Double]) -> CRUDBooleanExpression { + return RealBooleanExpression(.in(lhs: .keyPath(lhs), rhs: rhs.map { .decimal($0) })) +} +public func ~ (lhs: KeyPath, rhs: [UUID]) -> CRUDBooleanExpression { + return RealBooleanExpression(.in(lhs: .keyPath(lhs), rhs: rhs.map { .uuid($0) })) +} +public func ~ (lhs: KeyPath, rhs: [Date]) -> CRUDBooleanExpression { + return RealBooleanExpression(.in(lhs: .keyPath(lhs), rhs: rhs.map { .date($0) })) +} +public func ~ (lhs: KeyPath, rhs: [String]) -> CRUDBooleanExpression { + return RealBooleanExpression(.in(lhs: .keyPath(lhs), rhs: rhs.map { .string($0) })) +} +public func ~ (lhs: KeyPath, rhs: [Double]) -> CRUDBooleanExpression { + return RealBooleanExpression(.in(lhs: .keyPath(lhs), rhs: rhs.map { .decimal($0) })) +} +public func ~ (lhs: KeyPath, rhs: [UUID]) -> CRUDBooleanExpression { + return RealBooleanExpression(.in(lhs: .keyPath(lhs), rhs: rhs.map { .uuid($0) })) +} +public func ~ (lhs: KeyPath, rhs: [Date]) -> CRUDBooleanExpression { + return RealBooleanExpression(.in(lhs: .keyPath(lhs), rhs: rhs.map { .date($0) })) +} +// !~ NOT IN +public func !~ (lhs: KeyPath, rhs: [String]) -> CRUDBooleanExpression { + return !(lhs ~ rhs) +} +public func !~ (lhs: KeyPath, rhs: [Double]) -> CRUDBooleanExpression { + return !(lhs ~ rhs) +} +public func !~ (lhs: KeyPath, rhs: [UUID]) -> CRUDBooleanExpression { + return !(lhs ~ rhs) +} +public func !~ (lhs: KeyPath, rhs: [Date]) -> CRUDBooleanExpression { + return !(lhs ~ rhs) +} diff --git a/Sources/PerfectCRUD/Expression/InInts.swift b/Sources/PerfectCRUD/Expression/InInts.swift new file mode 100644 index 00000000..2ad64a24 --- /dev/null +++ b/Sources/PerfectCRUD/Expression/InInts.swift @@ -0,0 +1,131 @@ +// +// InInts.swift +// PerfectCRUD +// +// Created by Kyle Jessup on 2018-03-11. +// + +import Foundation + +// ~ IN +public func ~ (lhs: KeyPath, rhs: [Int]) -> CRUDBooleanExpression { + return RealBooleanExpression(.in(lhs: .keyPath(lhs), rhs: rhs.map { .integer($0) })) +} +public func ~ (lhs: KeyPath, rhs: [Int]) -> CRUDBooleanExpression { + return RealBooleanExpression(.in(lhs: .keyPath(lhs), rhs: rhs.map { .integer($0) })) +} +public func ~ (lhs: KeyPath, rhs: [UInt]) -> CRUDBooleanExpression { + return RealBooleanExpression(.in(lhs: .keyPath(lhs), rhs: rhs.map { .uinteger($0) })) +} +public func ~ (lhs: KeyPath, rhs: [UInt]) -> CRUDBooleanExpression { + return RealBooleanExpression(.in(lhs: .keyPath(lhs), rhs: rhs.map { .uinteger($0) })) +} +public func ~ (lhs: KeyPath, rhs: [Int64]) -> CRUDBooleanExpression { + return RealBooleanExpression(.in(lhs: .keyPath(lhs), rhs: rhs.map { .integer64($0) })) +} +public func ~ (lhs: KeyPath, rhs: [Int64]) -> CRUDBooleanExpression { + return RealBooleanExpression(.in(lhs: .keyPath(lhs), rhs: rhs.map { .integer64($0) })) +} +public func ~ (lhs: KeyPath, rhs: [UInt64]) -> CRUDBooleanExpression { + return RealBooleanExpression(.in(lhs: .keyPath(lhs), rhs: rhs.map { .uinteger64($0) })) +} +public func ~ (lhs: KeyPath, rhs: [UInt64]) -> CRUDBooleanExpression { + return RealBooleanExpression(.in(lhs: .keyPath(lhs), rhs: rhs.map { .uinteger64($0) })) +} +public func ~ (lhs: KeyPath, rhs: [Int32]) -> CRUDBooleanExpression { + return RealBooleanExpression(.in(lhs: .keyPath(lhs), rhs: rhs.map { .integer32($0) })) +} +public func ~ (lhs: KeyPath, rhs: [Int32]) -> CRUDBooleanExpression { + return RealBooleanExpression(.in(lhs: .keyPath(lhs), rhs: rhs.map { .integer32($0) })) +} +public func ~ (lhs: KeyPath, rhs: [UInt32]) -> CRUDBooleanExpression { + return RealBooleanExpression(.in(lhs: .keyPath(lhs), rhs: rhs.map { .uinteger32($0) })) +} +public func ~ (lhs: KeyPath, rhs: [UInt32]) -> CRUDBooleanExpression { + return RealBooleanExpression(.in(lhs: .keyPath(lhs), rhs: rhs.map { .uinteger32($0) })) +} +public func ~ (lhs: KeyPath, rhs: [Int16]) -> CRUDBooleanExpression { + return RealBooleanExpression(.in(lhs: .keyPath(lhs), rhs: rhs.map { .integer16($0) })) +} +public func ~ (lhs: KeyPath, rhs: [Int16]) -> CRUDBooleanExpression { + return RealBooleanExpression(.in(lhs: .keyPath(lhs), rhs: rhs.map { .integer16($0) })) +} +public func ~ (lhs: KeyPath, rhs: [UInt16]) -> CRUDBooleanExpression { + return RealBooleanExpression(.in(lhs: .keyPath(lhs), rhs: rhs.map { .uinteger16($0) })) +} +public func ~ (lhs: KeyPath, rhs: [UInt16]) -> CRUDBooleanExpression { + return RealBooleanExpression(.in(lhs: .keyPath(lhs), rhs: rhs.map { .uinteger16($0) })) +} +public func ~ (lhs: KeyPath, rhs: [Int8]) -> CRUDBooleanExpression { + return RealBooleanExpression(.in(lhs: .keyPath(lhs), rhs: rhs.map { .integer8($0) })) +} +public func ~ (lhs: KeyPath, rhs: [Int8]) -> CRUDBooleanExpression { + return RealBooleanExpression(.in(lhs: .keyPath(lhs), rhs: rhs.map { .integer8($0) })) +} +public func ~ (lhs: KeyPath, rhs: [UInt8]) -> CRUDBooleanExpression { + return RealBooleanExpression(.in(lhs: .keyPath(lhs), rhs: rhs.map { .uinteger8($0) })) +} +public func ~ (lhs: KeyPath, rhs: [UInt8]) -> CRUDBooleanExpression { + return RealBooleanExpression(.in(lhs: .keyPath(lhs), rhs: rhs.map { .uinteger8($0) })) +} +// !~ NOT IN +public func !~ (lhs: KeyPath, rhs: [Int]) -> CRUDBooleanExpression { + return !(lhs ~ rhs) +} +public func !~ (lhs: KeyPath, rhs: [Int]) -> CRUDBooleanExpression { + return !(lhs ~ rhs) +} +public func !~ (lhs: KeyPath, rhs: [UInt]) -> CRUDBooleanExpression { + return !(lhs ~ rhs) +} +public func !~ (lhs: KeyPath, rhs: [UInt]) -> CRUDBooleanExpression { + return !(lhs ~ rhs) +} +public func !~ (lhs: KeyPath, rhs: [Int64]) -> CRUDBooleanExpression { + return !(lhs ~ rhs) +} +public func !~ (lhs: KeyPath, rhs: [Int64]) -> CRUDBooleanExpression { + return !(lhs ~ rhs) +} +public func !~ (lhs: KeyPath, rhs: [UInt64]) -> CRUDBooleanExpression { + return !(lhs ~ rhs) +} +public func !~ (lhs: KeyPath, rhs: [UInt64]) -> CRUDBooleanExpression { + return !(lhs ~ rhs) +} +public func !~ (lhs: KeyPath, rhs: [Int32]) -> CRUDBooleanExpression { + return !(lhs ~ rhs) +} +public func !~ (lhs: KeyPath, rhs: [Int32]) -> CRUDBooleanExpression { + return !(lhs ~ rhs) +} +public func !~ (lhs: KeyPath, rhs: [UInt32]) -> CRUDBooleanExpression { + return !(lhs ~ rhs) +} +public func !~ (lhs: KeyPath, rhs: [UInt32]) -> CRUDBooleanExpression { + return !(lhs ~ rhs) +} +public func !~ (lhs: KeyPath, rhs: [Int16]) -> CRUDBooleanExpression { + return !(lhs ~ rhs) +} +public func !~ (lhs: KeyPath, rhs: [Int16]) -> CRUDBooleanExpression { + return !(lhs ~ rhs) +} +public func !~ (lhs: KeyPath, rhs: [UInt16]) -> CRUDBooleanExpression { + return !(lhs ~ rhs) +} +public func !~ (lhs: KeyPath, rhs: [UInt16]) -> CRUDBooleanExpression { + return !(lhs ~ rhs) +} +public func !~ (lhs: KeyPath, rhs: [Int8]) -> CRUDBooleanExpression { + return !(lhs ~ rhs) +} +public func !~ (lhs: KeyPath, rhs: [Int8]) -> CRUDBooleanExpression { + return !(lhs ~ rhs) +} +public func !~ (lhs: KeyPath, rhs: [UInt8]) -> CRUDBooleanExpression { + return !(lhs ~ rhs) +} +public func !~ (lhs: KeyPath, rhs: [UInt8]) -> CRUDBooleanExpression { + return !(lhs ~ rhs) +} diff --git a/Sources/PerfectCRUD/Expression/Like.swift b/Sources/PerfectCRUD/Expression/Like.swift new file mode 100644 index 00000000..6d808f59 --- /dev/null +++ b/Sources/PerfectCRUD/Expression/Like.swift @@ -0,0 +1,49 @@ +// +// Like.swift +// PerfectCRUD +// +// Created by Kyle Jessup on 2018-02-18. +// + +import Foundation + +// %=% LIKE +public func %=% (lhs: KeyPath, rhs: String) -> CRUDBooleanExpression { + return RealBooleanExpression(.like(lhs: .keyPath(lhs), wild1: true, rhs, wild2: true)) +} +public func %=% (lhs: KeyPath, rhs: String) -> CRUDBooleanExpression { + return RealBooleanExpression(.like(lhs: .keyPath(lhs), wild1: true, rhs, wild2: true)) +} +// *~ LIKE v% +public func =% (lhs: KeyPath, rhs: String) -> CRUDBooleanExpression { + return RealBooleanExpression(.like(lhs: .keyPath(lhs), wild1: false, rhs, wild2: true)) +} +public func =% (lhs: KeyPath, rhs: String) -> CRUDBooleanExpression { + return RealBooleanExpression(.like(lhs: .keyPath(lhs), wild1: false, rhs, wild2: true)) +} +// ~* LIKE %v +public func %= (lhs: KeyPath, rhs: String) -> CRUDBooleanExpression { + return RealBooleanExpression(.like(lhs: .keyPath(lhs), wild1: true, rhs, wild2: false)) +} +public func %= (lhs: KeyPath, rhs: String) -> CRUDBooleanExpression { + return RealBooleanExpression(.like(lhs: .keyPath(lhs), wild1: true, rhs, wild2: false)) +} +// !~ NOT LIKE +public func %!=% (lhs: KeyPath, rhs: String) -> CRUDBooleanExpression { + return !(lhs %=% rhs) +} +public func %!=% (lhs: KeyPath, rhs: String) -> CRUDBooleanExpression { + return !(lhs %=% rhs) +} +public func !=% (lhs: KeyPath, rhs: String) -> CRUDBooleanExpression { + return !(lhs =% rhs) +} +public func !=% (lhs: KeyPath, rhs: String) -> CRUDBooleanExpression { + return !(lhs =% rhs) +} +public func %!= (lhs: KeyPath, rhs: String) -> CRUDBooleanExpression { + return !(lhs %= rhs) +} +public func %!= (lhs: KeyPath, rhs: String) -> CRUDBooleanExpression { + return !(lhs %= rhs) +} diff --git a/Sources/PerfectCRUD/Expression/Logical.swift b/Sources/PerfectCRUD/Expression/Logical.swift new file mode 100644 index 00000000..cb12fa6e --- /dev/null +++ b/Sources/PerfectCRUD/Expression/Logical.swift @@ -0,0 +1,21 @@ +// +// Logical.swift +// PerfectCRUD +// +// Created by Kyle Jessup on 2018-02-18. +// + +import Foundation + +// && +public func && (lhs: CRUDBooleanExpression, rhs: CRUDBooleanExpression) -> CRUDBooleanExpression { + return RealBooleanExpression(.and(lhs: lhs.crudExpression, rhs: rhs.crudExpression)) +} +// || +public func || (lhs: CRUDBooleanExpression, rhs: CRUDBooleanExpression) -> CRUDBooleanExpression { + return RealBooleanExpression(.or(lhs: lhs.crudExpression, rhs: rhs.crudExpression)) +} +// ! +public prefix func ! (rhs: CRUDBooleanExpression) -> CRUDBooleanExpression { + return RealBooleanExpression(.not(rhs: rhs.crudExpression)) +} diff --git a/Sources/PerfectCRUD/Insert.swift b/Sources/PerfectCRUD/Insert.swift new file mode 100644 index 00000000..7f95af3a --- /dev/null +++ b/Sources/PerfectCRUD/Insert.swift @@ -0,0 +1,102 @@ +// +// PerfectCRUDInsert.swift +// PerfectCRUD +// +// Created by Kyle Jessup on 2017-12-02. +// + +import Foundation + +public struct Insert: FromTableProtocol, CommandProtocol { + public typealias FromTableType = A + public typealias OverAllForm = OAF + public let fromTable: FromTableType + public let sqlGenState: SQLGenState + init(fromTable ft: FromTableType, instances: [OAF], includeKeys: [PartialKeyPath], excludeKeys: [PartialKeyPath]) throws { + fromTable = ft + let delegate = ft.databaseConfiguration.sqlGenDelegate + var state = SQLGenState(delegate: delegate) + state.command = .insert + try ft.setState(state: &state) + let td = state.tableData[0] + let kpDecoder = td.keyPathDecoder + guard let kpInstance = td.modelInstance else { + throw CRUDSQLGenError("Could not get model instance for key path decoder \(OAF.self)") + } + let includeNames: [String] + if includeKeys.isEmpty { + let columnDecoder = CRUDColumnNameDecoder() + _ = try OverAllForm.init(from: columnDecoder) + includeNames = columnDecoder.collectedKeys.map { $0.name } + } else { + includeNames = try includeKeys.map { + guard let n = try kpDecoder.getKeyPathName(kpInstance, keyPath: $0) else { + throw CRUDSQLGenError("Could not get key path name for \(OAF.self) \($0)") + } + return n + } + } + let excludeNames: [String] = try excludeKeys.map { + guard let n = try kpDecoder.getKeyPathName(kpInstance, keyPath: $0) else { + throw CRUDSQLGenError("Could not get key path name for \(OAF.self) \($0)") + } + return n + } + + let encoder = try CRUDBindingsEncoder(delegate: delegate) + try instances[0].encode(to: encoder) + + let bindings = try encoder.completedBindings(allKeys: includeNames, ignoreKeys: Set(excludeNames)) + let columnNames = try bindings.map { try delegate.quote(identifier: $0.column) } + let bindIdentifiers = bindings.map { $0.identifier } + + let nameQ = try delegate.quote(identifier: "\(OAF.CRUDTableName)") + let sqlStr: String + if columnNames.isEmpty { + sqlStr = "INSERT INTO \(nameQ) \(delegate.getEmptyInsertSnippet())" + } else { + sqlStr = "INSERT INTO \(nameQ) (\(columnNames.joined(separator: ", "))) VALUES (\(bindIdentifiers.joined(separator: ", ")))" + } + CRUDLogging.log(.query, sqlStr) + sqlGenState = state + let exeDelegate = try databaseConfiguration.sqlExeDelegate(forSQL: sqlStr) + try exeDelegate.bind(delegate.bindings) + _ = try exeDelegate.hasNext() + + for instance in instances[1...] { + let delegate = databaseConfiguration.sqlGenDelegate + let encoder = try CRUDBindingsEncoder(delegate: delegate) + try instance.encode(to: encoder) + _ = try encoder.completedBindings(allKeys: includeNames, ignoreKeys: Set(excludeNames)) + try exeDelegate.bind(delegate.bindings) + _ = try exeDelegate.hasNext() + } + } +} + +public extension Table { + @discardableResult + func insert(_ instances: [Form]) throws -> Insert { + return try .init(fromTable: self, instances: instances, includeKeys: [], excludeKeys: []) + } + @discardableResult + func insert(_ instance: Form) throws -> Insert { + return try .init(fromTable: self, instances: [instance], includeKeys: [], excludeKeys: []) + } + @discardableResult + func insert(_ instances: [Form], setKeys: KeyPath, _ rest: PartialKeyPath...) throws -> Insert { + return try .init(fromTable: self, instances: instances, includeKeys: [setKeys] + rest, excludeKeys: []) + } + @discardableResult + func insert(_ instance: Form, setKeys: KeyPath, _ rest: PartialKeyPath...) throws -> Insert { + return try .init(fromTable: self, instances: [instance], includeKeys: [setKeys] + rest, excludeKeys: []) + } + @discardableResult + func insert(_ instances: [Form], ignoreKeys: KeyPath, _ rest: PartialKeyPath...) throws -> Insert { + return try .init(fromTable: self, instances: instances, includeKeys: [], excludeKeys: [ignoreKeys] + rest) + } + @discardableResult + func insert(_ instance: Form, ignoreKeys: KeyPath, _ rest: PartialKeyPath...) throws -> Insert { + return try .init(fromTable: self, instances: [instance], includeKeys: [], excludeKeys: [ignoreKeys] + rest) + } +} diff --git a/Sources/PerfectCRUD/Join.swift b/Sources/PerfectCRUD/Join.swift new file mode 100644 index 00000000..b48d8b71 --- /dev/null +++ b/Sources/PerfectCRUD/Join.swift @@ -0,0 +1,205 @@ +// +// PerfectCRUDJoin.swift +// PerfectCRUD +// +// Created by Kyle Jessup on 2018-01-03. +// + +import Foundation + +let joinPivotIdColumnName = "_crud_pivot_id_" + +let joinWord = "LEFT JOIN" + +public struct Join: TableProtocol, FromTableProtocol, Joinable, Selectable, Whereable, Orderable, Limitable { + public typealias Form = B + public typealias FromTableType = A + public typealias ComparisonType = O + public typealias OverAllForm = OAF + public let fromTable: FromTableType + let to: KeyPath + let on: KeyPath + let equals: KeyPath + public func setState(state: inout SQLGenState) throws { + try fromTable.setState(state: &state) + try state.addTable(type: Form.self, joinData: .init(to: to, on: on, equals: equals, pivot: nil)) + } + public func setSQL(state: inout SQLGenState) throws { + let (orderings, limit) = state.consumeState() + try fromTable.setSQL(state: &state) + let delegate = state.delegate + guard let poppedTableData = state.popTableData() else { + throw CRUDSQLGenError("No tables specified.") + } + let myTable = poppedTableData.myTable + let firstTable = poppedTableData.firstTable + let joinTables = poppedTableData.remainingTables + let nameQ = try delegate.quote(identifier: Form.CRUDTableName) + let aliasQ = try delegate.quote(identifier: myTable.alias) + let fNameQ = try delegate.quote(identifier: firstTable.type.CRUDTableName) + let fAliasQ = try delegate.quote(identifier: firstTable.alias) + let lhsStr = try Expression.sqlSnippet(keyPath: on, tableData: firstTable, state: state) + let rhsStr = try Expression.sqlSnippet(keyPath: equals, tableData: myTable, state: state) + switch state.command { + case .count: + () // joins do nothing on .count except limit master # + case .select: + var sqlStr = + """ + SELECT DISTINCT \(aliasQ).* + FROM \(nameQ) AS \(aliasQ) + \(joinWord) \(fNameQ) AS \(fAliasQ) ON \(lhsStr) = \(rhsStr) + + """ + if let whereExpr = state.whereExpr { + let referencedTypes = whereExpr.referencedTypes() + for type in referencedTypes { + guard type != firstTable.type && type != Form.self else { + continue + } + guard let joinTable = joinTables.first(where: { type == $0.type }) else { + throw CRUDSQLGenError("Unknown type included in where clause \(type).") + } + guard let joinData = joinTable.joinData else { + throw CRUDSQLGenError("Join without a clause \(type).") + } + let nameQ = try delegate.quote(identifier: joinTable.type.CRUDTableName) + let aliasQ = try delegate.quote(identifier: joinTable.alias) + let lhsStr = try Expression.keyPath(joinData.on).sqlSnippet(state: state) + let rhsStr = try Expression.keyPath(joinData.equals).sqlSnippet(state: state) + sqlStr += "\(joinWord) \(nameQ) AS \(aliasQ) ON \(lhsStr) = \(rhsStr)\n" + } + sqlStr += "WHERE \(try whereExpr.sqlSnippet(state: state))\n" + } + if !orderings.isEmpty { + let m = try orderings.map { "\(try Expression.keyPath($0.key).sqlSnippet(state: state))\($0.desc ? " DESC" : "")" } + sqlStr += "ORDER BY \(m.joined(separator: ", "))\n" + } + if let (max, skip) = limit { + if max > 0 { + sqlStr += "LIMIT \(max)\n" + } + if skip > 0 { + sqlStr += "OFFSET \(skip)\n" + } + } + state.statements.append(.init(sql: sqlStr, bindings: delegate.bindings)) + state.delegate.bindings = [] + CRUDLogging.log(.query, sqlStr) + // ordering + case .insert, .update, .delete:() + // state.fromStr.append("\(myTable)") + case .unknown: + throw CRUDSQLGenError("SQL command was not set.") + } + } +} + +public struct JoinPivot: TableProtocol, FromTableProtocol, Joinable, Selectable, Whereable, Orderable, Limitable { + public typealias Form = MyForm + public typealias FromTableType = MasterTable + public typealias PivotTableType = With + public typealias ComparisonType = PivotCompType + public typealias ComparisonType2 = PivotCompType2 + public typealias OverAllForm = OAF + + public let fromTable: FromTableType + let to: KeyPath + let on: KeyPath + let equals: KeyPath + let and: KeyPath + let alsoEquals: KeyPath + + public func setState(state: inout SQLGenState) throws { + try fromTable.setState(state: &state) + try state.addTable(type: Form.self, joinData: .init(to: to, on: on, equals: equals, pivot: PivotTableType.self)) + try state.addTable(type: PivotTableType.self) + } + public func setSQL(state: inout SQLGenState) throws { + let (orderings, limit) = state.consumeState() + try fromTable.setSQL(state: &state) + let delegate = state.delegate + + guard let poppedTableData1 = state.popTableData(), + let poppedTableData2 = state.popTableData() else { + throw CRUDSQLGenError("No tables specified.") + } + let myTable = poppedTableData1.myTable + let firstTable = poppedTableData1.firstTable + let joinTables = poppedTableData1.remainingTables + let pivotTable = poppedTableData2.myTable + + let myNameQ = try delegate.quote(identifier: myTable.type.CRUDTableName) + let myAliasQ = try delegate.quote(identifier: myTable.alias) + + let firstNameQ = try delegate.quote(identifier: firstTable.type.CRUDTableName) + let firstAliasQ = try delegate.quote(identifier: firstTable.alias) + + let lhsStr = try Expression.sqlSnippet(keyPath: on, tableData: firstTable, state: state) + let rhsStr = try Expression.sqlSnippet(keyPath: equals, tableData: pivotTable, state: state) + + let pivotNameQ = try delegate.quote(identifier: pivotTable.type.CRUDTableName) + let pivotAliasQ = try delegate.quote(identifier: pivotTable.alias) + + let lhsStr2 = try Expression.sqlSnippet(keyPath: and, tableData: myTable, state: state) + let rhsStr2 = try Expression.sqlSnippet(keyPath: alsoEquals, tableData: pivotTable, state: state) + + let tempColumnNameQ = try delegate.quote(identifier: joinPivotIdColumnName) + + switch state.command { + case .count: + () // joins do nothing on .count except limit master # + case .select: + var sqlStr = + """ + SELECT DISTINCT \(myAliasQ).*, \(lhsStr) AS \(tempColumnNameQ) + FROM \(myNameQ) AS \(myAliasQ) + \(joinWord) \(pivotNameQ) AS \(pivotAliasQ) ON \(lhsStr2) = \(rhsStr2) + \(joinWord) \(firstNameQ) AS \(firstAliasQ) ON \(lhsStr) = \(rhsStr) + + """ + if let whereExpr = state.whereExpr { + let referencedTypes = whereExpr.referencedTypes() + for type in referencedTypes { + guard type != firstTable.type, + type != Form.self, + type != PivotTableType.self else { + continue + } + guard let joinTable = joinTables.first(where: { type == $0.type }) else { + throw CRUDSQLGenError("Unknown type included in where clause \(type).") + } + guard let joinData = joinTable.joinData else { + throw CRUDSQLGenError("Join without a clause \(type).") + } + let nameQ = try delegate.quote(identifier: joinTable.type.CRUDTableName) + let aliasQ = try delegate.quote(identifier: joinTable.alias) + let lhsStr = try Expression.keyPath(joinData.on).sqlSnippet(state: state) + let rhsStr = try Expression.keyPath(joinData.equals).sqlSnippet(state: state) + sqlStr += "\(joinWord) \(nameQ) AS \(aliasQ) ON \(lhsStr) = \(rhsStr)\n" + } + sqlStr += "WHERE \(try whereExpr.sqlSnippet(state: state))\n" + } + if !orderings.isEmpty { + let m = try orderings.map { "\(try Expression.keyPath($0.key).sqlSnippet(state: state))\($0.desc ? " DESC" : "")" } + sqlStr += "ORDER BY \(m.joined(separator: ", "))\n" + } + if let (max, skip) = limit { + if max > 0 { + sqlStr += "LIMIT \(max)\n" + } + if skip > 0 { + sqlStr += "OFFSET \(skip)\n" + } + } + state.statements.append(.init(sql: sqlStr, bindings: delegate.bindings)) + state.delegate.bindings = [] + CRUDLogging.log(.query, sqlStr) + // ordering + case .insert, .update, .delete:() + // state.fromStr.append("\(myTable)") + case .unknown: + throw CRUDSQLGenError("SQL command was not set.") + } + } +} diff --git a/Sources/PerfectCRUD/Logging.swift b/Sources/PerfectCRUD/Logging.swift new file mode 100644 index 00000000..6e20ffd2 --- /dev/null +++ b/Sources/PerfectCRUD/Logging.swift @@ -0,0 +1,166 @@ +// +// PerfectCRUDLogging.swift +// PerfectCRUD +// +// Created by Kyle Jessup on 2017-11-24. +// Copyright (C) 2017 PerfectlySoft, Inc. +// +// ===----------------------------------------------------------------------===// +// +// This source file is part of the Perfect.org open source project +// +// Copyright (c) 2015 - 2017 PerfectlySoft Inc. and the Perfect project authors +// Licensed under Apache License v2.0 +// +// See http://perfect.org/licensing.html for license information +// +// ===----------------------------------------------------------------------===// +// + +import Foundation +import Dispatch + +public struct CRUDSQLGenError: Error, CustomStringConvertible { + public let description: String + public init(_ msg: String) { + description = msg + CRUDLogging.log(.error, msg) + } +} +public struct CRUDSQLExeError: Error, CustomStringConvertible { + public let description: String + public init(_ msg: String) { + description = msg + CRUDLogging.log(.error, msg) + } +} + +public enum CRUDLogDestination { + case none + case console + case file(String) + case custom((CRUDLogEvent) -> ()) + + func handleEvent(_ event: CRUDLogEvent) { + switch self { + case .none: + () + case .console: + print("\(event)") + case .file(let name): + let fm = FileManager() + guard fm.isWritableFile(atPath: name) || fm.createFile(atPath: name, contents: nil, attributes: nil), + let fileHandle = FileHandle(forWritingAtPath: name), + let data = "\(event)\n".data(using: .utf8) else { + print("[ERR] Unable to open file at \"\(name)\" to log event \(event)") + return + } + defer { + fileHandle.closeFile() + } + fileHandle.seekToEndOfFile() + fileHandle.write(data) + case .custom(let code): + code(event) + } + } +} + +public enum CRUDLogEventType: CustomStringConvertible { + case info, warning, error, query + public var description: String { + switch self { + case .info: + return "INFO" + case .warning: + return "WARN" + case .error: + return "ERR" + case .query: + return "QUERY" + } + } +} + +public struct CRUDLogEvent: CustomStringConvertible { + public let time: Date + public let type: CRUDLogEventType + public let msg: String + public var description: String { + let formatter = DateFormatter() + formatter.dateFormat = "EEE, dd MMM yyyy HH:mm:ss ZZ" + return "[\(formatter.string(from: time))] [\(type)] \(msg)" + } +} + +public struct CRUDLogging { + private static var _queryLogDestinations: [CRUDLogDestination] = [.console] + private static var _errorLogDestinations: [CRUDLogDestination] = [.console] + private static var pendingEvents: [CRUDLogEvent] = [] + private static var loggingQueue: DispatchQueue = { + let q = DispatchQueue(label: "CRUDLoggingQueue", qos: .background) + scheduleLogCheck(q) + return q + }() + private static func logCheckReschedulingInSerialQueue() { + logCheckInSerialQueue() + scheduleLogCheck(loggingQueue) + } + private static func logCheckInSerialQueue() { + guard !pendingEvents.isEmpty else { + return + } + let eventsToLog = pendingEvents + pendingEvents = [] + eventsToLog.forEach { + logEventInSerialQueue($0) + } + } + private static func logEventInSerialQueue(_ event: CRUDLogEvent) { + if case .query = event.type { + _queryLogDestinations.forEach { $0.handleEvent(event) } + } else { + _errorLogDestinations.forEach { $0.handleEvent(event) } + } + } + private static func scheduleLogCheck(_ queue: DispatchQueue) { + queue.asyncAfter(deadline: .now() + 0.5, execute: logCheckReschedulingInSerialQueue) + } +} + +public extension CRUDLogging { + static func flush() { + loggingQueue.sync { + logCheckInSerialQueue() + } + } + static var queryLogDestinations: [CRUDLogDestination] { + get { + return loggingQueue.sync { return _queryLogDestinations } + } + set { + loggingQueue.async { _queryLogDestinations = newValue } + } + } + static var errorLogDestinations: [CRUDLogDestination] { + get { + return loggingQueue.sync { return _errorLogDestinations } + } + set { + loggingQueue.async { _errorLogDestinations = newValue } + } + } + static func log(_ type: CRUDLogEventType, _ msg: String) { + let now = Date() + #if DEBUG || Xcode + loggingQueue.sync { + pendingEvents.append(.init(time: now, type: type, msg: msg)) + logCheckInSerialQueue() + } + #else + loggingQueue.async { + pendingEvents.append(.init(time: now, type: type, msg: msg)) + } + #endif + } +} diff --git a/Sources/PerfectCRUD/PerfectCRUD.swift b/Sources/PerfectCRUD/PerfectCRUD.swift new file mode 100644 index 00000000..1ffbc372 --- /dev/null +++ b/Sources/PerfectCRUD/PerfectCRUD.swift @@ -0,0 +1,434 @@ +// +// TestDeleteMe.swift +// PerfectCRUD +// +// Created by Kyle Jessup on 2017-11-26. +// + +import Foundation + +public typealias Expression = CRUDExpression +public typealias Bindings = [(String, Expression)] + +public protocol QueryItem { + associatedtype OverAllForm: Codable + func setState(state: inout SQLGenState) throws + func setSQL(state: inout SQLGenState) throws +} + +public protocol TableProtocol: QueryItem { + associatedtype Form: Codable + var databaseConfiguration: DatabaseConfigurationProtocol { get } +} + +public protocol FromTableProtocol { + associatedtype FromTableType: TableProtocol + var fromTable: FromTableType { get } +} + +public protocol CommandProtocol: QueryItem { + var sqlGenState: SQLGenState { get } +} + +public protocol SelectProtocol: Sequence, FromTableProtocol, CommandProtocol { + var fromTable: FromTableType { get } +} +// swiftlint:disable class_delegate_protocol +public protocol SQLGenDelegate { + var bindings: Bindings { get set } + func getBinding(for: Expression) throws -> String + func quote(identifier: String) throws -> String + func getCreateTableSQL(forTable: TableStructure, policy: TableCreatePolicy) throws -> [String] + func getCreateIndexSQL(forTable name: String, on columns: [String], unique: Bool) throws -> [String] + func getEmptyInsertSnippet() -> String // usually DEFAULT VALUES vs. VALUES () +} + +public extension SQLGenDelegate { + func getEmptyInsertSnippet() -> String { + return "DEFAULT VALUES" + } +} +// swiftlint:disable class_delegate_protocol +public protocol SQLExeDelegate { + func bind(_ bindings: Bindings, skip: Int) throws + func hasNext() throws -> Bool + func next() throws -> KeyedDecodingContainer? + func asyncExecute(completion: @escaping (SQLExeDelegate) -> ()) +} + +public protocol DatabaseConfigurationProtocol { + var sqlGenDelegate: SQLGenDelegate { get } + func sqlExeDelegate(forSQL: String) throws -> SQLExeDelegate + + init(url: String?, name: String?, host: String?, port: Int?, user: String?, pass: String?) throws +} + +public protocol DatabaseProtocol { + associatedtype Configuration: DatabaseConfigurationProtocol + var configuration: Configuration { get } + func table(_ form: T.Type) -> Table +} + +public protocol TableNameProvider { + static var tableName: String { get } +} + +public protocol Joinable: TableProtocol { + func join(_ to: KeyPath, on: KeyPath, equals: KeyPath) throws -> Join +} + +public protocol Selectable: TableProtocol { + func select() throws -> Select + func count() throws -> Int +} + +public protocol Whereable: TableProtocol { + func `where`(_ expr: CRUDBooleanExpression) -> Where +} + +public protocol Orderable: TableProtocol { + func order(by: PartialKeyPath
...) -> Ordering + func order(descending by: PartialKeyPath...) -> Ordering +} + +public protocol Limitable: TableProtocol { + func limit(_ max: Int, skip: Int) -> Limit +} +// swiftlint:disable line_length +public extension Joinable { + func join(_ to: KeyPath, on: KeyPath, equals: KeyPath) throws -> Join { + return .init(fromTable: self, to: to, on: on, equals: equals) + } + + func join(_ to: KeyPath, with: Pivot.Type, on: KeyPath, equals: KeyPath, and: KeyPath, is: KeyPath) throws -> JoinPivot { + return .init(fromTable: self, to: to, on: on, equals: equals, and: and, alsoEquals: `is`) + } +} + +public extension Selectable { + func select() throws -> Select { + return try .init(fromTable: self) + } + func count() throws -> Int { + var state = SQLGenState(delegate: databaseConfiguration.sqlGenDelegate) + state.command = .count + try setState(state: &state) + try setSQL(state: &state) + guard state.statements.count == 1 else { + throw CRUDSQLGenError("Too many statements for count().") + } + let stat = state.statements[0] + let exeDelegate = try databaseConfiguration.sqlExeDelegate(forSQL: stat.sql) + try exeDelegate.bind(stat.bindings) + guard try exeDelegate.hasNext(), + let container: KeyedDecodingContainer = try exeDelegate.next() else { + throw CRUDSQLGenError("No rows returned in count().") + } + return try container.decode(Int.self, forKey: ColumnKey(stringValue: "count")!) + } + func first() throws -> OverAllForm? { + var i = try select().makeIterator() + return try i.nextElement() + } +} + +public extension Selectable where Self: Limitable { + func first() throws -> OverAllForm? { + var i = try limit(1).select().makeIterator() + return try i.nextElement() + } +} + +public extension Whereable { + func `where`(_ expr: CRUDBooleanExpression) -> Where { + return .init(fromTable: self, expression: expr.crudExpression) + } +} +// swiftlint:disable line_length +public extension Orderable { + func order(by: PartialKeyPath...) -> Ordering { + return .init(fromTable: self, keys: by, descending: false) + } + func order(descending by: PartialKeyPath...) -> Ordering { + return .init(fromTable: self, keys: by, descending: true) + } + // !FIX! Swift 4.0.2 seems to have a problem with type inference for the above two funcs + // would not let \.name type references to be used + // this is an ugly work around + func order(by: KeyPath) -> Ordering { + return .init(fromTable: self, keys: [by], descending: false) + } + func order(by: KeyPath, _ thenBy: KeyPath) -> Ordering { + return .init(fromTable: self, keys: [by, thenBy], descending: false) + } + func order(by: KeyPath, _ thenBy: KeyPath, _ thenBy2: KeyPath) -> Ordering { + return .init(fromTable: self, keys: [by, thenBy, thenBy2], descending: false) + } + func order(by: KeyPath, _ thenBy: KeyPath, _ thenBy2: KeyPath, _ thenBy3: KeyPath) -> Ordering { + return .init(fromTable: self, keys: [by, thenBy, thenBy2, thenBy3], descending: false) + } + func order(by: KeyPath, _ thenBy: KeyPath, _ thenBy2: KeyPath, _ thenBy3: KeyPath, _ thenBy4: KeyPath) -> Ordering { + return .init(fromTable: self, keys: [by, thenBy, thenBy2, thenBy3, thenBy4], descending: false) + } + // desc + func order(descending by: KeyPath) -> Ordering { + return .init(fromTable: self, keys: [by], descending: true) + } + func order(descending by: KeyPath, _ thenBy: KeyPath) -> Ordering { + return .init(fromTable: self, keys: [by, thenBy], descending: true) + } + func order(descending by: KeyPath, _ thenBy: KeyPath, _ thenBy2: KeyPath) -> Ordering { + return .init(fromTable: self, keys: [by, thenBy, thenBy2], descending: true) + } + func order(descending by: KeyPath, _ thenBy: KeyPath, _ thenBy2: KeyPath, _ thenBy3: KeyPath) -> Ordering { + return .init(fromTable: self, keys: [by, thenBy, thenBy2, thenBy3], descending: true) + } + func order(descending by: KeyPath, _ thenBy: KeyPath, _ thenBy2: KeyPath, _ thenBy3: KeyPath, _ thenBy4: KeyPath) -> Ordering { + return .init(fromTable: self, keys: [by, thenBy, thenBy2, thenBy3, thenBy4], descending: true) + } +} + +public extension Limitable { + func limit(_ max: Int = 0, skip: Int = 0) -> Limit { + return .init(fromTable: self, max: max, skip: skip) + } + func limit(_ range: Range) -> Limit { + return limit(range.count, skip: range.lowerBound) + } + func limit(_ range: ClosedRange) -> Limit { + return limit(range.count, skip: range.lowerBound) + } + func limit(_ range: PartialRangeFrom) -> Limit { + return limit(skip: range.lowerBound) + } + func limit(_ range: PartialRangeThrough) -> Limit { + return limit(range.upperBound + 1) + } + func limit(_ range: PartialRangeUpTo) -> Limit { + return limit(range.upperBound) + } +} + +extension FromTableProtocol { + public var databaseConfiguration: DatabaseConfigurationProtocol { return fromTable.databaseConfiguration } +} + +extension CommandProtocol { + public func setState(state: inout SQLGenState) throws {} + public func setSQL(state: inout SQLGenState) throws {} +} + +public extension SQLExeDelegate { + func bind(_ bindings: Bindings) throws { return try bind(bindings, skip: 0) } +} + +public extension Decodable { + static var CRUDTableName: String { + if let p = self as? TableNameProvider.Type { + return p.tableName + } + return "\(Self.self)" + } +} + +struct PivotContainer { + let instance: Codable + let keys: [Codable] +} + +struct SQLTopExeDelegate: SQLExeDelegate { + func asyncExecute(completion: @escaping (SQLExeDelegate) -> ()) { + // place holder + } + + let genState: SQLGenState + let master: (table: SQLGenState.TableData, delegate: SQLExeDelegate) + let subObjects: [String:(onKeyName: String, onKey: AnyKeyPath, equalsKey: AnyKeyPath, objects: [Any])] + init(genState state: SQLGenState, configurator: DatabaseConfigurationProtocol) throws { + genState = state + let delegates: [(table: SQLGenState.TableData, delegate: SQLExeDelegate)] = try zip(state.tableData, state.statements).map { + let sd = try configurator.sqlExeDelegate(forSQL: $0.1.sql) + try sd.bind($0.1.bindings) + return ($0.0, sd) + } + guard !delegates.isEmpty else { + throw CRUDSQLExeError("No tables in query.") + } + master = delegates[0] + guard let modelInstance = master.table.modelInstance else { + throw CRUDSQLExeError("No model instance for type \(master.table.type).") + } + let joins = delegates[1...] + let keyPathDecoder = master.table.keyPathDecoder + subObjects = Dictionary(uniqueKeysWithValues: try joins.map { + let (joinTable, joinDelegate) = $0 + guard let joinData = joinTable.joinData, + let keyStr = try keyPathDecoder.getKeyPathName(modelInstance, keyPath: joinData.to), + let onKeyStr = try keyPathDecoder.getKeyPathName(modelInstance, keyPath: joinData.on) else { + throw CRUDSQLExeError("No join data on \(joinTable.type)") + } + var ary: [Any] = [] + let joinTableType = joinTable.type + if nil != joinData.pivot { + guard let onType = type(of: joinData.on).valueType as? Codable.Type else { + throw CRUDSQLExeError("Invalid join comparison type \(joinData.on).") + } + while try joinDelegate.hasNext() { + let decoder = CRUDPivotRowDecoder(delegate: joinDelegate, pivotOn: onType) + let instance = try joinTableType.init(from: decoder) + let keys = decoder.orderedKeys + ary.append(PivotContainer(instance: instance, keys: keys)) + } + } else { + while try joinDelegate.hasNext() { + let decoder = CRUDRowDecoder(delegate: joinDelegate) + ary.append(try joinTableType.init(from: decoder)) + } + } + return (keyStr, (onKeyStr, joinData.on, joinData.equals, ary)) + }) + } + func bind(_ bindings: Bindings, skip: Int) throws { + try master.delegate.bind(bindings, skip: skip) + } + func hasNext() throws -> Bool { + return try master.delegate.hasNext() + } + func next() throws -> KeyedDecodingContainer? where A: CodingKey { + guard let k: KeyedDecodingContainer = try master.delegate.next() else { + return nil + } + return KeyedDecodingContainer(SQLTopRowReader(exeDelegate: self, subRowReader: k)) + } +} + +public struct SQLGenState { + public enum Command { + case select, insert, update, delete, unknown + case count + } + public struct TableData { + public let type: Codable.Type + public let alias: String + public let modelInstance: Codable? + public let keyPathDecoder: CRUDKeyPathsDecoder + public let joinData: PropertyJoinData? + } + public struct PropertyJoinData { + public let to: AnyKeyPath + public let on: AnyKeyPath + public let equals: AnyKeyPath + public let pivot: Codable.Type? + } + public struct Statement { + public let sql: String + public let bindings: Bindings + } + struct MyTableData { + let firstTable: TableData + let myTable: TableData + let remainingTables: [TableData] + } + typealias Ordering = (key: AnyKeyPath, desc: Bool) + var delegate: SQLGenDelegate + var aliasCounter = 0 + public var tableData: [TableData] = [] + var tablePopCount = 0 + public var command: Command = .unknown + var whereExpr: Expression? + public var statements: [Statement] = [] // statements count must match tableData count for exe to succeed + var accumulatedOrderings: [Ordering] = [] + var currentLimit: (max: Int, skip: Int)? + public var bindingsEncoder: CRUDBindingsEncoder? + public var columnFilters: (include: [String], exclude: [String]) = ([], []) + public init(delegate d: SQLGenDelegate) { + delegate = d + } + mutating func consumeState() -> ([Ordering], (max: Int, skip: Int)?) { + defer { + accumulatedOrderings = [] + currentLimit = nil + } + return (accumulatedOrderings, currentLimit) + } + mutating func addTable(type: A.Type, joinData: PropertyJoinData? = nil) throws { + let decoder = CRUDKeyPathsDecoder() + let model = try A(from: decoder) + tableData.append(.init(type: type, + alias: nextAlias(), + modelInstance: model, + keyPathDecoder: decoder, + joinData: joinData)) + } + mutating func getAlias(type: A.Type) -> String? { + return tableData.first { $0.type == type }?.alias + } + func getTableData(type: Any.Type) -> TableData? { + return tableData.first { $0.type == type } + } + mutating func nextAlias() -> String { + defer { aliasCounter += 1 } + return "t\(aliasCounter)" + } + func getTableName(type: A.Type) -> String { + return type.CRUDTableName + } + mutating func getKeyName(type: A.Type, key: PartialKeyPath) throws -> String? { + guard let td = getTableData(type: type), + let instance = td.modelInstance as? A, + let name = try td.keyPathDecoder.getKeyPathName(instance, keyPath: key) else { + return nil + } + return name + } + mutating func popTableData() -> MyTableData? { + guard !tableData.isEmpty else { + return nil + } + let myTableIndex = tablePopCount + tablePopCount += 1 + return MyTableData(firstTable: tableData[0], + myTable: tableData[myTableIndex], + remainingTables: myTableIndex == 0 ? Array(tableData[1...]) : Array(tableData[1.. String { + let dateFormatter = DateFormatter() + dateFormatter.locale = Locale(identifier: "en_US_POSIX") + dateFormatter.timeZone = TimeZone(abbreviation: "GMT") + dateFormatter.dateFormat = "yyyy-MM-dd'T'HH:mm:ss.SSS" + let ret = dateFormatter.string(from: self) + "Z" + return ret + } + + init?(fromISO8601 string: String) { + let dateFormatter = DateFormatter() + dateFormatter.locale = Locale(identifier: "en_US_POSIX") + dateFormatter.timeZone = TimeZone.current + let validFormats = [ + "yyyy-MM-dd'T'HH:mm:ss.SSSZ", + "yyyy-MM-dd HH:mm:ss.SSSx", + "yyyy-MM-dd'T'HH:mm:ssZ", + "yyyy-MM-dd HH:mm:ssx"] + for fmt in validFormats { + dateFormatter.dateFormat = fmt + if let slf = dateFormatter.date(from: string) { + self = slf + return + } + } + return nil + } +} + +#if swift(>=4.1) +#else +// Added for Swift 4.0/4.1 compat +extension Collection { + func compactMap(_ transform: (Element) throws -> ElementOfResult?) rethrows -> [ElementOfResult] { + return try flatMap(transform) + } +} +#endif diff --git a/Sources/PerfectCRUD/Select.swift b/Sources/PerfectCRUD/Select.swift new file mode 100644 index 00000000..c6f97126 --- /dev/null +++ b/Sources/PerfectCRUD/Select.swift @@ -0,0 +1,105 @@ +// +// PerfectCRUDSelect.swift +// PerfectCRUD +// +// Created by Kyle Jessup on 2017-12-02. +// + +import Foundation + +public struct SelectIterator: IteratorProtocol { + public typealias Element = A.OverAllForm + let select: A? + let exeDelegate: SQLExeDelegate? + init(select s: A) throws { + select = s + exeDelegate = try SQLTopExeDelegate(genState: s.sqlGenState, configurator: s.fromTable.databaseConfiguration) + } + init() { + select = nil + exeDelegate = nil + } + public mutating func next() -> Element? { + guard let delegate = exeDelegate else { + return nil + } + do { + if try delegate.hasNext() { + let rowDecoder: CRUDRowDecoder = CRUDRowDecoder(delegate: delegate) + return try Element(from: rowDecoder) + } + } catch { + CRUDLogging.log(.error, "Error thrown in SelectIterator.next(). Caught: \(error)") + } + return nil + } + public mutating func nextElement() throws -> Element? { + guard let delegate = exeDelegate else { + return nil + } + if try delegate.hasNext() { + let rowDecoder: CRUDRowDecoder = CRUDRowDecoder(delegate: delegate) + return try Element(from: rowDecoder) + } + return nil + } +} + +public struct Select: SelectProtocol { + public typealias Iterator = SelectIterator +
+ + +
+
+ + +
+ + + + +{{ > inc/footer }} diff --git a/templates/pages/accountreset.mustache b/templates/pages/accountreset.mustache new file mode 100644 index 00000000..5a8b1c63 --- /dev/null +++ b/templates/pages/accountreset.mustache @@ -0,0 +1,25 @@ +{{ > inc/header }} + +{{ > inc/footer }} diff --git a/templates/pages/inc/footer.mustache b/templates/pages/inc/footer.mustache new file mode 100644 index 00000000..56e17a19 --- /dev/null +++ b/templates/pages/inc/footer.mustache @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/templates/pages/inc/header.mustache b/templates/pages/inc/header.mustache new file mode 100644 index 00000000..22fa28e4 --- /dev/null +++ b/templates/pages/inc/header.mustache @@ -0,0 +1,77 @@ + + + + + + + + + + + Welcome! + + +
+ diff --git a/templates/pages/index.mustache b/templates/pages/index.mustache new file mode 100644 index 00000000..af305b0e --- /dev/null +++ b/templates/pages/index.mustache @@ -0,0 +1,4 @@ +{{ > inc/header }} +

Welcome!

+ +{{ > inc/footer }} diff --git a/test-linux.sh b/test-linux.sh new file mode 100755 index 00000000..e20346ed --- /dev/null +++ b/test-linux.sh @@ -0,0 +1,5 @@ +#!/bin/bash +cp smtp.test.json /tmp/ && \ +swift test --filter PerfectSMTPTests --filter PerfectSQLiteTests --filter PerfectHTTPTests && \ +swift test --skip PerfectSMTPTests --skip PerfectSQLiteTests --skip PerfectHTTPTests + diff --git a/test-macos.sh b/test-macos.sh new file mode 100755 index 00000000..b357c831 --- /dev/null +++ b/test-macos.sh @@ -0,0 +1,3 @@ +#!/bin/bash +cp smtp.test.json /tmp/ && \ +swift package generate-xcodeproj && xcodebuild test -project Perfect.xcodeproj/ -scheme Perfect-Package diff --git a/webroot/favicon.ico b/webroot/favicon.ico new file mode 100644 index 00000000..44aa32e1 Binary files /dev/null and b/webroot/favicon.ico differ diff --git a/webroot/img/logo.svg b/webroot/img/logo.svg new file mode 100644 index 00000000..a43d8937 --- /dev/null +++ b/webroot/img/logo.svg @@ -0,0 +1,61 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/webroot/js/api.js b/webroot/js/api.js new file mode 100644 index 00000000..d616686f --- /dev/null +++ b/webroot/js/api.js @@ -0,0 +1,79 @@ +function getMeta(name) { + const elements = document.getElementsByTagName("meta"); + const m = Array.prototype.slice.call(elements).find( m => m.getAttribute("http-equiv") == name ); + return m ? m.getAttribute("content") : ""; +} +function preparePost(uri, form) { + const xauth = "Authorization"; + const xcsrf = "X-CSRF-Token"; + var headers = {}; + headers[xauth] = getMeta(xauth); + headers[xcsrf] = getMeta(xcsrf); + return { + url: uri, dataType: "json", contentType: "application/json", + headers: headers, + data: JSON.stringify(form) + }; +} +function onClickSignUp() { + const form = {"email": $('#email').val()}; + const body = preparePost("/api/invite", form); + $.post(body).done( () => { + alert("Please check your email for an invitation link."); + }).fail( + error => alert(error.responseText) + ); +} +function onClickSignIn() { + const form = { + "email": $('#email').val(), + "password": $('#password').val() + }; + const body = preparePost("/api/login", form); + $.post(body).done( (data) => { + if (data.token.length > 0) { + $('head').append(``); + $('#popupLogin').modal('toggle'); + } else { + alert("sorry, unable to login."); + console.dir(data); + } + }).fail( + error => alert(error.responseText) + ); +} +function onClickJoin() { + const form = { + "code": $('#code').val(), + "password": $('#newPassword').val() + }; + const body = preparePost("/api/register", form); + $.post(body).done( () => { + alert("Success! Please login!"); + window.location = "/"; + }).fail( + error => alert(error.responseText) + ); +} +function onClickReset() { + const form = {"email": $('#email').val()}; + const body = preparePost("/api/reset/attempt", form); + $.post(body).done( () => { + alert("Please check your email for an password reset link."); + }).fail( + error => alert(error.responseText) + ); +} +function onClickConfirmReset() { + const form = { + "code": $('#code').val(), + "password": $('#newPassword').val() + }; + const body = preparePost("/api/reset/confirm", form); + $.post(body).done( () => { + alert("Success! Please login!"); + window.location = "/"; + }).fail( + error => alert(error.responseText) + ); +} \ No newline at end of file