Skip to content

Commit

Permalink
Compatibility with TF 2.4 TF_TString:
Browse files Browse the repository at this point in the history
In tensorflow 2.4, TF_StringDecode, TF_StringEncode, and TF_StringEncodedSize are replaced by TF_TString.
  • Loading branch information
ljn917 committed Sep 21, 2020
1 parent c38997c commit 81615cc
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion src/Model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ void Model::init() {
}

void Model::save(const std::string &ckpt) {
#ifdef TENSORFLOW_C_TF_TSTRING_H_
std::unique_ptr<TF_TString, decltype(&TF_TString_Dealloc)> tstr(new TF_TString, &TF_TString_Dealloc);
TF_TString_Copy(tstr.get(), ckpt.c_str(), ckpt.size());
auto deallocator = [](void* data, size_t len, void* arg) {};
TF_Tensor* t = TF_NewTensor(TF_STRING, nullptr, 0, tstr.get(), 1, deallocator, nullptr);
#else
// Encode file_name to tensor
size_t size = 8 + TF_StringEncodedSize(ckpt.length());
TF_Tensor* t = TF_AllocateTensor(TF_STRING, nullptr, 0, size);
Expand All @@ -67,6 +73,7 @@ void Model::save(const std::string &ckpt) {

memset(data, 0, 8); // 8-byte offset of first string.
TF_StringEncode(ckpt.c_str(), ckpt.length(), (char*)(data + 8), size - 8, status);
#endif // TENSORFLOW_C_TF_TSTRING_H_

// Check errors
if (!this->status_check(false)) {
Expand Down Expand Up @@ -95,13 +102,19 @@ void Model::save(const std::string &ckpt) {
}

void Model::restore(const std::string& ckpt) {

#ifdef TENSORFLOW_C_TF_TSTRING_H_
std::unique_ptr<TF_TString, decltype(&TF_TString_Dealloc)> tstr(new TF_TString, &TF_TString_Dealloc);
TF_TString_Copy(tstr.get(), ckpt.c_str(), ckpt.size());
auto deallocator = [](void* data, size_t len, void* arg) {};
TF_Tensor* t = TF_NewTensor(TF_STRING, nullptr, 0, tstr.get(), 1, deallocator, nullptr);
#else
// Encode file_name to tensor
size_t size = 8 + TF_StringEncodedSize(ckpt.size());
TF_Tensor* t = TF_AllocateTensor(TF_STRING, nullptr, 0, size);
char* data = static_cast<char *>(TF_TensorData(t));
for (int i=0; i<8; i++) {data[i]=0;}
TF_StringEncode(ckpt.c_str(), ckpt.size(), data + 8, size - 8, status);
#endif // TENSORFLOW_C_TF_TSTRING_H_

// Check errors
if (!this->status_check(false)) {
Expand Down

0 comments on commit 81615cc

Please sign in to comment.