Skip to content

Commit

Permalink
Squashed commit of the following:
Browse files Browse the repository at this point in the history
commit 8b7da3e4275e0e8def53b680d746e77649f840bb
Author: Anton Versal <[email protected]>
Date:   Mon Nov 12 15:51:15 2018 +0100

    remames generated messages

commit 57fad651cf61a72c20a6564b38fac5ada32fbce7
Author: Anton Versal <[email protected]>
Date:   Mon Nov 12 14:19:52 2018 +0100

    adds interface I

commit cd713b5a4b9d7f9038d544b43bb2d613e42a597f
Author: Anton Versal <[email protected]>
Date:   Mon Nov 12 10:39:27 2018 +0100

    wip

commit 53dcf5dcd7a90c8282a1df2fbb1d05d752826344
Author: Anton Versal <[email protected]>
Date:   Mon Nov 12 10:39:17 2018 +0100

    wip
  • Loading branch information
antonversal committed Nov 12, 2018
1 parent feef403 commit 0dc743a
Show file tree
Hide file tree
Showing 9 changed files with 965 additions and 234 deletions.
68 changes: 37 additions & 31 deletions generator/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,17 @@ func (g *generator) isServiceDeprecated(field *google_protobuf.ServiceDescriptor
return *field.Options.Deprecated
}

func (g *generator) getTsTypeFromMessage(typeName *string) string {
func (g *generator) getTsTypeFromMessage(typeName *string, isInterface bool) string {
names := strings.Split(*typeName, ".")
importName := g.GetImportNameForMessage(*g.protoFile.Name, *typeName)
interfaceName := names[len(names)-1]
if isInterface {
interfaceName = "I" + interfaceName
}
if importName == "" {
return names[len(names)-1]
return interfaceName
}
return importName + "." + names[len(names)-1]
return importName + "." + interfaceName
}

func (g *generator) getTsFieldType(field *google_protobuf.FieldDescriptorProto) string {
Expand All @@ -172,7 +176,7 @@ func (g *generator) getTsFieldType(field *google_protobuf.FieldDescriptorProto)

if *field.Type == google_protobuf.FieldDescriptorProto_TYPE_MESSAGE ||
*field.Type == google_protobuf.FieldDescriptorProto_TYPE_ENUM {
return g.getTsTypeFromMessage(field.TypeName)
return g.getTsTypeFromMessage(field.TypeName, *field.Type == google_protobuf.FieldDescriptorProto_TYPE_MESSAGE)
}

return g.getTsFieldTypeForScalar(*field.Type)
Expand Down Expand Up @@ -212,7 +216,7 @@ func (g *generator) generateMessageInterface(message *google_protobuf.Descriptor
g.P("* @deprecated")
g.P("*/")
}
g.P(fmt.Sprintf("export interface %s {", g.messageName(message)))
g.P(fmt.Sprintf("export interface I%s {", g.messageName(message)))
for _, field := range message.Field {
g.generateField(field, false)
}
Expand All @@ -221,7 +225,7 @@ func (g *generator) generateMessageInterface(message *google_protobuf.Descriptor

func (g *generator) generateConstructor(message *google_protobuf.DescriptorProto) {
name := g.messageName(message)
g.P(fmt.Sprintf("constructor(attrs?: %s){", name))
g.P(fmt.Sprintf("constructor(attrs?: I%s){", name))
g.P("Object.assign(this, attrs)")
g.P("}")
}
Expand Down Expand Up @@ -375,17 +379,17 @@ func (g *generator) generateEncode(message *google_protobuf.DescriptorProto) {
g.P(fmt.Sprintf("for (const value of this.%s) {", name))
g.P("if (!value) { continue; }")
if field.GetTypeName() == ".google.protobuf.Timestamp" {
g.P(fmt.Sprintf("const msg = new %sMsg({seconds: Math.floor(value.getTime() / 1000) , nanos: value.getMilliseconds() * 1000000});", g.getTsTypeFromMessage(field.TypeName)))
g.P(fmt.Sprintf("const msg = new %s({seconds: Math.floor(value.getTime() / 1000) , nanos: value.getMilliseconds() * 1000000});", g.getTsTypeFromMessage(field.TypeName, false)))
} else {
g.P(fmt.Sprintf("const msg = new %sMsg(value);", g.getTsTypeFromMessage(field.TypeName)))
g.P(fmt.Sprintf("const msg = new %s(value);", g.getTsTypeFromMessage(field.TypeName, false)))
}
g.P(fmt.Sprintf("msg.encode(writer.uint32(%d).fork()).ldelim();", g.getFieldIndex(field)))
g.P("}")
} else {
if field.GetTypeName() == ".google.protobuf.Timestamp" {
g.P(fmt.Sprintf("const msg = new %sMsg({seconds: Math.floor(this.%s.getTime() / 1000) , nanos: this.%s.getMilliseconds() * 1000000});", g.getTsTypeFromMessage(field.TypeName), name, name))
g.P(fmt.Sprintf("const msg = new %s({seconds: Math.floor(this.%s.getTime() / 1000) , nanos: this.%s.getMilliseconds() * 1000000});", g.getTsTypeFromMessage(field.TypeName, false), name, name))
} else {
g.P(fmt.Sprintf("const msg = new %sMsg(this.%s);", g.getTsTypeFromMessage(field.TypeName), name))
g.P(fmt.Sprintf("const msg = new %s(this.%s);", g.getTsTypeFromMessage(field.TypeName, false), name))
}
g.P(fmt.Sprintf("msg.encode(writer.uint32(%d).fork()).ldelim();", g.getFieldIndex(field)))
}
Expand Down Expand Up @@ -427,7 +431,7 @@ func (g *generator) generateDecode(message *google_protobuf.DescriptorProto) {
g.P("? protobufjs.Reader.create(inReader)")
g.P(": inReader")
g.P("const end = length === undefined ? reader.len : reader.pos + length;")
g.P(fmt.Sprintf("const message = new %sMsg();", g.getTsTypeFromMessage(message.Name)))
g.P(fmt.Sprintf("const message = new %s();", g.getTsTypeFromMessage(message.Name, false)))
g.P("while (reader.pos < end) {")
g.P("const tag = reader.uint32()")
g.P("switch (tag >>> 3) {")
Expand Down Expand Up @@ -467,17 +471,17 @@ func (g *generator) generateDecode(message *google_protobuf.DescriptorProto) {
g.P(fmt.Sprintf("message.%s = [];", name))
g.P("}")
if field.GetTypeName() == ".google.protobuf.Timestamp" {
g.P(fmt.Sprintf("const %s = %sMsg.decode(reader, reader.uint32());", name, g.getTsTypeFromMessage(field.TypeName)))
g.P(fmt.Sprintf("const %s = %s.decode(reader, reader.uint32());", name, g.getTsTypeFromMessage(field.TypeName, false)))
g.P(fmt.Sprintf("message.%s.push(new Date(((%s.seconds || 0) * 1000) + ((%s.nanos || 0) / 1000000)));", name, name, name))
} else {
g.P(fmt.Sprintf("message.%s.push(%sMsg.decode(reader, reader.uint32()));", name, g.getTsTypeFromMessage(field.TypeName)))
g.P(fmt.Sprintf("message.%s.push(%s.decode(reader, reader.uint32()));", name, g.getTsTypeFromMessage(field.TypeName, false)))
}
} else {
if field.GetTypeName() == ".google.protobuf.Timestamp" {
g.P(fmt.Sprintf("const %s = %sMsg.decode(reader, reader.uint32());", name, g.getTsTypeFromMessage(field.TypeName)))
g.P(fmt.Sprintf("const %s = %s.decode(reader, reader.uint32());", name, g.getTsTypeFromMessage(field.TypeName, false)))
g.P(fmt.Sprintf("message.%s = new Date(((%s.seconds || 0) * 1000) + ((%s.nanos || 0) / 1000000));", name, name, name))
} else {
g.P(fmt.Sprintf("message.%s = %sMsg.decode(reader, reader.uint32());", name, g.getTsTypeFromMessage(field.TypeName)))
g.P(fmt.Sprintf("message.%s = %s.decode(reader, reader.uint32());", name, g.getTsTypeFromMessage(field.TypeName, false)))
}
}
} else {
Expand Down Expand Up @@ -538,7 +542,7 @@ func (g *generator) generateMessageClass(message *google_protobuf.DescriptorProt
g.P("*/")
}
name := g.messageName(message)
g.P(fmt.Sprintf("export class %sMsg implements %s{", name, name))
g.P(fmt.Sprintf("export class %s implements I%s{", name, name))
g.generateDecode(message)
for _, field := range message.Field {
g.generateField(field, true)
Expand Down Expand Up @@ -596,14 +600,16 @@ func (g *generator) generateDefinition(service *google_protobuf.ServiceDescripto
serverStreaming = *method.ServerStreaming
}
g.P(fmt.Sprintf("responseStream: %s,", strconv.FormatBool(serverStreaming)))
requestType := g.getTsTypeFromMessage(method.InputType)
g.P(fmt.Sprintf("requestType: %sMsg,", requestType))
responseType := g.getTsTypeFromMessage(method.OutputType)
g.P(fmt.Sprintf("responseType: %sMsg,", responseType))
g.P(fmt.Sprintf("requestSerialize: (args: %s) => new %sMsg(args).encode().finish() as Buffer,", requestType, requestType))
g.P(fmt.Sprintf("requestDeserialize: (argBuf: Buffer) => %sMsg.decode(argBuf),", requestType))
g.P(fmt.Sprintf("responseSerialize: (args: %s) => new %sMsg(args).encode().finish() as Buffer,", responseType, responseType))
g.P(fmt.Sprintf("responseDeserialize: (argBuf: Buffer) => %sMsg.decode(argBuf),", responseType))
requestType := g.getTsTypeFromMessage(method.InputType, false)
iRequestType := g.getTsTypeFromMessage(method.InputType, true)
g.P(fmt.Sprintf("requestType: %s,", requestType))
responseType := g.getTsTypeFromMessage(method.OutputType, false)
iResponseType := g.getTsTypeFromMessage(method.OutputType, true)
g.P(fmt.Sprintf("responseType: %s,", responseType))
g.P(fmt.Sprintf("requestSerialize: (args: %s) => new %s(args).encode().finish() as Buffer,", iRequestType, requestType))
g.P(fmt.Sprintf("requestDeserialize: (argBuf: Buffer) => %s.decode(argBuf),", requestType))
g.P(fmt.Sprintf("responseSerialize: (args: %s) => new %s(args).encode().finish() as Buffer,", iResponseType, responseType))
g.P(fmt.Sprintf("responseDeserialize: (argBuf: Buffer) => %s.decode(argBuf),", responseType))
g.P("},")
}
g.P("}")
Expand All @@ -621,8 +627,8 @@ func (g *generator) generateImplementation(service *google_protobuf.ServiceDescr

for _, method := range service.Method {
g.methodDeprecated(method)
inputTypeName := g.getTsTypeFromMessage(method.InputType)
outputTypeName := g.getTsTypeFromMessage(method.OutputType)
inputTypeName := g.getTsTypeFromMessage(method.InputType, true)
outputTypeName := g.getTsTypeFromMessage(method.OutputType, true)
if method.ServerStreaming != nil && *method.ServerStreaming && method.ClientStreaming != nil && *method.ClientStreaming {
g.P(fmt.Sprintf("%s(call: grpc.ServerDuplexStream<%s, %s>): void;", g.toLowerFirst(*method.Name), inputTypeName, outputTypeName))
} else if method.ServerStreaming != nil && *method.ServerStreaming {
Expand Down Expand Up @@ -653,28 +659,28 @@ func (g *generator) generateClient(service *google_protobuf.ServiceDescriptorPro
g.P(fmt.Sprintf("super(%sServiceDefinition, address, credentials, trace, options);", g.toLowerFirst(*service.Name)))
g.P("}")
for _, method := range service.Method {
inputTypeName := g.getTsTypeFromMessage(method.InputType)
inputTypeName := g.getTsTypeFromMessage(method.InputType, true)
g.methodDeprecated(method)
if method.ServerStreaming != nil && *method.ServerStreaming && method.ClientStreaming != nil && *method.ClientStreaming {
outputTypeName := g.getTsTypeFromMessage(method.OutputType)
outputTypeName := g.getTsTypeFromMessage(method.OutputType, true)
g.P(fmt.Sprintf("public %s(metadata?: grpcts.Metadata) {", g.toLowerFirst(*method.Name)))
g.methodDeprecatedLog(method)
g.P(fmt.Sprintf("return super.makeBidiStreamRequest<%s, %s>('%s', metadata);", inputTypeName, outputTypeName, g.toLowerFirst(*method.Name)))
g.P(fmt.Sprint("}"))
} else if method.ServerStreaming != nil && *method.ServerStreaming {
outputTypeName := g.getTsTypeFromMessage(method.OutputType)
outputTypeName := g.getTsTypeFromMessage(method.OutputType, true)
g.P(fmt.Sprintf("public %s(req: %s, metadata?: grpcts.Metadata) {", g.toLowerFirst(*method.Name), inputTypeName))
g.methodDeprecatedLog(method)
g.P(fmt.Sprintf("return super.makeServerStreamRequest<%s, %s>('%s', req, metadata);", inputTypeName, outputTypeName, g.toLowerFirst(*method.Name)))
g.P(fmt.Sprint("}"))
} else if method.ClientStreaming != nil && *method.ClientStreaming {
outputTypeName := g.getTsTypeFromMessage(method.OutputType)
outputTypeName := g.getTsTypeFromMessage(method.OutputType, true)
g.P(fmt.Sprintf("public %s(metadata?: grpcts.Metadata) {", g.toLowerFirst(*method.Name)))
g.methodDeprecatedLog(method)
g.P(fmt.Sprintf("return super.makeClientStreamRequest<%s, %s>('%s', metadata);", inputTypeName, outputTypeName, g.toLowerFirst(*method.Name)))
g.P(fmt.Sprint("};"))
} else {
outputTypeName := g.getTsTypeFromMessage(method.OutputType)
outputTypeName := g.getTsTypeFromMessage(method.OutputType, true)
g.P(fmt.Sprintf("public %s(req: %s, metadata?: grpcts.Metadata) {", g.toLowerFirst(*method.Name), inputTypeName))
g.methodDeprecatedLog(method)
g.P(fmt.Sprintf("return super.makeUnaryRequest<%s, %s>('%s', req, metadata);", inputTypeName, outputTypeName, g.toLowerFirst(*method.Name)))
Expand Down
4 changes: 2 additions & 2 deletions integrationTests/__tests__/decode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ describe('decode', () => {
beforeEach(() => {
const message = PbTest.fromObject(values)
buffer = PbTest.encode(message).finish()
decoded = Foo.TestMsg.decode(buffer)
decoded = Foo.Test.decode(buffer)
})

describe.each([
Expand Down Expand Up @@ -141,7 +141,7 @@ describe('decode changed protos', () => {
beforeEach(() => {
const message = PbTest.fromObject(values)
buffer = PbTest.encode(message).finish()
decoded = Foo.TestMsg.decode(buffer)
decoded = Foo.Test.decode(buffer)
})

it('ignores missing field', () => {
Expand Down
2 changes: 1 addition & 1 deletion integrationTests/__tests__/encode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ describe('encode', () => {
}

beforeEach(() => {
const user = new Foo.TestMsg(values)
const user = new Foo.Test(values)
buffer = user.encode().finish()
decoded = PbTest.toObject(PbTest.decode(buffer), {
enums: String,
Expand Down
Loading

0 comments on commit 0dc743a

Please sign in to comment.