diff --git a/Makefile b/Makefile index 1485357..ae9ef19 100644 --- a/Makefile +++ b/Makefile @@ -10,14 +10,14 @@ lint: ${LIBS} cpplint ./lib/*/* cpplint ./test/*/* -build: ${LIBS} +build: ${LIBS} ${TESTS} mkdir -p ${BUILD_DIR}; \ cd build; \ cmake ..; \ make -j 2 -test: ${BUILD_DIR} - ${BUILD_DIR}/test/mathTest +test: build + ${BUILD_DIR}/test/runTest format: clang-format -i $(FILE) diff --git a/README.md b/README.md index 39e69e2..0412a89 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ # Library List ## データ構造 - [x] UnionFind -- [ ] SegmentTree +- [x] SegmentTree - [ ] BIT(Binary-Indexed-Tree) - [ ] Treap diff --git a/lib/DataStructure/segment_tree.h b/lib/DataStructure/segment_tree.h new file mode 100644 index 0000000..732ac4d --- /dev/null +++ b/lib/DataStructure/segment_tree.h @@ -0,0 +1,61 @@ +#pragma once +#include "template.h" + +// 0-indexed +template +class SegmentTree { + private: + using Func = function; + int n; // 最下段の数 + vector segmentTree; // セグ木本体 + const Func f; // 二項演算 + const Monoid identityElement; // モノイドの単位元 + + public: + SegmentTree(vector vec, const Func f, const Monoid identityElement); + void Update(int idx, Monoid val); + Monoid Query(int a, int b, int k = 0, int l = 0, + int r = -1); // 使う時は区間[a, b)のみ指定すれば良い + Monoid GetNum(int idx); // 元の要素番号から最下層の値を取得 +}; + +template +SegmentTree::SegmentTree(vector vec, const Func f, + const Monoid identityElement) + : f(f), identityElement(identityElement) { + int sz = vec.size(); + n = 1; + while (n < sz) n *= 2; + segmentTree.assign(2 * n - 1, identityElement); + for (int i = 0; i < sz; i++) segmentTree[i + n - 1] = vec[i]; + for (int i = n - 2; i >= 0; i--) + segmentTree[i] = f(segmentTree[2 * i + 1], segmentTree[2 * i + 2]); +} + +template +void SegmentTree::Update(int idx, Monoid val) { + idx += n - 1; + segmentTree[idx] = val; + while (idx > 0) { + idx = (idx - 1) / 2; + segmentTree[idx] = f(segmentTree[2 * idx + 1], segmentTree[2 * idx + 2]); + } +} + +template +Monoid SegmentTree::Query(int a, int b, int k, int l, int r) { + if (r < 0) r = n; + + if (r <= a || b <= l) return identityElement; + + if (a <= l && r <= b) return segmentTree[k]; + + int vl = Query(a, b, 2 * k + 1, l, (l + r) / 2); + int vr = Query(a, b, 2 * k + 2, (l + r) / 2, r); + return f(vl, vr); +} + +template +Monoid SegmentTree::GetNum(int idx) { + return segmentTree[idx + n - 1]; +} diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index dc7943d..077a349 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,5 +1,5 @@ cmake_minimum_required(VERSION 2.8) -file(GLOB MATH_TEST_FILES "Math/test_*.cpp") -add_executable(mathTest ${MATH_TEST_FILES}) -target_link_libraries(mathTest ProconMath gtest gtest_main pthread) +file(GLOB TEST_FILES "*/test_*.cpp") +add_executable(runTest ${TEST_FILES}) +target_link_libraries(runTest ProconMath gtest gtest_main pthread) diff --git a/test/DataStructure/test_segment_tree.cpp b/test/DataStructure/test_segment_tree.cpp new file mode 100644 index 0000000..ddd0b0d --- /dev/null +++ b/test/DataStructure/test_segment_tree.cpp @@ -0,0 +1,29 @@ +#include "gtest/gtest.h" +#include "lib/DataStructure/segment_tree.h" + +// Range Minimum Query +TEST(DataStructureTest, segment_tree_RMQ) { + vector vec(3, 1e9); + SegmentTree segmentTree(vec, [](int a, int b) { return min(a, b); }, + 1e9); + + segmentTree.Update(0, 1); + segmentTree.Update(1, 2); + segmentTree.Update(2, 3); + + EXPECT_EQ(segmentTree.Query(0, 2 + 1), 1); + EXPECT_EQ(segmentTree.Query(1, 2 + 1), 2); +} + +// Range Sum Query +TEST(DataStructureTest, segment_tree_RSQ) { + vector vec(3, 0); + SegmentTree segmentTree(vec, [](int a, int b) { return a + b; }, 0); + + segmentTree.Update(0, 1); + segmentTree.Update(1, 2); + segmentTree.Update(2, 3); + + EXPECT_EQ(segmentTree.Query(0, 1 + 1), 3); + EXPECT_EQ(segmentTree.Query(1, 1 + 1), 2); +}