Skip to content

Commit

Permalink
修改推理bug和单词拼写时空格处理问题
Browse files Browse the repository at this point in the history
  • Loading branch information
chenkui164 committed Feb 1, 2023
1 parent b4d5f05 commit 77a85dc
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 37 deletions.
101 changes: 67 additions & 34 deletions src/lib/Vocab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,58 +67,91 @@ bool Vocab::isChinese(string ch)
return false;
}


string Vocab::vector2stringV2(vector<int> in)
{
int i;
list<string> words;

int strstatus = 0;
bool isAllChinese = true;
bool isAllAlpha = true;

string combine = "";
int is_combining = false;
int is_pre_english = false;
int pre_english_len = 0;

int is_combining = false;
string combine = "";

for (auto it = in.begin(); it != in.end(); it++) {
string word = vocab[*it];
// words.push_back(word);
int sub_word = !(word.find("@@") == string::npos);
// cout << word << "," << word.size() << "," << isChinese(word) << ","
// << sub_word << endl;

// step1 空白字符不处理
if (word == "<s>" || word == "</s>" || word == "<unk>")
continue;
if (sub_word) {
combine += word.erase(word.length() - 2);
is_combining = true;
continue;
} else if (is_combining) {
combine += word;
is_combining = false;
// cout << combine << endl;
word = combine;
combine = "";

// step2 将音素拼成完整的单词
{
int sub_word = !(word.find("@@") == string::npos);

//处理单词起始和中间部分
if (sub_word) {
combine += word.erase(word.length() - 2);
is_combining = true;
continue;
}
//处理单词结束部分, combine结束
else if (is_combining) {
combine += word;
is_combining = false;
word = combine;
combine = "";
}
}

if (isChinese(word)) {
words.push_back(word);
is_pre_english = false;
} else {
if (is_pre_english && pre_english_len > 1) {
words.push_back(" ");
if (word.size() == 1) {
word[0] = word[0] - 32;
}
// step3 处理英文单词,单词之间需要加入空格,缩写转成大写。
{

//输入是汉字不需要处理
if (isChinese(word)) {
words.push_back(word);
} else {
if (word.size() == 1) {
is_pre_english = false;
}
//输入是英文单词
else {

// 如果前面是汉字,不论当前是多个字母还是单个字母的单词,都不需要加空格
if (!is_pre_english) {
word[0] = word[0] - 32;
words.push_back(word);
pre_english_len = word.size();

}
words.push_back(word);

// 如果前面是单词
else {

// 单个字母的单词变大写
if (word.size() == 1) {
word[0] = word[0] - 32;
}

// 前面单词的长度是大于1的,说明当前输入不属于缩写的部分,需要和前面的单词分割开,加空格
if (pre_english_len > 1) {
words.push_back(" ");
words.push_back(word);
pre_english_len = word.size();
}
// 前面单词的长度是等于1,可能是属于缩写
else {
// 当前长度大于1, 不可能是缩写,所以需要插入空格
if (word.size() > 1) {
words.push_back(" ");
}
words.push_back(word);
pre_english_len = word.size();
}
}

is_pre_english = true;

}
is_pre_english = true;
pre_english_len = word.size();
}
}

Expand Down
11 changes: 8 additions & 3 deletions src/lib/paraformer/Predictor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,17 +150,22 @@ void Predictor::forward(Tensor<float> *&din)
}
// printf("sum_alphas is %f\n", sum_alphas);

int len_labels = round(sum_alphas + 0.45);
// printf("len_labels is %d\n", len_labels);
// int len_labels = round(sum_alphas + 0.45);
int token_num = int(sum_alphas + 0.45);

Tensor<float> *tout = new Tensor<float>(len_labels, 512);
Tensor<float> *tout = new Tensor<float>(token_num, 512);
tout->zeros();

int token_idx = 0;
offset = 0;
for (i = 0; i < mm; i++) {
if (fires[i] >= 1) {
memcpy(tout->buff + offset, frames.buff + i * 512,
512 * sizeof(float));
offset += 512;
token_idx++;
if (token_idx == token_num)
break;
}
}
free(conv_im2col);
Expand Down

0 comments on commit 77a85dc

Please sign in to comment.