-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
chore(model): refactor text generation task (#222)
Because - to refactor the current text generation task This commit - divided the text generation task into basic, chat, and visual question-answering mode - use repeated structure to replace JSON-object string literal --------- Signed-off-by: tony.wang.10101 <[email protected]>
- Loading branch information
1 parent
3705d49
commit d75a1db
Showing
11 changed files
with
354 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,3 +2,5 @@ | |
gen | ||
#IntelliJ | ||
.idea | ||
|
||
.DS_Store |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
syntax = "proto3"; | ||
|
||
package model.model.v1alpha; | ||
|
||
// Google api | ||
import "google/api/field_behavior.proto"; | ||
import "model/model/v1alpha/common.proto"; | ||
|
||
// ImageToImageInput represents the input of image to image task | ||
message ImageToImageInput { | ||
// The Prompt Image, only for multimodal input | ||
oneof type { | ||
// Image type URL | ||
string prompt_image_url = 1; | ||
// Image type base64 | ||
string prompt_image_base64 = 2; | ||
} | ||
// The prompt text | ||
optional string prompt = 3 [(google.api.field_behavior) = REQUIRED]; | ||
// The steps, default is 5 | ||
optional int32 steps = 4 [(google.api.field_behavior) = OPTIONAL]; | ||
// The guidance scale, default is 7.5 | ||
optional float cfg_scale = 5 [(google.api.field_behavior) = OPTIONAL]; | ||
// The seed, default is 0 | ||
optional int32 seed = 6 [(google.api.field_behavior) = OPTIONAL]; | ||
// The number of generated samples, default is 1 | ||
optional int32 samples = 7 [(google.api.field_behavior) = OPTIONAL]; | ||
// The extra parameters | ||
optional ExtraParamObject extra_params = 8 [(google.api.field_behavior) = OPTIONAL]; | ||
} | ||
|
||
// ImageToImageOutput represents the output of image to image task | ||
message ImageToImageOutput { | ||
// List of generated images | ||
repeated string images = 1 [(google.api.field_behavior) = OUTPUT_ONLY]; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
syntax = "proto3"; | ||
|
||
package model.model.v1alpha; | ||
|
||
// Google api | ||
import "google/api/field_behavior.proto"; | ||
import "model/model/v1alpha/common.proto"; | ||
|
||
// TextGenerationChatInput represents the input of text generation chat task | ||
message TextGenerationChatInput { | ||
// The prompt text | ||
repeated ConversationObject conversation = 1 [(google.api.field_behavior) = REQUIRED]; | ||
// The maximum number of tokens for model to generate | ||
optional int32 max_new_tokens = 2 [(google.api.field_behavior) = OPTIONAL]; | ||
// The temperature for sampling | ||
optional float temperature = 3 [(google.api.field_behavior) = OPTIONAL]; | ||
// Top k for sampling | ||
optional int32 top_k = 4 [(google.api.field_behavior) = OPTIONAL]; | ||
// The seed | ||
optional int32 seed = 5 [(google.api.field_behavior) = OPTIONAL]; | ||
// The extra parameters | ||
repeated ExtraParamObject extra_params = 6 [(google.api.field_behavior) = OPTIONAL]; | ||
} | ||
|
||
// TextGenerationChatOutput represents the output of text generation chat task | ||
message TextGenerationChatOutput { | ||
// The text generated by the model | ||
string text = 1 [(google.api.field_behavior) = OUTPUT_ONLY]; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
syntax = "proto3"; | ||
|
||
package model.model.v1alpha; | ||
|
||
// Google api | ||
import "google/api/field_behavior.proto"; | ||
import "model/model/v1alpha/common.proto"; | ||
|
||
// VisualQuestionAnsweringInput represents the input of visaul question answering task | ||
message VisualQuestionAnsweringInput { | ||
// The prompt text | ||
string prompt = 1 [(google.api.field_behavior) = REQUIRED]; | ||
// The Prompt Image, only for multimodal input | ||
oneof type { | ||
// Image type URL | ||
string prompt_image_url = 2; | ||
// Image type base64 | ||
string prompt_image_base64 = 3; | ||
} | ||
// The maximum number of tokens for model to generate | ||
optional int32 max_new_tokens = 4 [(google.api.field_behavior) = OPTIONAL]; | ||
// The temperature for sampling | ||
optional float temperature = 5 [(google.api.field_behavior) = OPTIONAL]; | ||
// Top k for sampling | ||
optional int32 top_k = 6 [(google.api.field_behavior) = OPTIONAL]; | ||
// The seed | ||
optional int32 seed = 7 [(google.api.field_behavior) = OPTIONAL]; | ||
// The extra parameters | ||
repeated ExtraParamObject extra_params = 8 [(google.api.field_behavior) = OPTIONAL]; | ||
} | ||
|
||
// VisualQuestionAnsweringOutput represents the output of visaul question answering task | ||
message VisualQuestionAnsweringOutput { | ||
// The text generated by the model | ||
string text = 1 [(google.api.field_behavior) = OUTPUT_ONLY]; | ||
} |
Oops, something went wrong.