Skip to content

Commit

Permalink
Add Type Argument To Greedy Decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
apaniukov committed Jan 30, 2024
1 parent 2e0bd47 commit 7f601d0
Showing 1 changed file with 4 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,12 @@ def greedy_decoder(input) -> Model:
return token_ids.output(0)


def add_greedy_decoding(text_generation_model: Model, logits_output: str = LOGITS_OUTPUT_NAME) -> Model:
def add_greedy_decoding(
text_generation_model: Model, logits_output: str = LOGITS_OUTPUT_NAME, output_type: Type = Type.i64
) -> Model:
ppp = PrePostProcessor(text_generation_model)
ppp.output(logits_output).postprocess().custom(greedy_decoder)
ppp.output(logits_output).tensor().set_element_type(output_type)
model = ppp.build()
model.output(logits_output).tensor.set_names({TOKEN_IDS_OUTPUT_NAME})
return model
Expand Down

0 comments on commit 7f601d0

Please sign in to comment.