Skip to content

Commit

Permalink
chore(model): refactor text generation task (#222)
Browse files Browse the repository at this point in the history
Because

- to refactor the current text generation task 

This commit

- divided the text generation task into basic, chat, and visual
question-answering mode
- use repeated structure to replace JSON-object string literal

---------

Signed-off-by: tony.wang.10101 <[email protected]>
  • Loading branch information
tonywang10101 authored Dec 3, 2023
1 parent 3705d49 commit d75a1db
Show file tree
Hide file tree
Showing 11 changed files with 354 additions and 35 deletions.
Binary file added .DS_Store
Binary file not shown.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
gen
#IntelliJ
.idea

.DS_Store
10 changes: 7 additions & 3 deletions common/task/v1alpha/task.proto
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@ enum Task {
TASK_TEXT_TO_IMAGE = 7;
// Task: TEXT GENERATION
TASK_TEXT_GENERATION = 8;
// Task: TEXT GENERATION CHAT
TASK_TEXT_GENERATION_CHAT = 9;
// Task: VISUAL QUESTION ANSWERING
TASK_VISUAL_QUESTION_ANSWERING = 10;
// Task: IMAGE TO IMAGE
TASK_IMAGE_TO_IMAGE = 9;
TASK_IMAGE_TO_IMAGE = 11;
// Task: TEXT EMBEDDINGS
TASK_TEXT_EMBEDDINGS = 10;
TASK_TEXT_EMBEDDINGS = 12;
// Task: SPEECH RECOGNITION
TASK_SPEECH_RECOGNITION = 11;
TASK_SPEECH_RECOGNITION = 13;
}
17 changes: 17 additions & 0 deletions model/model/v1alpha/common.proto
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,20 @@ message BoundingBox {
// Bounding box height value
float height = 4 [(google.api.field_behavior) = OUTPUT_ONLY];
}

// Additional hyperparameters for model inferences
// or other configuration not listsed in protobuf
message ExtraParamObject {
// Name of the hyperparameter
string param_name = 1;
// Value of the hyperparameter
string param_value = 2;
}

// Conversation based prompt for text generation model
message ConversationObject {
// Role name of the conversation
string role = 1;
// Content of the conversation
string content = 2;
}
33 changes: 27 additions & 6 deletions model/model/v1alpha/model.proto
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@ import "google/protobuf/timestamp.proto";
import "model/model/v1alpha/model_definition.proto";
import "model/model/v1alpha/task_classification.proto";
import "model/model/v1alpha/task_detection.proto";
import "model/model/v1alpha/task_image_to_image.proto";
import "model/model/v1alpha/task_instance_segmentation.proto";
import "model/model/v1alpha/task_keypoint.proto";
import "model/model/v1alpha/task_ocr.proto";
import "model/model/v1alpha/task_semantic_segmentation.proto";
import "model/model/v1alpha/task_text_generation.proto";
import "model/model/v1alpha/task_text_generation_chat.proto";
import "model/model/v1alpha/task_text_to_image.proto";
import "model/model/v1alpha/task_unspecified.proto";
import "model/model/v1alpha/task_visual_question_answering.proto";
import "protoc-gen-openapiv2/options/annotations.proto";

// LivenessRequest represents a request to check a service liveness status
Expand Down Expand Up @@ -471,10 +474,16 @@ message TaskInput {
SemanticSegmentationInput semantic_segmentation = 6;
// The text to image input
TextToImageInput text_to_image = 7;
// The image to image input
ImageToImageInput image_to_image = 8;
// The text generation input
TextGenerationInput text_generation = 8;
TextGenerationInput text_generation = 9;
// The text generation chat input
TextGenerationChatInput text_generation_chat = 10;
// The visual question answering input
VisualQuestionAnsweringInput visual_question_answering = 11;
// The unspecified task input
UnspecifiedInput unspecified = 9;
UnspecifiedInput unspecified = 12;
}
}

Expand All @@ -496,10 +505,16 @@ message TaskInputStream {
SemanticSegmentationInputStream semantic_segmentation = 6;
// The text to image input
TextToImageInput text_to_image = 7;
// The image to image input
ImageToImageInput image_to_image = 8;
// The text generation input
TextGenerationInput text_generation = 8;
TextGenerationInput text_generation = 9;
// The text generation chat input
TextGenerationChatInput text_generation_chat = 10;
// The visual question answering input
VisualQuestionAnsweringInput visual_question_answering = 11;
// The unspecified task input
UnspecifiedInput unspecified = 9;
UnspecifiedInput unspecified = 12;
}
}

Expand All @@ -521,10 +536,16 @@ message TaskOutput {
SemanticSegmentationOutput semantic_segmentation = 6 [(google.api.field_behavior) = OUTPUT_ONLY];
// The text to image output
TextToImageOutput text_to_image = 7 [(google.api.field_behavior) = OUTPUT_ONLY];
// The image to image output
ImageToImageOutput image_to_image = 8 [(google.api.field_behavior) = OUTPUT_ONLY];
// The text generation output
TextGenerationOutput text_generation = 8 [(google.api.field_behavior) = OUTPUT_ONLY];
TextGenerationOutput text_generation = 9 [(google.api.field_behavior) = OUTPUT_ONLY];
// The text generation output
TextGenerationChatOutput text_generation_chat = 10 [(google.api.field_behavior) = OUTPUT_ONLY];
// The text generation output
VisualQuestionAnsweringOutput visual_question_answering = 11 [(google.api.field_behavior) = OUTPUT_ONLY];
// The unspecified task output
UnspecifiedOutput unspecified = 9 [(google.api.field_behavior) = OUTPUT_ONLY];
UnspecifiedOutput unspecified = 12 [(google.api.field_behavior) = OUTPUT_ONLY];
}
}

Expand Down
36 changes: 36 additions & 0 deletions model/model/v1alpha/task_image_to_image.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
syntax = "proto3";

package model.model.v1alpha;

// Google api
import "google/api/field_behavior.proto";
import "model/model/v1alpha/common.proto";

// ImageToImageInput represents the input of image to image task
message ImageToImageInput {
// The Prompt Image, only for multimodal input
oneof type {
// Image type URL
string prompt_image_url = 1;
// Image type base64
string prompt_image_base64 = 2;
}
// The prompt text
optional string prompt = 3 [(google.api.field_behavior) = REQUIRED];
// The steps, default is 5
optional int32 steps = 4 [(google.api.field_behavior) = OPTIONAL];
// The guidance scale, default is 7.5
optional float cfg_scale = 5 [(google.api.field_behavior) = OPTIONAL];
// The seed, default is 0
optional int32 seed = 6 [(google.api.field_behavior) = OPTIONAL];
// The number of generated samples, default is 1
optional int32 samples = 7 [(google.api.field_behavior) = OPTIONAL];
// The extra parameters
optional ExtraParamObject extra_params = 8 [(google.api.field_behavior) = OPTIONAL];
}

// ImageToImageOutput represents the output of image to image task
message ImageToImageOutput {
// List of generated images
repeated string images = 1 [(google.api.field_behavior) = OUTPUT_ONLY];
}
20 changes: 6 additions & 14 deletions model/model/v1alpha/task_text_generation.proto
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,22 @@ package model.model.v1alpha;

// Google api
import "google/api/field_behavior.proto";
import "model/model/v1alpha/common.proto";

// TextGenerationInput represents the input of text generation task
message TextGenerationInput {
// The prompt text
string prompt = 1 [(google.api.field_behavior) = REQUIRED];
// The Prompt Image, only for multimodal input
oneof type {
// Image type URL
string prompt_image_url = 2;
// Image type base64
string prompt_image_base64 = 3;
}
// The maximum number of tokens for model to generate
optional int32 max_new_tokens = 4 [(google.api.field_behavior) = OPTIONAL];
// The trigger words to stop generation
optional string stop_words_list = 5 [(google.api.field_behavior) = OPTIONAL];
optional int32 max_new_tokens = 2 [(google.api.field_behavior) = OPTIONAL];
// The temperature for sampling
optional float temperature = 6 [(google.api.field_behavior) = OPTIONAL];
optional float temperature = 3 [(google.api.field_behavior) = OPTIONAL];
// Top k for sampling
optional int32 top_k = 7 [(google.api.field_behavior) = OPTIONAL];
optional int32 top_k = 4 [(google.api.field_behavior) = OPTIONAL];
// The seed
optional int32 seed = 8 [(google.api.field_behavior) = OPTIONAL];
optional int32 seed = 5 [(google.api.field_behavior) = OPTIONAL];
// The extra parameters
optional string extra_params = 9 [(google.api.field_behavior) = OPTIONAL];
repeated ExtraParamObject extra_params = 6 [(google.api.field_behavior) = OPTIONAL];
}

// TextGenerationOutput represents the output of text generation task
Expand Down
29 changes: 29 additions & 0 deletions model/model/v1alpha/task_text_generation_chat.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
syntax = "proto3";

package model.model.v1alpha;

// Google api
import "google/api/field_behavior.proto";
import "model/model/v1alpha/common.proto";

// TextGenerationChatInput represents the input of text generation chat task
message TextGenerationChatInput {
// The prompt text
repeated ConversationObject conversation = 1 [(google.api.field_behavior) = REQUIRED];
// The maximum number of tokens for model to generate
optional int32 max_new_tokens = 2 [(google.api.field_behavior) = OPTIONAL];
// The temperature for sampling
optional float temperature = 3 [(google.api.field_behavior) = OPTIONAL];
// Top k for sampling
optional int32 top_k = 4 [(google.api.field_behavior) = OPTIONAL];
// The seed
optional int32 seed = 5 [(google.api.field_behavior) = OPTIONAL];
// The extra parameters
repeated ExtraParamObject extra_params = 6 [(google.api.field_behavior) = OPTIONAL];
}

// TextGenerationChatOutput represents the output of text generation chat task
message TextGenerationChatOutput {
// The text generated by the model
string text = 1 [(google.api.field_behavior) = OUTPUT_ONLY];
}
3 changes: 2 additions & 1 deletion model/model/v1alpha/task_text_to_image.proto
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package model.model.v1alpha;

// Google api
import "google/api/field_behavior.proto";
import "model/model/v1alpha/common.proto";

// TextToImageInput represents the input of text to image task
message TextToImageInput {
Expand All @@ -25,7 +26,7 @@ message TextToImageInput {
// The number of generated samples, default is 1
optional int32 samples = 7 [(google.api.field_behavior) = OPTIONAL];
// The extra parameters
optional string extra_params = 8 [(google.api.field_behavior) = OPTIONAL];
optional ExtraParamObject extra_params = 8 [(google.api.field_behavior) = OPTIONAL];
}

// TextToImageOutput represents the output of text to image task
Expand Down
36 changes: 36 additions & 0 deletions model/model/v1alpha/task_visual_question_answering.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
syntax = "proto3";

package model.model.v1alpha;

// Google api
import "google/api/field_behavior.proto";
import "model/model/v1alpha/common.proto";

// VisualQuestionAnsweringInput represents the input of visaul question answering task
message VisualQuestionAnsweringInput {
// The prompt text
string prompt = 1 [(google.api.field_behavior) = REQUIRED];
// The Prompt Image, only for multimodal input
oneof type {
// Image type URL
string prompt_image_url = 2;
// Image type base64
string prompt_image_base64 = 3;
}
// The maximum number of tokens for model to generate
optional int32 max_new_tokens = 4 [(google.api.field_behavior) = OPTIONAL];
// The temperature for sampling
optional float temperature = 5 [(google.api.field_behavior) = OPTIONAL];
// Top k for sampling
optional int32 top_k = 6 [(google.api.field_behavior) = OPTIONAL];
// The seed
optional int32 seed = 7 [(google.api.field_behavior) = OPTIONAL];
// The extra parameters
repeated ExtraParamObject extra_params = 8 [(google.api.field_behavior) = OPTIONAL];
}

// VisualQuestionAnsweringOutput represents the output of visaul question answering task
message VisualQuestionAnsweringOutput {
// The text generated by the model
string text = 1 [(google.api.field_behavior) = OUTPUT_ONLY];
}
Loading

0 comments on commit d75a1db

Please sign in to comment.