From 20e9b50fe506d3beb686d227d9271d94b4526e74 Mon Sep 17 00:00:00 2001 From: Dimitris Koutsogiorgas Date: Wed, 30 Aug 2023 16:20:37 -0700 Subject: [PATCH] [swift] Generate a default value for sub fields with only optional fields. --- .../com/squareup/wire/swift/SwiftGenerator.kt | 23 ++++++++++++------- .../src/main/swift/ContainsDuration.swift | 7 +++--- .../src/main/swift/ContainsTimestamp.swift | 7 +++--- .../no-manifest/src/main/swift/AllTypes.swift | 14 ++++++----- .../no-manifest/src/main/swift/FooBar.swift | 7 +++--- .../src/main/swift/OuterMessage.swift | 7 +++--- .../src/main/swift/VersionOne.swift | 7 +++--- .../src/main/swift/VersionTwo.swift | 7 +++--- 8 files changed, 47 insertions(+), 32 deletions(-) diff --git a/wire-swift-generator/src/main/java/com/squareup/wire/swift/SwiftGenerator.kt b/wire-swift-generator/src/main/java/com/squareup/wire/swift/SwiftGenerator.kt index dc2ceaedad..9f3de9bbc2 100644 --- a/wire-swift-generator/src/main/java/com/squareup/wire/swift/SwiftGenerator.kt +++ b/wire-swift-generator/src/main/java/com/squareup/wire/swift/SwiftGenerator.kt @@ -158,6 +158,14 @@ class SwiftGenerator private constructor( else -> null } + private val Field.defaultedValue: CodeBlock? + get() = default?.let { + return defaultFieldInitializer(type!!, it) + } ?: if (isMessage && !isRequiredParameter && !isCollection) { + val subType = schema.getType(type!!) as MessageType + if (subType!!.fields.all { !it.isRequiredParameter }) CodeBlock.of("%T()", subType.typeName) else null + } else null + // see https://protobuf.dev/programming-guides/proto3/#default private val Field.proto3InitialValue: String get() = when { @@ -651,7 +659,7 @@ class SwiftGenerator private constructor( } } addStatement( - if (field.default != null) "_%N.wrappedValue = %L" else { "self.%N = %L" }, + if (field.defaultedValue != null) "_%N.wrappedValue = %L" else { "self.%N = %L" }, field.name, initializer, ) @@ -831,7 +839,7 @@ class SwiftGenerator private constructor( .map { CodeBlock.of("%S", it) } .joinToCode() - val prefix = if (field.default != null) { "_%1N.wrappedValue" } else { "self.%1N" } + val prefix = if (field.defaultedValue != null) { "_%1N.wrappedValue" } else { "self.%1N" } addStatement( "$prefix = try container.$decode($typeArg%2T.self, $forKeys: $keys)", field.name, @@ -1146,7 +1154,7 @@ class SwiftGenerator private constructor( .apply { type.fields.filter { it.isRequiredParameter }.forEach { field -> addStatement( - if (field.default != null) "_%1N.wrappedValue = %1N" else { "self.%1N = %1N" }, + if (field.defaultedValue != null) "_%1N.wrappedValue = %1N" else { "self.%1N = %1N" }, field.name, ) } @@ -1172,7 +1180,7 @@ class SwiftGenerator private constructor( .apply { type.fields.forEach { field -> addStatement( - if (field.default != null) "_%1N.wrappedValue = %1N" else { "self.%1N = %1N" }, + if (field.defaultedValue != null) "_%1N.wrappedValue = %1N" else { "self.%1N = %1N" }, field.name, ) } @@ -1224,10 +1232,9 @@ class SwiftGenerator private constructor( if (isIndirect(type, field)) { property.addAttribute(AttributeSpec.builder(indirect).build()) } - val default = field.default - if (default != null) { - val defaultValue = defaultFieldInitializer(field.type!!, default) - property.addAttribute(AttributeSpec.builder(defaulted).addArgument("defaultValue: $defaultValue").build()) + val defaultedValue = field.defaultedValue + if (defaultedValue != null) { + property.addAttribute(AttributeSpec.builder(defaulted).addArgument("defaultValue: $defaultedValue").build()) } if (field.isMap) { diff --git a/wire-tests-proto3-swift/src/main/swift/ContainsDuration.swift b/wire-tests-proto3-swift/src/main/swift/ContainsDuration.swift index ab4027b43b..f9a0e94244 100644 --- a/wire-tests-proto3-swift/src/main/swift/ContainsDuration.swift +++ b/wire-tests-proto3-swift/src/main/swift/ContainsDuration.swift @@ -5,6 +5,7 @@ import Wire public struct ContainsDuration { + @Defaulted(defaultValue: Duration()) public var duration: Duration? public var unknownFields: Foundation.Data = .init() @@ -20,7 +21,7 @@ extension ContainsDuration { @_disfavoredOverload @available(*, deprecated) public init(duration: Duration? = nil) { - self.duration = duration + _duration.wrappedValue = duration } } @@ -63,7 +64,7 @@ extension ContainsDuration : Proto3Codable { } self.unknownFields = try protoReader.endMessage(token: token) - self.duration = duration + _duration.wrappedValue = duration } public func encode(to protoWriter: Wire.ProtoWriter) throws { @@ -78,7 +79,7 @@ extension ContainsDuration : Codable { public init(from decoder: Swift.Decoder) throws { let container = try decoder.container(keyedBy: Wire.StringLiteralCodingKeys.self) - self.duration = try container.decodeIfPresent(Duration.self, forKey: "duration") + _duration.wrappedValue = try container.decodeIfPresent(Duration.self, forKey: "duration") } public func encode(to encoder: Swift.Encoder) throws { diff --git a/wire-tests-proto3-swift/src/main/swift/ContainsTimestamp.swift b/wire-tests-proto3-swift/src/main/swift/ContainsTimestamp.swift index f38bd33a8a..af012d9ca0 100644 --- a/wire-tests-proto3-swift/src/main/swift/ContainsTimestamp.swift +++ b/wire-tests-proto3-swift/src/main/swift/ContainsTimestamp.swift @@ -5,6 +5,7 @@ import Wire public struct ContainsTimestamp { + @Defaulted(defaultValue: Timestamp()) public var timestamp: Timestamp? public var unknownFields: Foundation.Data = .init() @@ -20,7 +21,7 @@ extension ContainsTimestamp { @_disfavoredOverload @available(*, deprecated) public init(timestamp: Timestamp? = nil) { - self.timestamp = timestamp + _timestamp.wrappedValue = timestamp } } @@ -63,7 +64,7 @@ extension ContainsTimestamp : Proto3Codable { } self.unknownFields = try protoReader.endMessage(token: token) - self.timestamp = timestamp + _timestamp.wrappedValue = timestamp } public func encode(to protoWriter: Wire.ProtoWriter) throws { @@ -78,7 +79,7 @@ extension ContainsTimestamp : Codable { public init(from decoder: Swift.Decoder) throws { let container = try decoder.container(keyedBy: Wire.StringLiteralCodingKeys.self) - self.timestamp = try container.decodeIfPresent(Timestamp.self, forKey: "timestamp") + _timestamp.wrappedValue = try container.decodeIfPresent(Timestamp.self, forKey: "timestamp") } public func encode(to encoder: Swift.Encoder) throws { diff --git a/wire-tests-swift/no-manifest/src/main/swift/AllTypes.swift b/wire-tests-swift/no-manifest/src/main/swift/AllTypes.swift index 600361c18a..2fdcf3da41 100644 --- a/wire-tests-swift/no-manifest/src/main/swift/AllTypes.swift +++ b/wire-tests-swift/no-manifest/src/main/swift/AllTypes.swift @@ -1860,7 +1860,7 @@ extension AllTypes.Storage { self.opt_string = opt_string self.opt_bytes = opt_bytes self.opt_nested_enum = opt_nested_enum - self.opt_nested_message = opt_nested_message + _opt_nested_message.wrappedValue = opt_nested_message self.req_int32 = req_int32 self.req_uint32 = req_uint32 self.req_sint32 = req_sint32 @@ -1957,7 +1957,7 @@ extension AllTypes.Storage { self.ext_opt_string = ext_opt_string self.ext_opt_bytes = ext_opt_bytes self.ext_opt_nested_enum = ext_opt_nested_enum - self.ext_opt_nested_message = ext_opt_nested_message + _ext_opt_nested_message.wrappedValue = ext_opt_nested_message self.ext_rep_int32 = ext_rep_int32 self.ext_rep_uint32 = ext_rep_uint32 self.ext_rep_sint32 = ext_rep_sint32 @@ -2017,6 +2017,7 @@ extension AllTypes { public var opt_string: Swift.String? public var opt_bytes: Foundation.Data? public var opt_nested_enum: AllTypes.NestedEnum? + @Wire.Defaulted(defaultValue: AllTypes.NestedMessage()) public var opt_nested_message: AllTypes.NestedMessage? public var req_int32: Swift.Int32 public var req_uint32: Swift.UInt32 @@ -2130,6 +2131,7 @@ extension AllTypes { public var ext_opt_string: Swift.String? public var ext_opt_bytes: Foundation.Data? public var ext_opt_nested_enum: AllTypes.NestedEnum? + @Wire.Defaulted(defaultValue: AllTypes.NestedMessage()) public var ext_opt_nested_message: AllTypes.NestedMessage? public var ext_rep_int32: [Swift.Int32] = [] public var ext_rep_uint32: [Swift.UInt32] = [] @@ -2549,7 +2551,7 @@ extension AllTypes.Storage : Proto2Codable { self.opt_string = opt_string self.opt_bytes = opt_bytes self.opt_nested_enum = opt_nested_enum - self.opt_nested_message = opt_nested_message + _opt_nested_message.wrappedValue = opt_nested_message self.req_int32 = try AllTypes.checkIfMissing(req_int32, "req_int32") self.req_uint32 = try AllTypes.checkIfMissing(req_uint32, "req_uint32") self.req_sint32 = try AllTypes.checkIfMissing(req_sint32, "req_sint32") @@ -2646,7 +2648,7 @@ extension AllTypes.Storage : Proto2Codable { self.ext_opt_string = ext_opt_string self.ext_opt_bytes = ext_opt_bytes self.ext_opt_nested_enum = ext_opt_nested_enum - self.ext_opt_nested_message = ext_opt_nested_message + _ext_opt_nested_message.wrappedValue = ext_opt_nested_message self.ext_rep_int32 = ext_rep_int32 self.ext_rep_uint32 = ext_rep_uint32 self.ext_rep_sint32 = ext_rep_sint32 @@ -2852,7 +2854,7 @@ extension AllTypes.Storage : Codable { self.opt_string = try container.decodeIfPresent(Swift.String.self, firstOfKeys: "optString", "opt_string") self.opt_bytes = try container.decodeIfPresent(stringEncoded: Foundation.Data.self, firstOfKeys: "optBytes", "opt_bytes") self.opt_nested_enum = try container.decodeIfPresent(AllTypes.NestedEnum.self, firstOfKeys: "optNestedEnum", "opt_nested_enum") - self.opt_nested_message = try container.decodeIfPresent(AllTypes.NestedMessage.self, firstOfKeys: "optNestedMessage", "opt_nested_message") + _opt_nested_message.wrappedValue = try container.decodeIfPresent(AllTypes.NestedMessage.self, firstOfKeys: "optNestedMessage", "opt_nested_message") self.req_int32 = try container.decode(Swift.Int32.self, firstOfKeys: "reqInt32", "req_int32") self.req_uint32 = try container.decode(Swift.UInt32.self, firstOfKeys: "reqUint32", "req_uint32") self.req_sint32 = try container.decode(Swift.Int32.self, firstOfKeys: "reqSint32", "req_sint32") @@ -2949,7 +2951,7 @@ extension AllTypes.Storage : Codable { self.ext_opt_string = try container.decodeIfPresent(Swift.String.self, firstOfKeys: "extOptString", "ext_opt_string") self.ext_opt_bytes = try container.decodeIfPresent(stringEncoded: Foundation.Data.self, firstOfKeys: "extOptBytes", "ext_opt_bytes") self.ext_opt_nested_enum = try container.decodeIfPresent(AllTypes.NestedEnum.self, firstOfKeys: "extOptNestedEnum", "ext_opt_nested_enum") - self.ext_opt_nested_message = try container.decodeIfPresent(AllTypes.NestedMessage.self, firstOfKeys: "extOptNestedMessage", "ext_opt_nested_message") + _ext_opt_nested_message.wrappedValue = try container.decodeIfPresent(AllTypes.NestedMessage.self, firstOfKeys: "extOptNestedMessage", "ext_opt_nested_message") self.ext_rep_int32 = try container.decodeProtoArray(Swift.Int32.self, firstOfKeys: "extRepInt32", "ext_rep_int32") self.ext_rep_uint32 = try container.decodeProtoArray(Swift.UInt32.self, firstOfKeys: "extRepUint32", "ext_rep_uint32") self.ext_rep_sint32 = try container.decodeProtoArray(Swift.Int32.self, firstOfKeys: "extRepSint32", "ext_rep_sint32") diff --git a/wire-tests-swift/no-manifest/src/main/swift/FooBar.swift b/wire-tests-swift/no-manifest/src/main/swift/FooBar.swift index 8affc4d7e1..4cd1bc7757 100644 --- a/wire-tests-swift/no-manifest/src/main/swift/FooBar.swift +++ b/wire-tests-swift/no-manifest/src/main/swift/FooBar.swift @@ -7,6 +7,7 @@ public struct FooBar { public var foo: Int32? public var bar: String? + @Defaulted(defaultValue: FooBar.Nested()) public var baz: FooBar.Nested? public var qux: UInt64? public var fred: [Float] = [] @@ -42,7 +43,7 @@ extension FooBar { ) { self.foo = foo self.bar = bar - self.baz = baz + _baz.wrappedValue = baz self.qux = qux self.fred = fred self.daisy = daisy @@ -124,7 +125,7 @@ extension FooBar : Proto2Codable { self.foo = foo self.bar = bar - self.baz = baz + _baz.wrappedValue = baz self.qux = qux self.fred = fred self.daisy = daisy @@ -157,7 +158,7 @@ extension FooBar : Codable { let container = try decoder.container(keyedBy: Wire.StringLiteralCodingKeys.self) self.foo = try container.decodeIfPresent(Swift.Int32.self, forKey: "foo") self.bar = try container.decodeIfPresent(Swift.String.self, forKey: "bar") - self.baz = try container.decodeIfPresent(FooBar.Nested.self, forKey: "baz") + _baz.wrappedValue = try container.decodeIfPresent(FooBar.Nested.self, forKey: "baz") self.qux = try container.decodeIfPresent(stringEncoded: Swift.UInt64.self, forKey: "qux") self.fred = try container.decodeProtoArray(Swift.Float.self, forKey: "fred") self.daisy = try container.decodeIfPresent(Swift.Double.self, forKey: "daisy") diff --git a/wire-tests-swift/no-manifest/src/main/swift/OuterMessage.swift b/wire-tests-swift/no-manifest/src/main/swift/OuterMessage.swift index 5a1c1d0b73..6d9ff20669 100644 --- a/wire-tests-swift/no-manifest/src/main/swift/OuterMessage.swift +++ b/wire-tests-swift/no-manifest/src/main/swift/OuterMessage.swift @@ -6,6 +6,7 @@ import Wire public struct OuterMessage { public var outer_number_before: Int32? + @Defaulted(defaultValue: EmbeddedMessage()) public var embedded_message: EmbeddedMessage? public var unknownFields: Foundation.Data = .init() @@ -22,7 +23,7 @@ extension OuterMessage { @available(*, deprecated) public init(outer_number_before: Swift.Int32? = nil, embedded_message: EmbeddedMessage? = nil) { self.outer_number_before = outer_number_before - self.embedded_message = embedded_message + _embedded_message.wrappedValue = embedded_message } } @@ -68,7 +69,7 @@ extension OuterMessage : Proto2Codable { self.unknownFields = try protoReader.endMessage(token: token) self.outer_number_before = outer_number_before - self.embedded_message = embedded_message + _embedded_message.wrappedValue = embedded_message } public func encode(to protoWriter: Wire.ProtoWriter) throws { @@ -85,7 +86,7 @@ extension OuterMessage : Codable { public init(from decoder: Swift.Decoder) throws { let container = try decoder.container(keyedBy: Wire.StringLiteralCodingKeys.self) self.outer_number_before = try container.decodeIfPresent(Swift.Int32.self, firstOfKeys: "outerNumberBefore", "outer_number_before") - self.embedded_message = try container.decodeIfPresent(EmbeddedMessage.self, firstOfKeys: "embeddedMessage", "embedded_message") + _embedded_message.wrappedValue = try container.decodeIfPresent(EmbeddedMessage.self, firstOfKeys: "embeddedMessage", "embedded_message") } public func encode(to encoder: Swift.Encoder) throws { diff --git a/wire-tests-swift/no-manifest/src/main/swift/VersionOne.swift b/wire-tests-swift/no-manifest/src/main/swift/VersionOne.swift index af553e8109..dae0dbb476 100644 --- a/wire-tests-swift/no-manifest/src/main/swift/VersionOne.swift +++ b/wire-tests-swift/no-manifest/src/main/swift/VersionOne.swift @@ -6,6 +6,7 @@ import Wire public struct VersionOne { public var i: Int32? + @Defaulted(defaultValue: NestedVersionOne()) public var obj: NestedVersionOne? public var en: EnumVersionOne? public var unknownFields: Foundation.Data = .init() @@ -27,7 +28,7 @@ extension VersionOne { en: EnumVersionOne? = nil ) { self.i = i - self.obj = obj + _obj.wrappedValue = obj self.en = en } @@ -76,7 +77,7 @@ extension VersionOne : Proto2Codable { self.unknownFields = try protoReader.endMessage(token: token) self.i = i - self.obj = obj + _obj.wrappedValue = obj self.en = en } @@ -95,7 +96,7 @@ extension VersionOne : Codable { public init(from decoder: Swift.Decoder) throws { let container = try decoder.container(keyedBy: Wire.StringLiteralCodingKeys.self) self.i = try container.decodeIfPresent(Swift.Int32.self, forKey: "i") - self.obj = try container.decodeIfPresent(NestedVersionOne.self, forKey: "obj") + _obj.wrappedValue = try container.decodeIfPresent(NestedVersionOne.self, forKey: "obj") self.en = try container.decodeIfPresent(EnumVersionOne.self, forKey: "en") } diff --git a/wire-tests-swift/no-manifest/src/main/swift/VersionTwo.swift b/wire-tests-swift/no-manifest/src/main/swift/VersionTwo.swift index 40347c6af3..9deac6ff8b 100644 --- a/wire-tests-swift/no-manifest/src/main/swift/VersionTwo.swift +++ b/wire-tests-swift/no-manifest/src/main/swift/VersionTwo.swift @@ -11,6 +11,7 @@ public struct VersionTwo { public var v2_f32: UInt32? public var v2_f64: UInt64? public var v2_rs: [String] = [] + @Defaulted(defaultValue: NestedVersionTwo()) public var obj: NestedVersionTwo? public var en: EnumVersionTwo? public var unknownFields: Foundation.Data = .init() @@ -42,7 +43,7 @@ extension VersionTwo { self.v2_f32 = v2_f32 self.v2_f64 = v2_f64 self.v2_rs = v2_rs - self.obj = obj + _obj.wrappedValue = obj self.en = en } @@ -106,7 +107,7 @@ extension VersionTwo : Proto2Codable { self.v2_f32 = v2_f32 self.v2_f64 = v2_f64 self.v2_rs = v2_rs - self.obj = obj + _obj.wrappedValue = obj self.en = en } @@ -135,7 +136,7 @@ extension VersionTwo : Codable { self.v2_f32 = try container.decodeIfPresent(Swift.UInt32.self, firstOfKeys: "v2F32", "v2_f32") self.v2_f64 = try container.decodeIfPresent(stringEncoded: Swift.UInt64.self, firstOfKeys: "v2F64", "v2_f64") self.v2_rs = try container.decodeProtoArray(Swift.String.self, firstOfKeys: "v2Rs", "v2_rs") - self.obj = try container.decodeIfPresent(NestedVersionTwo.self, forKey: "obj") + _obj.wrappedValue = try container.decodeIfPresent(NestedVersionTwo.self, forKey: "obj") self.en = try container.decodeIfPresent(EnumVersionTwo.self, forKey: "en") }