From 54fe6043bc4a0ad251efe26f85b4e734463a4918 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 17 Dec 2024 12:04:37 -0800 Subject: [PATCH] fix deletion of non-evaled arrays with siblings --- mlx/array.cpp | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/mlx/array.cpp b/mlx/array.cpp index 70ecab40d..eeda019a7 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -277,7 +277,19 @@ array::ArrayDesc::~ArrayDesc() { } ad.inputs.clear(); for (auto& [_, a] : input_map) { - if (a.array_desc_.use_count() <= a.siblings().size() + 1) { + bool is_deletable = + (a.array_desc_.use_count() <= a.siblings().size() + 1); + // An array with siblings is deletable only if all of its siblings + // are deletable + for (auto& s : a.siblings()) { + if (!is_deletable) { + break; + } + int is_input = (input_map.find(s.id()) != input_map.end()); + is_deletable &= + s.array_desc_.use_count() <= a.siblings().size() + is_input; + } + if (is_deletable) { for_deletion.push_back(std::move(a.array_desc_)); } }