Skip to content

Commit

Permalink
Fix x86 tests (#1691)
Browse files Browse the repository at this point in the history
* fix x86 tests

* comment
  • Loading branch information
awni authored Dec 11, 2024
1 parent 4f9b60d commit f3dfa36
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 9 deletions.
2 changes: 1 addition & 1 deletion mlx/backend/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ add_custom_command(
COMMAND
/bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER}
${PROJECT_SOURCE_DIR} ${CLANG}
${PROJECT_SOURCE_DIR} ${CLANG} ${CMAKE_SYSTEM_PROCESSOR}
DEPENDS make_compiled_preamble.sh
compiled_preamble.h
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
Expand Down
11 changes: 6 additions & 5 deletions mlx/backend/common/make_compiled_preamble.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@ OUTPUT_FILE=$1
GCC=$2
SRCDIR=$3
CLANG=$4
ARCH=$5

if [ "$CLANG" = "TRUE" ]; then
read -r -d '' INCLUDES <<- EOM
#include <cmath>
#include <complex>
#include <cstdint>
#include <vector>
#include <cmath>
#include <complex>
#include <cstdint>
#include <vector>
EOM
CC_FLAGS=""
CC_FLAGS="-arch ${ARCH}"
else
CC_FLAGS="-std=c++17"
fi
Expand Down
4 changes: 2 additions & 2 deletions tests/autograd_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,10 +437,10 @@ TEST_CASE("test op vjps") {
// Test erf
{
auto out = vjp([](array in) { return erf(in); }, array(inf), array(1.0f));
CHECK_EQ(out.second.item<float>(), 0.0f);
CHECK_EQ(out.second.item<float>(), doctest::Approx(0.0f));

out = vjp([](array in) { return erf(in); }, array(-inf), array(2.0f));
CHECK_EQ(out.second.item<float>(), 0.0f);
CHECK_EQ(out.second.item<float>(), doctest::Approx(0.0f));

out = vjp([](array in) { return erf(in); }, array(0.0f), array(1.0f));
CHECK_EQ(out.second.item<float>(), static_cast<float>(M_2_SQRTPI));
Expand Down
2 changes: 1 addition & 1 deletion tests/ops_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1227,7 +1227,7 @@ TEST_CASE("test arithmetic unary ops") {
CHECK(array_equal(exp(array({})), array({})).item<bool>());

x = array(neginf);
CHECK_EQ(exp(x).item<float>(), 0.0f);
CHECK_EQ(exp(x).item<float>(), doctest::Approx(0.0f));

// Integer input type
x = array(2);
Expand Down

0 comments on commit f3dfa36

Please sign in to comment.