Skip to content

Commit

Permalink
[swift] Provide a default value for sub fields and common types as pe…
Browse files Browse the repository at this point in the history
…r proto spec
  • Loading branch information
dnkoutso committed Oct 11, 2023
1 parent b9f6835 commit ea299a0
Show file tree
Hide file tree
Showing 22 changed files with 367 additions and 249 deletions.
28 changes: 16 additions & 12 deletions wire-runtime-swift/src/test/swift/sample/Dinosaur.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@ public struct Dinosaur {
/**
* Common name of this dinosaur, like "Stegosaurus".
*/
@Defaulted(defaultValue: "")
public var name: String?
/**
* URLs with images of this dinosaur.
*/
public var picture_urls: [String] = []
@Defaulted(defaultValue: 0)
public var length_meters: Double?
@Defaulted(defaultValue: 0)
public var mass_kilograms: Double?
@Defaulted(defaultValue: Period.CRETACEOUS)
public var period: Period?
public var unknownFields: Foundation.Data = .init()

Expand All @@ -36,11 +40,11 @@ extension Dinosaur {
mass_kilograms: Swift.Double? = nil,
period: Period? = nil
) {
self.name = name
_name.wrappedValue = name
self.picture_urls = picture_urls
self.length_meters = length_meters
self.mass_kilograms = mass_kilograms
self.period = period
_length_meters.wrappedValue = length_meters
_mass_kilograms.wrappedValue = mass_kilograms
_period.wrappedValue = period
}

}
Expand Down Expand Up @@ -91,11 +95,11 @@ extension Dinosaur : Proto2Codable {
}
self.unknownFields = try protoReader.endMessage(token: token)

self.name = name
_name.wrappedValue = name
self.picture_urls = picture_urls
self.length_meters = length_meters
self.mass_kilograms = mass_kilograms
self.period = period
_length_meters.wrappedValue = length_meters
_mass_kilograms.wrappedValue = mass_kilograms
_period.wrappedValue = period
}

public func encode(to protoWriter: Wire.ProtoWriter) throws {
Expand All @@ -114,11 +118,11 @@ extension Dinosaur : Codable {

public init(from decoder: Swift.Decoder) throws {
let container = try decoder.container(keyedBy: Wire.StringLiteralCodingKeys.self)
self.name = try container.decodeIfPresent(Swift.String.self, forKey: "name")
_name.wrappedValue = try container.decodeIfPresent(Swift.String.self, forKey: "name")
self.picture_urls = try container.decodeProtoArray(Swift.String.self, firstOfKeys: "pictureUrls", "picture_urls")
self.length_meters = try container.decodeIfPresent(Swift.Double.self, firstOfKeys: "lengthMeters", "length_meters")
self.mass_kilograms = try container.decodeIfPresent(Swift.Double.self, firstOfKeys: "massKilograms", "mass_kilograms")
self.period = try container.decodeIfPresent(Period.self, forKey: "period")
_length_meters.wrappedValue = try container.decodeIfPresent(Swift.Double.self, firstOfKeys: "lengthMeters", "length_meters")
_mass_kilograms.wrappedValue = try container.decodeIfPresent(Swift.Double.self, firstOfKeys: "massKilograms", "mass_kilograms")
_period.wrappedValue = try container.decodeIfPresent(Period.self, forKey: "period")
}

public func encode(to encoder: Swift.Encoder) throws {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,44 @@ class SwiftGenerator private constructor(
else -> null
}

private val Field.defaultedValue: CodeBlock?
get() {
when (val default = default) {
// identity values
null -> {
if (!isOptional) return null
if (type == ProtoType.ANY) return null
if (isMessage && !isRequiredParameter && !isCollection) {
val messageType = schema.getType(type!!) as MessageType
return if (messageType.fields.any { it.isRequiredParameter }) null else CodeBlock.of("%T()", messageType.typeName.makeNonOptional())
}
if (isEnum) {
val enumType = schema.getType(type!!) as EnumType
return if (enumType.constants.getOrNull(0) != null) CodeBlock.of("%T.%L", typeName.makeNonOptional(), enumType.constants[0].name) else null
}
when (typeName.makeNonOptional()) {
BOOL -> return CodeBlock.of("%L", false)
INT -> return CodeBlock.of("%L", 0)
INT32 -> return CodeBlock.of("%L", 0)
INT64 -> return CodeBlock.of("%L", 0)
UINT32 -> return CodeBlock.of("%L", 0)
UINT64 -> return CodeBlock.of("%L", 0)
FLOAT -> return CodeBlock.of("%L", 0)
DOUBLE -> return CodeBlock.of("%L", 0)
STRING -> return CodeBlock.of("%S", "")
DATA -> return CodeBlock.of(
"%T(base64Encoded: %S)!",
FOUNDATION_DATA,
"".encode(charset = Charsets.ISO_8859_1).base64(),
)
}
}
else -> return defaultFieldInitializer(type!!, default)
}

return null
}

// see https://protobuf.dev/programming-guides/proto3/#default
private val Field.proto3InitialValue: String
get() = when {
Expand Down Expand Up @@ -651,7 +689,7 @@ class SwiftGenerator private constructor(
}
}
addStatement(
if (field.default != null) "_%N.wrappedValue = %L" else { "self.%N = %L" },
if (!isIndirect(type, field) && field.defaultedValue != null) "_%N.wrappedValue = %L" else { "self.%N = %L" },
field.name,
initializer,
)
Expand Down Expand Up @@ -831,7 +869,7 @@ class SwiftGenerator private constructor(
.map { CodeBlock.of("%S", it) }
.joinToCode()

val prefix = if (field.default != null) { "_%1N.wrappedValue" } else { "self.%1N" }
val prefix = if (!isIndirect(type, field) && field.defaultedValue != null) { "_%1N.wrappedValue" } else { "self.%1N" }
addStatement(
"$prefix = try container.$decode($typeArg%2T.self, $forKeys: $keys)",
field.name,
Expand Down Expand Up @@ -1146,7 +1184,7 @@ class SwiftGenerator private constructor(
.apply {
type.fields.filter { it.isRequiredParameter }.forEach { field ->
addStatement(
if (field.default != null) "_%1N.wrappedValue = %1N" else { "self.%1N = %1N" },
if (!isIndirect(type, field) && field.defaultedValue != null) "_%1N.wrappedValue = %1N" else { "self.%1N = %1N" },
field.name,
)
}
Expand All @@ -1172,7 +1210,7 @@ class SwiftGenerator private constructor(
.apply {
type.fields.forEach { field ->
addStatement(
if (field.default != null) "_%1N.wrappedValue = %1N" else { "self.%1N = %1N" },
if (!isIndirect(type, field) && field.defaultedValue != null) "_%1N.wrappedValue = %1N" else { "self.%1N = %1N" },
field.name,
)
}
Expand Down Expand Up @@ -1224,10 +1262,9 @@ class SwiftGenerator private constructor(
if (isIndirect(type, field)) {
property.addAttribute(AttributeSpec.builder(indirect).build())
}
val default = field.default
if (default != null) {
val defaultValue = defaultFieldInitializer(field.type!!, default)
property.addAttribute(AttributeSpec.builder(defaulted).addArgument("defaultValue: $defaultValue").build())
val defaultedValue = field.defaultedValue
if (!isIndirect(type, field) && defaultedValue != null) {
property.addAttribute(AttributeSpec.builder(defaulted).addArgument("defaultValue: $defaultedValue").build())
}

if (field.isMap) {
Expand Down Expand Up @@ -1278,7 +1315,8 @@ class SwiftGenerator private constructor(
typeName == DOUBLE -> defaultValue.toDoubleFieldInitializer()
typeName == STRING -> CodeBlock.of("%S", stringLiteralWithQuotes2(defaultValue.toString()))
typeName == DATA -> CodeBlock.of(
"Foundation.Data(base64Encoded: %S)!",
"%T(base64Encoded: %S)!",
FOUNDATION_DATA,
defaultValue.toString().encode(charset = Charsets.ISO_8859_1).base64(),
)
protoType.isEnum -> CodeBlock.of("%T.%L", typeName, defaultValue)
Expand Down
7 changes: 4 additions & 3 deletions wire-tests-proto3-swift/src/main/swift/ContainsDuration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import Wire

public struct ContainsDuration {

@Defaulted(defaultValue: Duration())
public var duration: Duration?
public var unknownFields: Foundation.Data = .init()

Expand All @@ -20,7 +21,7 @@ extension ContainsDuration {
@_disfavoredOverload
@available(*, deprecated)
public init(duration: Duration? = nil) {
self.duration = duration
_duration.wrappedValue = duration
}

}
Expand Down Expand Up @@ -63,7 +64,7 @@ extension ContainsDuration : Proto3Codable {
}
self.unknownFields = try protoReader.endMessage(token: token)

self.duration = duration
_duration.wrappedValue = duration
}

public func encode(to protoWriter: Wire.ProtoWriter) throws {
Expand All @@ -78,7 +79,7 @@ extension ContainsDuration : Codable {

public init(from decoder: Swift.Decoder) throws {
let container = try decoder.container(keyedBy: Wire.StringLiteralCodingKeys.self)
self.duration = try container.decodeIfPresent(Duration.self, forKey: "duration")
_duration.wrappedValue = try container.decodeIfPresent(Duration.self, forKey: "duration")
}

public func encode(to encoder: Swift.Encoder) throws {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import Wire

public struct ContainsTimestamp {

@Defaulted(defaultValue: Timestamp())
public var timestamp: Timestamp?
public var unknownFields: Foundation.Data = .init()

Expand All @@ -20,7 +21,7 @@ extension ContainsTimestamp {
@_disfavoredOverload
@available(*, deprecated)
public init(timestamp: Timestamp? = nil) {
self.timestamp = timestamp
_timestamp.wrappedValue = timestamp
}

}
Expand Down Expand Up @@ -63,7 +64,7 @@ extension ContainsTimestamp : Proto3Codable {
}
self.unknownFields = try protoReader.endMessage(token: token)

self.timestamp = timestamp
_timestamp.wrappedValue = timestamp
}

public func encode(to protoWriter: Wire.ProtoWriter) throws {
Expand All @@ -78,7 +79,7 @@ extension ContainsTimestamp : Codable {

public init(from decoder: Swift.Decoder) throws {
let container = try decoder.container(keyedBy: Wire.StringLiteralCodingKeys.self)
self.timestamp = try container.decodeIfPresent(Timestamp.self, forKey: "timestamp")
_timestamp.wrappedValue = try container.decodeIfPresent(Timestamp.self, forKey: "timestamp")
}

public func encode(to encoder: Swift.Encoder) throws {
Expand Down
Loading

0 comments on commit ea299a0

Please sign in to comment.