Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a replace operation #61

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ set(mlxdata-src
${CMAKE_CURRENT_LIST_DIR}/mlx/data/op/Squeeze.cpp
${CMAKE_CURRENT_LIST_DIR}/mlx/data/op/Tokenize.cpp
${CMAKE_CURRENT_LIST_DIR}/mlx/data/op/ImageTransform.cpp
${CMAKE_CURRENT_LIST_DIR}/mlx/data/op/RemoveValue.cpp)
${CMAKE_CURRENT_LIST_DIR}/mlx/data/op/RemoveValue.cpp
${CMAKE_CURRENT_LIST_DIR}/mlx/data/op/Replace.cpp)

if(AWSSDK_FOUND)
list(APPEND mlxdata-src
Expand Down
26 changes: 26 additions & 0 deletions mlx/data/Dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "mlx/data/op/ReadFromTAR.h"
#include "mlx/data/op/RemoveValue.h"
#include "mlx/data/op/RenameKey.h"
#include "mlx/data/op/Replace.h"
#include "mlx/data/op/SampleTransform.h"
#include "mlx/data/op/SaveImage.h"
#include "mlx/data/op/Shape.h"
Expand Down Expand Up @@ -633,6 +634,31 @@ T Dataset<T, B>::remove_value_if(
}
}

template <class T, class B>
T Dataset<T, B>::replace(
const std::string& key,
const std::string& old,
const std::string& replacement,
int count) {
return transform_(
std::make_shared<op::Replace>(key, old, replacement, count));
}

template <class T, class B>
T Dataset<T, B>::replace_if(
bool cond,
const std::string& key,
const std::string& old,
const std::string& replacement,
int count) {
if (cond) {
return transform_(
std::make_shared<op::Replace>(key, old, replacement, count));
} else {
return T(self_);
}
}

template <class T, class B>
T Dataset<T, B>::rename_key(const std::string& ikey, const std::string& okey)
const {
Expand Down
12 changes: 12 additions & 0 deletions mlx/data/Dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,18 @@ class Dataset {
double value,
double pad) const;

T replace(
const std::string& key,
const std::string& old,
const std::string& replacement,
int count = -1);
T replace_if(
bool cond,
const std::string& key,
const std::string& old,
const std::string& replacement,
int count = -1);

T rename_key(const std::string& ikey, const std::string& okey) const;
T rename_key_if(bool cond, const std::string& ikey, const std::string& okey)
const;
Expand Down
73 changes: 71 additions & 2 deletions mlx/data/core/Utils.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.

#include "mlx/data/core/Utils.h"
#include <iostream>

namespace {

Expand Down Expand Up @@ -57,6 +56,7 @@ void uniq_t(
}
}
}

template <class T>
void remove_t(
std::shared_ptr<Array> dst,
Expand Down Expand Up @@ -102,6 +102,65 @@ void remove_t(
}
}
}

template <typename T>
void replace_t(
std::shared_ptr<Array>& result,
const std::shared_ptr<Array> src,
const std::shared_ptr<Array> old,
const std::shared_ptr<Array> replacement,
int count) {
int64_t src_size = src->size();
int64_t old_size = old->size();
int64_t replacement_size = replacement->size();

T* src_buffer = src->data<T>();
T* old_buffer = old->data<T>();
T* replacement_buffer = replacement->data<T>();

// Calculate the result size. If this ends up being slow we can try
// a single pass algorithm that grows the buffer using realloc. We can also
// try a better search algorithm because this has a worst case complexity
// O(src_size old_size).
int64_t result_size = src_size;
int matches = 0;
if (old_size != replacement_size) {
for (int64_t i = 0; i < src_size; i++) {
if (std::equal(old_buffer, old_buffer + old_size, src_buffer + i)) {
i += old_size - 1;
result_size += replacement_size - old_size;
matches++;
}
if (matches == count) {
break;
}
}
}

result = std::make_shared<Array>(src->type(), result_size);
T* result_buffer = result->data<T>();

matches = 0;
for (int64_t i = 0, j = 0; i < src_size; i++, j++) {
if (std::equal(old_buffer, old_buffer + old_size, src_buffer + i)) {
std::copy(
replacement_buffer,
replacement_buffer + replacement_size,
result_buffer + j);
i += old_size - 1;
j += replacement_size - 1;
matches++;
} else {
result_buffer[j] = src_buffer[i];
}
if (matches == count) {
std::copy(
src_buffer + i + 1, src_buffer + src_size, result_buffer + j + 1);
break;
}
}
}

} // namespace
namespace mlx {
namespace data {
Expand Down Expand Up @@ -192,6 +251,16 @@ Sample merge_batch(
return sample_batch;
}

std::shared_ptr<Array> replace(
const std::shared_ptr<Array> src,
const std::shared_ptr<Array> old,
const std::shared_ptr<Array> replacement,
int count) {
std::shared_ptr<Array> result;
ARRAY_DISPATCH(src, replace_t, result, src, old, replacement, count);
return result;
}

} // namespace core
} // namespace data
} // namespace mlx
8 changes: 7 additions & 1 deletion mlx/data/core/Utils.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.

#include "mlx/data/Array.h"
#include "mlx/data/Sample.h"
Expand All @@ -20,6 +20,12 @@ std::pair<std::shared_ptr<Array>, std::shared_ptr<Array>> remove(
double value,
double pad);

std::shared_ptr<Array> replace(
const std::shared_ptr<Array> src,
const std::shared_ptr<Array> old,
const std::shared_ptr<Array> replacement,
int count);

Sample merge_batch(
const std::vector<Sample>& samples,
const std::unordered_map<std::string, double>& pad_values = {},
Expand Down
30 changes: 30 additions & 0 deletions mlx/data/op/Replace.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright © 2024 Apple Inc.

#include "mlx/data/op/Replace.h"
#include "mlx/data/core/Utils.h"

namespace mlx {
namespace data {
namespace op {

Replace::Replace(
const std::string& key,
const std::string& old,
const std::string& replacement,
int count)
: key_(key),
old_(std::make_shared<Array>(old)),
replacement_(std::make_shared<Array>(replacement)),
count_(count) {}

Sample Replace::apply(const Sample& sample) const {
auto value = sample::check_key(sample, key_, old_->type());
value = core::replace(value, old_, replacement_, count_);
auto new_sample = sample;
new_sample[key_] = value;
return new_sample;
}

} // namespace op
} // namespace data
} // namespace mlx
30 changes: 30 additions & 0 deletions mlx/data/op/Replace.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright © 2024 Apple Inc.

#pragma once

#include "mlx/data/op/Op.h"

namespace mlx {
namespace data {
namespace op {

class Replace : public Op {
public:
Replace(
const std::string& key,
const std::string& old,
const std::string& replacement,
int count);

virtual Sample apply(const Sample& sample) const override;

private:
std::string key_;
std::shared_ptr<Array> old_;
std::shared_ptr<Array> replacement_;
int count_;
};

} // namespace op
} // namespace data
} // namespace mlx
36 changes: 36 additions & 0 deletions python/src/wrap_dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,42 @@ void mlx_data_export_dataset(py::class_<T, P>& base) {
py::arg("pad") = 0,
"Conditional :meth:`Buffer.remove_value`.");

base.def(
"replace",
&T::replace,
py::call_guard<py::gil_scoped_release>(),
py::arg("key"),
py::arg("old"),
py::arg("replacement"),
py::arg("count") = -1,
R"pbdoc(
Replace ``old`` with ``replacement`` in the array at ``key``.

Example:

.. code-block:: python

# Replace ' ' with '▁' to prepare for SPM tokenization.
dset = dset.replace("text", " ", "\u2581")

Args:
key (str): The sample key that contains the array we are operating on.
old (str): The character sequence that we are replacing.
replacement (str): The character sequence that we are replacing with.
count (int): Perform at most ``count`` replacements. Ignore if negative.
Default: ``-1``.
)pbdoc");
base.def(
"replace_if",
&T::replace_if,
py::call_guard<py::gil_scoped_release>(),
py::arg("cond"),
py::arg("key"),
py::arg("old"),
py::arg("replacement"),
py::arg("count") = -1,
"Conditional :meth:`Buffer.replace`.");

base.def(
"rename_key",
&T::rename_key,
Expand Down
8 changes: 6 additions & 2 deletions python/tests/test_buffer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Copyright © 2024 Apple Inc.

from unittest import TestCase
import unittest

import mlx.data as dx


class TestBuffer(TestCase):
class TestBuffer(unittest.TestCase):
def test__getitem__(self):
n = 5
b = dx.buffer_from_vector(list(dict(i=i) for i in range(n)))
Expand All @@ -18,3 +18,7 @@ def test__getitem__(self):
_ = b[n]
with self.assertRaises(IndexError):
_ = b[-(n + 1)]


if __name__ == "__main__":
unittest.main()
24 changes: 24 additions & 0 deletions python/tests/test_replace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright © 2024 Apple Inc.

import unittest

import mlx.data as dx


class TestReplace(unittest.TestCase):
def test_replace(self):
s = "Hello world".encode()
dset = dx.buffer_from_vector([dict(text=s)])

ds = dset.replace("text", "world", "everybody!")
self.assertEqual(bytes(ds[0]["text"]), b"Hello everybody!")

ds = dset.replace("text", "l", "b")
self.assertEqual(bytes(ds[0]["text"]), b"Hebbo worbd")

ds = dset.replace("text", "l", "b", 2)
self.assertEqual(bytes(ds[0]["text"]), b"Hebbo world")


if __name__ == "__main__":
unittest.main()