Skip to content

Commit

Permalink
Moved from/to vector to tensor.hpp
Browse files Browse the repository at this point in the history
  • Loading branch information
omilyutin-tt committed Dec 16, 2024
1 parent 25b402e commit 59f5c5e
Show file tree
Hide file tree
Showing 10 changed files with 173 additions and 200 deletions.
28 changes: 13 additions & 15 deletions tests/ttnn/unit_tests/gtests/tensor/test_distributed_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
namespace ttnn::distributed::test {

using ::testing::ElementsAre;
using ::ttnn::experimental::xtensor::from_vector;
using ::ttnn::experimental::xtensor::to_vector;

using TensorDistributionTest = T3kMultiDeviceFixture;

Expand All @@ -25,7 +23,7 @@ TensorSpec get_tensor_spec(const ttnn::SimpleShape& shape, DataType dtype) {
}

TEST_F(TensorDistributionTest, Replication) {
Tensor input_tensor = from_vector(
Tensor input_tensor = Tensor::from_vector(
std::vector<float>{42.F, 13.F, -99.F}, get_tensor_spec(ttnn::SimpleShape{1, 1, 1, 3}, DataType::FLOAT32));

auto mapper = replicate_tensor_to_mesh_mapper(*mesh_device_);
Expand All @@ -34,13 +32,13 @@ TEST_F(TensorDistributionTest, Replication) {
std::vector<Tensor> device_tensors = get_device_tensors(replicated_tensor);
EXPECT_EQ(device_tensors.size(), mesh_device_->num_devices());
for (const auto& device_tensor : device_tensors) {
EXPECT_THAT(to_vector<float>(device_tensor), ElementsAre(42.F, 13.F, -99.F));
EXPECT_THAT(device_tensor.to_vector<float>(), ElementsAre(42.F, 13.F, -99.F));
}
}

TEST_F(TensorDistributionTest, Shard1DInvalidDim) {
const int num_devices = mesh_device_->num_devices();
Tensor input_tensor = from_vector(
Tensor input_tensor = Tensor::from_vector(
std::vector<float>(num_devices, 0),
get_tensor_spec(ttnn::SimpleShape{1, 1, 1, num_devices}, DataType::FLOAT32));

Expand All @@ -58,7 +56,7 @@ TEST_F(TensorDistributionTest, Shard1DInvalidDim) {
TEST_F(TensorDistributionTest, Shard1DTooFewShards) {
const int num_devices = mesh_device_->num_devices();
ASSERT_LT(3, num_devices);
Tensor input_tensor = from_vector(
Tensor input_tensor = Tensor::from_vector(
std::vector<float>{42.F, 13.F, -99.F}, get_tensor_spec(ttnn::SimpleShape{1, 1, 1, 3}, DataType::FLOAT32));

EXPECT_ANY_THROW({
Expand All @@ -74,22 +72,22 @@ TEST_F(TensorDistributionTest, Shard1D) {
test_data.insert(test_data.end(), {i * 1.F, i * 2.F, i * 3.F});
}
Tensor input_tensor =
from_vector(test_data, get_tensor_spec(ttnn::SimpleShape{1, num_devices, 3, 1}, DataType::FLOAT32));
Tensor::from_vector(test_data, get_tensor_spec(ttnn::SimpleShape{1, num_devices, 3, 1}, DataType::FLOAT32));

auto mapper = shard_tensor_to_mesh_mapper(*mesh_device_, 1);
Tensor sharded_tensor = distribute_tensor(input_tensor, *mesh_device_, *mapper);

std::vector<Tensor> device_tensors = get_device_tensors(sharded_tensor);
EXPECT_EQ(device_tensors.size(), mesh_device_->num_devices());
for (int i = 0; i < device_tensors.size(); i++) {
EXPECT_THAT(to_vector<float>(device_tensors[i]), ElementsAre(i * 1.F, i * 2.F, i * 3.F));
EXPECT_THAT(device_tensors[i].to_vector<float>(), ElementsAre(i * 1.F, i * 2.F, i * 3.F));
}

auto composer = concat_mesh_to_tensor_composer(/*dim=*/0);
Tensor concatenated_tensor = aggregate_tensor(sharded_tensor, *composer);

Tensor expected_tensor =
from_vector(test_data, get_tensor_spec(ttnn::SimpleShape{num_devices, 1, 3, 1}, DataType::FLOAT32));
Tensor::from_vector(test_data, get_tensor_spec(ttnn::SimpleShape{num_devices, 1, 3, 1}, DataType::FLOAT32));
EXPECT_TRUE(ttnn::allclose<float>(concatenated_tensor, expected_tensor));
}

Expand Down Expand Up @@ -121,7 +119,7 @@ TEST_F(TensorDistributionTest, Shard2DReplicateDim) {

std::vector<float> test_data = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0};
Tensor input_tensor =
from_vector(test_data, get_tensor_spec(ttnn::SimpleShape{1, num_rows, num_cols, 1}, DataType::FLOAT32));
Tensor::from_vector(test_data, get_tensor_spec(ttnn::SimpleShape{1, num_rows, num_cols, 1}, DataType::FLOAT32));
input_tensor.print();

auto mapper = shard_tensor_2d_to_mesh_mapper(
Expand All @@ -138,10 +136,10 @@ TEST_F(TensorDistributionTest, Shard2DReplicateDim) {

int i = 0;
for (; i < 4; i++) {
EXPECT_THAT(to_vector<float>(device_tensors[i]), ElementsAre(0.0, 1.0, 2.0, 3.0));
EXPECT_THAT(device_tensors[i].to_vector<float>(), ElementsAre(0.0, 1.0, 2.0, 3.0));
}
for (; i < device_tensors.size(); i++) {
EXPECT_THAT(to_vector<float>(device_tensors[i]), ElementsAre(4.0, 5.0, 6.0, 7.0));
EXPECT_THAT(device_tensors[i].to_vector<float>(), ElementsAre(4.0, 5.0, 6.0, 7.0));
}
}

Expand All @@ -156,7 +154,7 @@ TEST_F(TensorDistributionTest, Shard2D) {
test_data.insert(test_data.end(), {i * 1.F, i * 2.F, i * 3.F});
}
Tensor input_tensor =
from_vector(test_data, get_tensor_spec(ttnn::SimpleShape{1, num_rows, num_cols, 3}, DataType::FLOAT32));
Tensor::from_vector(test_data, get_tensor_spec(ttnn::SimpleShape{1, num_rows, num_cols, 3}, DataType::FLOAT32));

auto mapper = shard_tensor_2d_to_mesh_mapper(
*mesh_device_,
Expand All @@ -170,7 +168,7 @@ TEST_F(TensorDistributionTest, Shard2D) {
std::vector<Tensor> device_tensors = get_device_tensors(sharded_tensor);
EXPECT_EQ(device_tensors.size(), mesh_device_->num_devices());
for (int i = 0; i < device_tensors.size(); i++) {
EXPECT_THAT(to_vector<float>(device_tensors[i]), ElementsAre(i * 1.F, i * 2.F, i * 3.F));
EXPECT_THAT(device_tensors[i].to_vector<float>(), ElementsAre(i * 1.F, i * 2.F, i * 3.F));
}

auto composer = concat_mesh_2d_to_tensor_composer(
Expand All @@ -182,7 +180,7 @@ TEST_F(TensorDistributionTest, Shard2D) {
Tensor concatenated_tensor = aggregate_tensor(sharded_tensor, *composer);

Tensor expected_tensor =
from_vector(test_data, get_tensor_spec(ttnn::SimpleShape{num_rows, 1, num_cols, 3}, DataType::FLOAT32));
Tensor::from_vector(test_data, get_tensor_spec(ttnn::SimpleShape{num_rows, 1, num_cols, 3}, DataType::FLOAT32));
EXPECT_TRUE(ttnn::allclose<float>(concatenated_tensor, expected_tensor));
}

Expand Down
1 change: 0 additions & 1 deletion tests/ttnn/unit_tests/gtests/tensor/test_partition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ using ::testing::SizeIs;
using ::tt::tt_metal::Tensor;
using ::ttnn::experimental::xtensor::chunk;
using ::ttnn::experimental::xtensor::concatenate;
using ::ttnn::experimental::xtensor::from_vector;

TEST(PartitionTest, ChunkBasicNonDivisible3) {
// Create a 1D tensor: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Expand Down
23 changes: 12 additions & 11 deletions tests/ttnn/unit_tests/gtests/tensor/test_vector_conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ namespace {

using ::testing::Eq;
using ::testing::Pointwise;
using ::ttnn::experimental::xtensor::from_vector;
using ::ttnn::experimental::xtensor::to_vector;

const std::vector<ttnn::SimpleShape>& get_shapes_for_test() {
static auto* shapes = new std::vector<ttnn::SimpleShape>{
Expand Down Expand Up @@ -58,8 +56,8 @@ TYPED_TEST_SUITE(VectorConversionTest, TestTypes);
TYPED_TEST(VectorConversionTest, Roundtrip) {
for (const auto& shape : get_shapes_for_test()) {
auto input = arange<TypeParam>(0, static_cast<int64_t>(shape.volume()), 1);
auto output =
to_vector<TypeParam>(from_vector(input, get_tensor_spec(shape, convert_to_data_type<TypeParam>())));
auto output = Tensor::from_vector(input, get_tensor_spec(shape, convert_to_data_type<TypeParam>()))
.template to_vector<TypeParam>();
EXPECT_THAT(output, Pointwise(Eq(), input)) << "for shape: " << shape;
}
}
Expand All @@ -69,18 +67,20 @@ TYPED_TEST(VectorConversionTest, InvalidSize) {
auto input = arange<TypeParam>(0, 42, 1);

ASSERT_NE(input.size(), shape.volume());
EXPECT_ANY_THROW(from_vector(input, get_tensor_spec(shape, convert_to_data_type<TypeParam>())));
EXPECT_ANY_THROW(Tensor::from_vector(input, get_tensor_spec(shape, convert_to_data_type<TypeParam>())));
}

TYPED_TEST(VectorConversionTest, RoundtripTilezedLayout) {
ttnn::SimpleShape shape{128, 128};

auto input = arange<TypeParam>(0, shape.volume(), 1);
// TODO: Support this.
EXPECT_ANY_THROW(from_vector(input, get_tensor_spec(shape, convert_to_data_type<TypeParam>(), Layout::TILE)));
EXPECT_ANY_THROW(
Tensor::from_vector(input, get_tensor_spec(shape, convert_to_data_type<TypeParam>(), Layout::TILE)));

auto output = to_vector<TypeParam>(
from_vector(input, get_tensor_spec(shape, convert_to_data_type<TypeParam>())).to(Layout::TILE));
auto output = Tensor::from_vector(input, get_tensor_spec(shape, convert_to_data_type<TypeParam>()))
.to(Layout::TILE)
.template to_vector<TypeParam>();
EXPECT_THAT(output, Pointwise(Eq(), input));
}

Expand All @@ -89,7 +89,7 @@ TYPED_TEST(VectorConversionTest, InvalidDtype) {
auto input = arange<TypeParam>(0, 42, 1);

ASSERT_NE(input.size(), shape.volume());
EXPECT_ANY_THROW(from_vector(
EXPECT_ANY_THROW(Tensor::from_vector(
input,
get_tensor_spec(
shape,
Expand All @@ -106,10 +106,11 @@ TEST(FloatVectorConversionTest, RoundtripBfloat16Representation) {
return bf.to_float();
});

auto output_bf16 = to_vector<bfloat16>(from_vector(input_ft, get_tensor_spec(shape, DataType::BFLOAT16)));
auto output_bf16 =
Tensor::from_vector(input_ft, get_tensor_spec(shape, DataType::BFLOAT16)).to_vector<bfloat16>();
EXPECT_THAT(output_bf16, Pointwise(Eq(), input_bf16)) << "for shape: " << shape;

auto output_ft = to_vector<float>(from_vector(input_bf16, get_tensor_spec(shape, DataType::BFLOAT16)));
auto output_ft = Tensor::from_vector(input_bf16, get_tensor_spec(shape, DataType::BFLOAT16)).to_vector<float>();
EXPECT_THAT(output_ft, Pointwise(Eq(), input_ft)) << "for shape: " << shape;
}
}
Expand Down
2 changes: 1 addition & 1 deletion tt-train/sources/ttml/core/tt_tensor_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ T get_median(std::vector<T>& vec) {
template <typename T>
void print_tensor_stats_(const tt::tt_metal::Tensor& tensor, const std::string& name) {
auto tensor_shape = tensor.get_shape();
auto tensor_vec = ttml::core::to_vector<T>(tensor);
auto tensor_vec = tensor.to_vector<T>();

auto median = get_median(tensor_vec);
auto mean = std::accumulate(tensor_vec.begin(), tensor_vec.end(), 0.F) / static_cast<float>(tensor_vec.size());
Expand Down
4 changes: 2 additions & 2 deletions tt-train/sources/ttml/core/tt_tensor_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ template <class VectorType = float, DataType TensorType = DataType::BFLOAT16>

template <class T = float>
[[nodiscard]] std::vector<T> to_vector(const tt::tt_metal::Tensor& tensor) {
return ttnn::experimental::xtensor::to_vector<T>(tensor);
return tensor.to_vector<T>();
}

[[nodiscard]] bool is_tensor_initialized(const tt::tt_metal::Tensor& tensor);
Expand All @@ -56,7 +56,7 @@ template <class T = float, DataType TensorType = DataType::BFLOAT16>

template <class T = float>
[[nodiscard]] xt::xarray<T> to_xtensor(const tt::tt_metal::Tensor& tensor) {
auto vec = to_vector<T>(tensor);
auto vec = tensor.to_vector<T>();
const auto& shape = tensor.get_shape().logical_shape();
std::vector<size_t> shape_vec(shape.cbegin(), shape.cend());
return xt::adapt(std::move(vec), shape_vec);
Expand Down
1 change: 0 additions & 1 deletion ttnn/cpp/ttnn/tensor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ set(TENSOR_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/layout/page_config.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layout/size.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layout/tensor_layout.cpp
${CMAKE_CURRENT_SOURCE_DIR}/xtensor/conversion_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/xtensor/partition.cpp
CACHE INTERNAL
"Tensor sources to reuse in ttnn build"
Expand Down
134 changes: 124 additions & 10 deletions ttnn/cpp/ttnn/tensor/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,9 @@

using namespace tt::constants;

namespace tt {

namespace tt_metal {

namespace tt::tt_metal {
namespace {
namespace CMAKE_UNIQUE_NAMESPACE {

MemoryConfig extract_memory_config(const Storage& storage) {
return std::visit(
[](const auto& storage) -> MemoryConfig {
Expand All @@ -50,7 +47,60 @@ MemoryConfig extract_memory_config(const Storage& storage) {
},
storage);
}
} // namespace CMAKE_UNIQUE_NAMESPACE

template <typename T>
Tensor create_owned_tensor_from_span(tt::stl::Span<const T> data, const TensorSpec& spec) {
// TODO: support tilized layouts.
TT_FATAL(spec.layout() == Layout::ROW_MAJOR, "Unsupported layout: {}", spec.layout());
auto buffer = tt::tt_metal::owned_buffer::create(std::vector<T>(data.begin(), data.end()));
auto storage = OwnedStorage{std::move(buffer)};
return Tensor{std::move(storage), spec};
}

// TODO: optimize precomputing multipliers
template <typename T, typename InternalType>
std::vector<T> untile_tensor_to_vec(const Tensor& cpu_tensor) {
auto tiled_buffer = tt::tt_metal::host_buffer::get_as<InternalType>(cpu_tensor);
auto untiled_shape = cpu_tensor.get_logical_shape();
auto tiled_shape = cpu_tensor.get_padded_shape();

// Calculate total size of the untiled tensor
size_t total_size = untiled_shape.volume();

std::vector<T> untiled_data(total_size);

auto compute_flat_index = [](const std::vector<uint32_t>& indices, ttnn::SimpleShape& shape) -> uint32_t {
uint32_t flat_index = 0;
uint32_t multiplier = 1;
for (int i = (int)indices.size() - 1; i >= 0; --i) {
flat_index += indices[i] * multiplier;
multiplier *= shape[i];
}
return flat_index;
};

std::vector<uint32_t> indices(tiled_shape.rank(), 0);

for (size_t idx = 0; idx < total_size; ++idx) {
uint32_t untiled_index = compute_flat_index(indices, untiled_shape);
uint32_t tiled_index = compute_flat_index(indices, tiled_shape);
if constexpr (std::is_same_v<InternalType, bfloat16>) {
untiled_data[untiled_index] = tiled_buffer[tiled_index].to_float();
} else {
untiled_data[untiled_index] = tiled_buffer[tiled_index];
}

for (int dim = (int)tiled_shape.rank() - 1; dim >= 0; --dim) {
if (++indices[dim] < untiled_shape[dim]) {
break;
}
indices[dim] = 0;
}
}

return untiled_data;
}

} // namespace

Tensor::TensorAttributes::TensorAttributes() :
Expand Down Expand Up @@ -111,7 +161,7 @@ Tensor::Tensor(
tile->get_tile_shape());
}
}
auto memory_config = CMAKE_UNIQUE_NAMESPACE::extract_memory_config(storage);
auto memory_config = extract_memory_config(storage);
init(
std::move(storage),
TensorSpec(
Expand Down Expand Up @@ -559,6 +609,72 @@ const Storage& Tensor::get_storage() const {
return this->tensor_attributes->storage;
}

template <>
Tensor Tensor::from_span<float>(tt::stl::Span<const float> buffer, const TensorSpec& spec) {
size_t volume = spec.logical_shape().volume();
TT_FATAL(
buffer.size() == volume, "Current buffer size is {} different from shape volume {}", buffer.size(), volume);
if (spec.data_type() == DataType::FLOAT32) {
return create_owned_tensor_from_span(buffer, spec);
} else if (spec.data_type() == DataType::BFLOAT16) {
std::vector<bfloat16> bfloat16_data;
bfloat16_data.reserve(buffer.size());
std::transform(std::begin(buffer), std::end(buffer), std::back_inserter(bfloat16_data), [](float value) {
return bfloat16(value);
});
return create_owned_tensor_from_span(
tt::stl::Span<const bfloat16>(bfloat16_data.data(), bfloat16_data.size()), spec);
} else {
// TODO: support bf8 and bf4
TT_THROW("Unsupported data type for from_span<float>: {}", spec.data_type());
}
}

template <typename T>
Tensor Tensor::from_span(tt::stl::Span<const T> buffer, const TensorSpec& spec) {
size_t volume = spec.logical_shape().volume();
TT_FATAL(
buffer.size() == volume, "Current buffer size is {} different from shape volume {}", buffer.size(), volume);
TT_FATAL(
spec.data_type() == convert_to_data_type<T>(),
"Unsupported data type for from_span: got {}, expected: {}",
spec.data_type(),
convert_to_data_type<T>());
return create_owned_tensor_from_span(buffer, spec);
}

template <>
std::vector<float> Tensor::to_vector<float>() const {
auto cpu_tensor = this->cpu().to(Layout::ROW_MAJOR);
if (cpu_tensor.get_dtype() == DataType::BFLOAT16) {
return untile_tensor_to_vec<float, bfloat16>(cpu_tensor);
} else if (cpu_tensor.get_dtype() == DataType::FLOAT32) {
return untile_tensor_to_vec<float, float>(cpu_tensor);
} else {
// TODO: support bf4, bf8.
TT_THROW("Cannot convert tensor to vector for data type: {}", cpu_tensor.get_dtype());
}
}

template <typename T>
std::vector<T> Tensor::to_vector() const {
auto cpu_tensor = this->cpu().to(Layout::ROW_MAJOR);
TT_FATAL(
cpu_tensor.get_dtype() == convert_to_data_type<T>(),
"Unsupported data type for to_vector: got {}, expected: {}",
cpu_tensor.get_dtype(),
convert_to_data_type<T>());
return untile_tensor_to_vec<T, T>(cpu_tensor);
}

// Instantiate explicitly for the supported types.
template Tensor Tensor::from_span<bfloat16>(tt::stl::Span<const bfloat16> buffer, const TensorSpec& spec);
template Tensor Tensor::from_span<uint32_t>(tt::stl::Span<const uint32_t> buffer, const TensorSpec& spec);
template Tensor Tensor::from_span<int32_t>(tt::stl::Span<const int32_t> buffer, const TensorSpec& spec);
template std::vector<bfloat16> Tensor::to_vector<bfloat16>() const;
template std::vector<uint32_t> Tensor::to_vector<uint32_t>() const;
template std::vector<int32_t> Tensor::to_vector<int32_t>() const;

Tensor Tensor::to(Device* target_device, const MemoryConfig& mem_config,uint8_t cq_id,
const std::vector<SubDeviceId>& sub_device_ids) const {
return tensor_ops::tensor_to(*this, target_device, mem_config, cq_id, sub_device_ids);
Expand Down Expand Up @@ -959,6 +1075,4 @@ bool validate_worker_modes(const std::vector<Device*>& workers) {
return worker_modes_match;
}

} // namespace tt_metal

} // namespace tt
} // namespace tt::tt_metal
Loading

0 comments on commit 59f5c5e

Please sign in to comment.