Skip to content

Commit

Permalink
cpplint passed
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed Jan 23, 2025
1 parent b7c91b4 commit 42d6d24
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 50 deletions.
109 changes: 75 additions & 34 deletions sherpa-onnx/csrc/offline-tts-cache-mechanism.cc
Original file line number Diff line number Diff line change
@@ -1,30 +1,58 @@
// sherpa-onnx/csrc/offline-tts-cache-mechanism.cc
//
// @mah92 From Iranian people to the comunity with love
// Copyright (c) 2025 @mah92 From Iranian people to the community with love

#include "sherpa-onnx/csrc/offline-tts-cache-mechanism.h"

#include <algorithm>
#include <fstream>
#include <iostream>
#include <filesystem>
#include <chrono>
#include <algorithm>
#include <iostream>
#include <limits>

#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/wave-reader.h"
#include "sherpa-onnx/csrc/wave-writer.h"

// Platform-specific time functions
#if defined(_WIN32)
#include <windows.h>
#else
#include <sys/time.h>
#include <unistd.h>
#endif

namespace sherpa_onnx {

CacheMechanism::CacheMechanism(const std::string &cache_dir, int32_t cache_size)
: cache_dir_(cache_dir), cache_size_bytes_(cache_size), used_cache_size_bytes_(0) {
// Helper function to get the current time in seconds
static int64_t GetCurrentTimeInSeconds() {
#if defined(_WIN32)
// Windows implementation
FILETIME ft;
GetSystemTimeAsFileTime(&ft);
uint64_t time = ((uint64_t)ft.dwHighDateTime << 32) | ft.dwLowDateTime;
return static_cast<int64_t>(time / 10000000ULL - 11644473600ULL);
#else
// Unix implementation
struct timeval tv;
gettimeofday(&tv, nullptr);
return static_cast<int64_t>(tv.tv_sec);
#endif
}

CacheMechanism::CacheMechanism(const std::string &cache_dir,
int32_t cache_size)
: cache_dir_(cache_dir),
cache_size_bytes_(cache_size),
used_cache_size_bytes_(0) {

// Create the cache directory if it doesn't exist
if (!std::filesystem::exists(cache_dir_)) {
bool dir_created = std::filesystem::create_directory(cache_dir_);
if (!dir_created) {
SHERPA_ONNX_LOGE("Unable to create cache directory: %s", cache_dir_.c_str());
SHERPA_ONNX_LOGE("Unable to create cache directory: %s",
cache_dir_.c_str());
SHERPA_ONNX_LOGE("Cache mechanism disabled!");
cache_mechanism_inited_ = false;
return;
Expand All @@ -38,37 +66,41 @@ CacheMechanism::CacheMechanism(const std::string &cache_dir, int32_t cache_size)
UpdateCacheVector();

// Initialize the last save time
last_save_time_ = std::chrono::steady_clock::now();
last_save_time_ = GetCurrentTimeInSeconds();

// Indicate that initialization has been successful
cache_mechanism_inited_ = true;
}

CacheMechanism::~CacheMechanism() {
if(cache_mechanism_inited_ == false) return;
if (cache_mechanism_inited_ == false) return;

// Save the repeat counts on destruction
SaveRepeatCounts();
}

void CacheMechanism::AddWavFile(const std::string &text_hash, const std::vector<float> &samples, int32_t sample_rate) {
void CacheMechanism::AddWavFile(
const std::string &text_hash,
const std::vector<float> &samples,
const int32_t sample_rate) {
std::lock_guard<std::recursive_mutex> lock(mutex_);

if(cache_mechanism_inited_ == false) return;
if (cache_mechanism_inited_ == false) return;

std::string file_path = cache_dir_ + "/" + text_hash + ".wav";

// Check if the file physically exists in the cache directory
bool file_exists = std::filesystem::exists(file_path);

if (!file_exists) { // If the file does not exist, add it to the cache
if (!file_exists) { // If the file does not exist, add it to the cache
// Ensure the cache does not exceed its size limit
EnsureCacheLimit();

// Write the audio samples to a WAV file
bool success = WriteWave(file_path, sample_rate, samples.data(), samples.size());
bool success = WriteWave(file_path,
sample_rate, samples.data(), samples.size());
if (success) {
// Calculate the size of the new WAV file and add it to the total cache size
// Calculate size of the new WAV file and add it to the total cache size
std::ifstream file(file_path, std::ios::binary | std::ios::ate);
if (file.is_open()) {
used_cache_size_bytes_ += file.tellg();
Expand All @@ -79,18 +111,20 @@ void CacheMechanism::AddWavFile(const std::string &text_hash, const std::vector<
}
}

std::vector<float> CacheMechanism::GetWavFile(const std::string &text_hash, int32_t &sample_rate) {
std::vector<float> CacheMechanism::GetWavFile(
const std::string &text_hash,
int32_t *sample_rate) {
std::lock_guard<std::recursive_mutex> lock(mutex_);

std::vector<float> samples;

if(cache_mechanism_inited_ == false) return samples;
if (cache_mechanism_inited_ == false) return samples;

std::string file_path = cache_dir_ + "/" + text_hash + ".wav";

if (std::filesystem::exists(file_path)) {
bool is_ok = false;
samples = ReadWave(file_path, &sample_rate, &is_ok);
samples = ReadWave(file_path, sample_rate, &is_ok);

if (is_ok == false) {
SHERPA_ONNX_LOGE("Failed to read cached file: %s", file_path.c_str());
Expand All @@ -99,14 +133,14 @@ std::vector<float> CacheMechanism::GetWavFile(const std::string &text_hash, int3

// Ensure the text_hash exists in the map before incrementing the count
if (repeat_counts_.find(text_hash) == repeat_counts_.end()) {
repeat_counts_[text_hash] = 1; // Initialize if it doesn't exist
repeat_counts_[text_hash] = 1; // Initialize if it doesn't exist
} else {
repeat_counts_[text_hash]++; // Increment the repeat count
repeat_counts_[text_hash]++; // Increment the repeat count
}

// Save the repeat counts every 10 minutes
auto now = std::chrono::steady_clock::now();
if (std::chrono::duration_cast<std::chrono::seconds>(now - last_save_time_).count() >= 10 * 60) {
int64_t now = GetCurrentTimeInSeconds();
if (now - last_save_time_ >= 10 * 60) {
SaveRepeatCounts();
last_save_time_ = now;
}
Expand All @@ -115,15 +149,15 @@ std::vector<float> CacheMechanism::GetWavFile(const std::string &text_hash, int3
}

int32_t CacheMechanism::GetCacheSize() const {
if(cache_mechanism_inited_ == false) return 0;
if (cache_mechanism_inited_ == false) return 0;

return cache_size_bytes_;
}

void CacheMechanism::SetCacheSize(int32_t cache_size) {
std::lock_guard<std::recursive_mutex> lock(mutex_);

if(cache_mechanism_inited_ == false) return;
if (cache_mechanism_inited_ == false) return;

cache_size_bytes_ = cache_size;

Expand All @@ -133,7 +167,7 @@ void CacheMechanism::SetCacheSize(int32_t cache_size) {
void CacheMechanism::ClearCache() {
std::lock_guard<std::recursive_mutex> lock(mutex_);

if(cache_mechanism_inited_ == false) return;
if (cache_mechanism_inited_ == false) return;

// Remove all WAV files in the cache directory
for (const auto &entry : std::filesystem::directory_iterator(cache_dir_)) {
Expand All @@ -156,7 +190,7 @@ void CacheMechanism::ClearCache() {
int32_t CacheMechanism::GetTotalUsedCacheSize() const {
std::lock_guard<std::recursive_mutex> lock(mutex_);

if(cache_mechanism_inited_ == false) return 0;
if (cache_mechanism_inited_ == false) return 0;

return used_cache_size_bytes_;
}
Expand All @@ -174,7 +208,8 @@ void CacheMechanism::LoadRepeatCounts() {
// Open the file for reading
std::ifstream ifs(repeat_count_file);
if (!ifs.is_open()) {
SHERPA_ONNX_LOGE("Failed to open repeat count file: %s", repeat_count_file.c_str());
SHERPA_ONNX_LOGE("Failed to open repeat count file: %s",
repeat_count_file.c_str());
return; // Skip loading if the file cannot be opened
}

Expand All @@ -196,15 +231,17 @@ void CacheMechanism::SaveRepeatCounts() {
// Open the file for writing
std::ofstream ofs(repeat_count_file);
if (!ofs.is_open()) {
SHERPA_ONNX_LOGE("Failed to open repeat count file for writing: %s", repeat_count_file.c_str());
SHERPA_ONNX_LOGE("Failed to open repeat count file for writing: %s",
repeat_count_file.c_str());
return; // Skip saving if the file cannot be opened
}

// Write the repeat counts to the file
for (const auto &entry : repeat_counts_) {
ofs << entry.first << " " << entry.second;
if (!ofs) {
SHERPA_ONNX_LOGE("Failed to write repeat count for text hash: %s", entry.first.c_str());
SHERPA_ONNX_LOGE("Failed to write repeat count for text hash: %s",
entry.first.c_str());
return; // Stop writing if an error occurs
}
ofs << std::endl;
Expand All @@ -226,12 +263,14 @@ void CacheMechanism::RemoveWavFile(const std::string &text_hash) {
// Remove the entry from the repeat counts and cache vector
if (repeat_counts_.find(text_hash) != repeat_counts_.end()) {
repeat_counts_.erase(text_hash);
cache_vector_.erase(std::remove(cache_vector_.begin(), cache_vector_.end(), text_hash), cache_vector_.end());
cache_vector_.erase(
std::remove(cache_vector_.begin(), cache_vector_.end(), text_hash),
cache_vector_.end());
}
}

void CacheMechanism::UpdateCacheVector() {
used_cache_size_bytes_ = 0; // Reset the total cache size before recalculating
used_cache_size_bytes_ = 0; // Reset total cache size before recalculating

for (const auto &entry : std::filesystem::directory_iterator(cache_dir_)) {
if (entry.path().extension() == ".wav") {
Expand All @@ -252,9 +291,11 @@ void CacheMechanism::UpdateCacheVector() {
}

void CacheMechanism::EnsureCacheLimit() {
if(used_cache_size_bytes_ > cache_size_bytes_) {
auto target_cache_size = std::max(static_cast<int> (cache_size_bytes_*0.95), 0); //Remove more to prevent deleting every step
while (used_cache_size_bytes_> 0 && used_cache_size_bytes_ > target_cache_size) {
if (used_cache_size_bytes_ > cache_size_bytes_) {
auto target_cache_size
= std::max(static_cast<int> (cache_size_bytes_*0.95), 0);
while (used_cache_size_bytes_> 0
&& used_cache_size_bytes_ > target_cache_size) {
// Cache is full, remove the least repeated file
std::string least_repeated_file = GetLeastRepeatedFile();
RemoveWavFile(least_repeated_file);
Expand All @@ -281,4 +322,4 @@ std::string CacheMechanism::GetLeastRepeatedFile() {
return least_repeated_file;
}

} // namespace sherpa_onnx
} // namespace sherpa_onnx
22 changes: 13 additions & 9 deletions sherpa-onnx/csrc/offline-tts-cache-mechanism.h
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
// sherpa-onnx/csrc/offline-tts-cache-mechanism.h
//
// @mah92 From Iranian people to the comunity with love
// Copyright (c) 2025 @mah92 From Iranian people to the community with love

#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_CACHE_MECHANISM_H_
#define SHERPA_ONNX_CSRC_OFFLINE_TTS_CACHE_MECHANISM_H_

#include <string>
#include <vector>
#include <unordered_map>
#include <mutex>
#include <chrono>
#include <mutex> // NOLINT

namespace sherpa_onnx {

Expand All @@ -19,18 +18,23 @@ class CacheMechanism {
~CacheMechanism();

// Add a new wav file to the cache
void AddWavFile(const std::string &text_hash, const std::vector<float> &samples, int32_t sample_rate);
void AddWavFile(
const std::string &text_hash,
const std::vector<float> &samples,
const int32_t sample_rate);

// Get the cached wav file if it exists
std::vector<float> GetWavFile(const std::string &text_hash, int32_t &sample_rate);
std::vector<float> GetWavFile(
const std::string &text_hash,
int32_t *sample_rate);

// Get the current cache size in bytes
int32_t GetCacheSize() const;

// Set the cache size in bytes
void SetCacheSize(int32_t cache_size);

// Remove all the wav files in the cache
// Remove all the wav files in the cache
void ClearCache();

// To get total used cache size(for wav files) in bytes
Expand Down Expand Up @@ -73,13 +77,13 @@ class CacheMechanism {
// Mutex for thread safety (recursive to avoid deadlocks)
mutable std::recursive_mutex mutex_;

// Time of last save
std::chrono::steady_clock::time_point last_save_time_;
// Time of last save (in seconds since epoch)
int64_t last_save_time_;

// if cache mechanism is inited successfully
bool cache_mechanism_inited_;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_CACHE_MECHANISM_H_
#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_CACHE_MECHANISM_H_
18 changes: 12 additions & 6 deletions sherpa-onnx/csrc/offline-tts.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,15 @@ std::string OfflineTtsConfig::ToString() const {
}

OfflineTts::OfflineTts(const OfflineTtsConfig &config)
: config_(config), impl_(OfflineTtsImpl::Create(config)), cache_mechanism_(nullptr) {}
: config_(config),
impl_(OfflineTtsImpl::Create(config)),
cache_mechanism_(nullptr) {}

template <typename Manager>
OfflineTts::OfflineTts(Manager *mgr, const OfflineTtsConfig &config)
: config_(config), impl_(OfflineTtsImpl::Create(mgr, config)), cache_mechanism_(nullptr) {}
: config_(config),
impl_(OfflineTtsImpl::Create(mgr, config)),
cache_mechanism_(nullptr) {}

OfflineTts::~OfflineTts() = default;

Expand All @@ -105,14 +109,16 @@ GeneratedAudio OfflineTts::Generate(
// Check if the cache mechanism is active and if the audio is already cached
if (cache_mechanism_) {
int32_t sample_rate;
std::vector<float> samples = cache_mechanism_->GetWavFile(text_hash, sample_rate);
std::vector<float> samples
= cache_mechanism_->GetWavFile(text_hash, &sample_rate);

if (!samples.empty()) {
SHERPA_ONNX_LOGE("Returning cached audio for hash:%s", text_hash.c_str());

// If a callback is provided, call it with the cached audio
if (callback) {
int32_t result = callback(samples.data(), samples.size(), 1.0f /* progress */);
int32_t result
= callback(samples.data(), samples.size(), 1.0f /* progress */);
if (result == 0) {
// If the callback returns 0, stop further processing
SHERPA_ONNX_LOGE("Callback requested to stop processing.");
Expand All @@ -127,7 +133,6 @@ GeneratedAudio OfflineTts::Generate(

// Generate the audio if not cached
GeneratedAudio audio = impl_->Generate(text, sid, speed, std::move(callback));
// SHERPA_ONNX_LOGE("Generated audio: sample rate: %d, sample count: %d", audio.sample_rate, audio.samples.size());

// Cache the generated audio if the cache mechanism is active
if (cache_mechanism_) {
Expand All @@ -148,7 +153,8 @@ void OfflineTts::SetCacheSize(const int32_t cache_size) {
if (cache_size > 0) {
if (!cache_mechanism_) {
// Initialize the cache mechanism if it hasn't been initialized yet
cache_mechanism_ = std::make_unique<CacheMechanism>(config_.cache_dir, cache_size);
cache_mechanism_
= std::make_unique<CacheMechanism>(config_.cache_dir, cache_size);
} else {
// Update the cache size if the cache mechanism is already initialized
cache_mechanism_->SetCacheSize(cache_size);
Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/offline-tts.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class OfflineTts {

// To get total used cache size(for wav files) in bytes
int32_t GetTotalUsedCacheSize();

// Number of supported speakers.
// If it supports only a single speaker, then it return 0 or 1.
int32_t NumSpeakers() const;
Expand Down

0 comments on commit 42d6d24

Please sign in to comment.