Skip to content

Commit

Permalink
more random cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar committed Nov 9, 2023
1 parent e1064e0 commit 461f643
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 141 deletions.
102 changes: 0 additions & 102 deletions torch_xla/csrc/runtime/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,57 +24,6 @@ namespace torch_xla {
namespace runtime {
namespace util {

template <typename F>
xla::Status CheckedCall(const F& fn) {
try {
fn();
} catch (const std::exception& ex) {
return tsl::errors::Internal(ex.what());
}
return xla::Status();
}

template <typename T>
class Cleanup {
public:
using StatusType = T;

explicit Cleanup(std::function<void(StatusType)> func)
: func_(std::move(func)) {}
Cleanup(Cleanup&& ref)
: func_(std::move(ref.func_)), status_(std::move(ref.status_)) {}
Cleanup(const Cleanup&) = delete;

~Cleanup() {
if (func_ != nullptr) {
func_(std::move(status_));
}
}

Cleanup& operator=(const Cleanup&) = delete;

Cleanup& operator=(Cleanup&& ref) {
if (this != &ref) {
func_ = std::move(ref.func_);
status_ = std::move(ref.status_);
}
return *this;
}

void Release() { func_ = nullptr; }

void SetStatus(StatusType status) { status_ = std::move(status); }

const StatusType& GetStatus() const { return status_; }

private:
std::function<void(StatusType)> func_;
StatusType status_;
};

using ExceptionCleanup = Cleanup<std::exception_ptr>;
using StatusCleanup = Cleanup<xla::Status>;

// Allows APIs which might return const references and values, to not be forced
// to return values in the signature.
template <typename T>
Expand All @@ -96,10 +45,6 @@ class MaybeRef {
const T& ref_;
};

struct MidPolicy {
size_t operator()(size_t size) const { return size / 2; }
};

template <class T>
class MaybePtr {
public:
Expand All @@ -121,48 +66,6 @@ class MaybePtr {
absl::optional<T> storage_;
};

template <typename C>
std::vector<const typename C::value_type::element_type*> GetConstSharedPointers(
const C& shared_pointers) {
std::vector<const typename C::value_type::element_type*> pointers;
pointers.reserve(shared_pointers.size());
for (auto& shared_pointer : shared_pointers) {
pointers.push_back(shared_pointer.get());
}
return pointers;
}

template <typename C>
std::vector<typename C::value_type::element_type*> GetSharedPointers(
const C& shared_pointers) {
std::vector<typename C::value_type::element_type*> pointers;
pointers.reserve(shared_pointers.size());
for (auto& shared_pointer : shared_pointers) {
pointers.push_back(shared_pointer.get());
}
return pointers;
}

template <typename C, typename K, typename T, typename F>
void InsertCombined(C* map, const K& key, const T& value, const F& combiner) {
auto it = map->find(key);
if (it == map->end()) {
map->emplace(key, value);
} else {
it->second = combiner(it->second, value);
}
}

template <typename T>
std::vector<T> Iota(size_t size, T init = 0, T incr = 1) {
std::vector<T> result(size);
T value = init;
for (size_t i = 0; i < size; ++i, value += incr) {
result[i] = value;
}
return result;
}

template <typename T>
std::vector<T> Range(T start, T end, T step = 1) {
std::vector<T> result;
Expand Down Expand Up @@ -220,11 +123,6 @@ const typename T::mapped_type& MapInsert(T* cont,
return it->second;
}

template <typename T>
typename std::underlying_type<T>::type GetEnumValue(T value) {
return static_cast<typename std::underlying_type<T>::type>(value);
}

template <typename T, typename S>
T Multiply(const S& input) {
return std::accumulate(input.begin(), input.end(), T(1),
Expand Down
38 changes: 0 additions & 38 deletions torch_xla/csrc/runtime/util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,36 +15,6 @@ namespace util {

using ::testing::ElementsAre;

TEST(UtilTest, Cleanup) {
bool notify = false;

// Set to true.
{
Cleanup<bool> c([&notify](bool b) { notify = b; });
c.SetStatus(true);
}
EXPECT_TRUE(notify);

// Set to false.
{
Cleanup<bool> c([&notify](bool b) { notify = b; });
c.SetStatus(false);
}
EXPECT_FALSE(notify);

// Releasing the cleanup will not change the `notify` to true.
{
Cleanup<bool> c([&notify](bool b) { notify = b; });
c.SetStatus(true);
c.Release();
}
EXPECT_FALSE(notify);
}

TEST(UtilTest, Iota) {
EXPECT_THAT(Iota<int16_t>(5, 0, 2), ElementsAre(0, 2, 4, 6, 8));
}

TEST(UtilTest, Range) {
EXPECT_THAT(Range<int16_t>(0, 10, 2), ElementsAre(0, 2, 4, 6, 8));
EXPECT_THAT(Range<int16_t>(10, 0, -2), ElementsAre(10, 8, 6, 4, 2));
Expand Down Expand Up @@ -75,14 +45,6 @@ TEST(UtilTest, MapInsert) {
EXPECT_EQ(MapInsert(&v, 1, [] { return 12; }), 1);
}

TEST(UtilTest, GetEnumValue) {
enum E { A = 0, B, C, D };
EXPECT_EQ(GetEnumValue(E::A), 0);
EXPECT_EQ(GetEnumValue(E::B), 1);
EXPECT_EQ(GetEnumValue(E::C), 2);
EXPECT_EQ(GetEnumValue(E::D), 3);
}

TEST(UtilTest, Multiply) {
std::vector<int32_t> t = {1, 2, 3, 4, 5};
EXPECT_EQ(Multiply<int32_t>(t), 120);
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/tensor_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ namespace {
struct DataAsync {
std::vector<runtime::ComputationClient::TensorSource> source_tensors;
std::vector<torch::lazy::BackendDataPtr> async_datas;
std::vector<runtime::util::ExceptionCleanup> handle_unlockers;
std::vector<torch::lazy::ExceptionCleanup> handle_unlockers;
};

bool ShouldUseBF16() {
Expand Down

0 comments on commit 461f643

Please sign in to comment.