Skip to content

Commit

Permalink
fix: better error handling (#20)
Browse files Browse the repository at this point in the history
Co-authored-by: vansangpfiev <[email protected]>
  • Loading branch information
vansangpfiev and sangjanai authored Jul 28, 2024
1 parent 828eb7b commit 253c19b
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 10 deletions.
47 changes: 37 additions & 10 deletions src/onnx_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ OnnxEngine::OnnxEngine() {
void OnnxEngine::LoadModel(
std::shared_ptr<Json::Value> json_body,
std::function<void(Json::Value&&, Json::Value&&)>&& callback) {
auto path = json_body->get("model_path", "").asString();
path_ = json_body->get("model_path", "").asString();
user_prompt_ = json_body->get("user_prompt", "USER: ").asString();
ai_prompt_ = json_body->get("ai_prompt", "ASSISTANT: ").asString();
system_prompt_ =
Expand All @@ -136,7 +136,7 @@ void OnnxEngine::LoadModel(
max_history_chat_ = json_body->get("max_history_chat", 2).asInt();
try {
std::cout << "Creating model..." << std::endl;
oga_model_ = OgaModel::Create(path.c_str());
oga_model_ = OgaModel::Create(path_.c_str());
std::cout << "Creating tokenizer..." << std::endl;
tokenizer_ = OgaTokenizer::Create(*oga_model_);
tokenizer_stream_ = OgaTokenizerStream::Create(*tokenizer_);
Expand All @@ -149,7 +149,7 @@ void OnnxEngine::LoadModel(
status["status_code"] = k200OK;
callback(std::move(status), std::move(json_resp));
model_id_ = GetModelId(*json_body);
LOG_INFO << "Model loaded successfully: " << path
LOG_INFO << "Model loaded successfully: " << path_
<< ", model_id: " << model_id_;
model_loaded_ = true;
start_time_ = std::chrono::system_clock::now().time_since_epoch() /
Expand All @@ -158,7 +158,7 @@ void OnnxEngine::LoadModel(
q_ = std::make_unique<trantor::ConcurrentTaskQueue>(1, model_id_);
}
} catch (const std::exception& e) {
std::cout << e.what() << std::endl;
std::cout << "Failed to load model: " << e.what() << std::endl;
oga_model_.reset();
tokenizer_.reset();
tokenizer_stream_.reset();
Expand All @@ -182,8 +182,8 @@ void OnnxEngine::HandleChatCompletion(
auto is_stream = json_body->get("stream", false).asBool();

std::string formatted_output = pre_prompt_;
int history_max = max_history_chat_ * 2; // both user and assistant

int history_max = max_history_chat_ * 2; // both user and assistant
int index = 0;
for (const auto& message : req.messages) {
std::string input_role = message["role"].asString();
Expand Down Expand Up @@ -272,8 +272,32 @@ void OnnxEngine::HandleChatCompletion(
auto duration_ms =
std::chrono::duration_cast<std::chrono::milliseconds>(end - start)
.count();
LOG_DEBUG << "Generated tokens per second: "
<< generated_tokens / duration_ms * 1000;
std::cout << "Generated tokens per second: "
<< generated_tokens / duration_ms * 1000 << std::endl;
if ((generated_tokens / duration_ms * 1000) < 1.0f) {
max_history_chat_ = std::max(1, max_history_chat_ / 2);
tokenizer_stream_.reset();
tokenizer_.reset();
oga_model_.reset();
generator.reset();
params.reset();
sequences.reset();
model_loaded_ = false;
LOG_WARN << "Something wrong happened, restart model and try again";
LOG_INFO << "Creating model...";
oga_model_ = OgaModel::Create(path_.c_str());
LOG_INFO << "Creating tokenizer...";
tokenizer_ = OgaTokenizer::Create(*oga_model_);
tokenizer_stream_ = OgaTokenizerStream::Create(*tokenizer_);
LOG_INFO << "Model loaded successfully: " << path_
<< ", model_id: " << model_id_;
model_loaded_ = true;
start_time_ = std::chrono::system_clock::now().time_since_epoch() /
std::chrono::milliseconds(1);
if (q_ == nullptr) {
q_ = std::make_unique<trantor::ConcurrentTaskQueue>(1, model_id_);
}
}

LOG_INFO << "End of result";
Json::Value resp_data;
Expand Down Expand Up @@ -326,10 +350,13 @@ void OnnxEngine::HandleChatCompletion(
status["has_error"] = false;
status["is_stream"] = false;
status["status_code"] = k200OK;
cb(std::move(status), std::move(resp_data));
cb(std::move(status), std::move(resp_data));
}
} catch (const std::exception& e) {
std::cout << e.what() << std::endl;
tokenizer_stream_.reset();
tokenizer_.reset();
oga_model_.reset();
std::cout << "Error during inference: " << e.what() << std::endl;
Json::Value json_resp;
json_resp["message"] = "Error during inference";
Json::Value status;
Expand Down
1 change: 1 addition & 0 deletions src/onnx_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,6 @@ class OnnxEngine : public EngineI {
uint64_t start_time_;
int max_history_chat_;
std::unique_ptr<trantor::ConcurrentTaskQueue> q_;
std::string path_;
};
} // namespace cortex_onnx

0 comments on commit 253c19b

Please sign in to comment.