From f3d63cbd93acc5b1e666908cf0522f5806c98ed6 Mon Sep 17 00:00:00 2001 From: Michael Broughton Date: Fri, 17 Dec 2021 02:32:21 -0800 Subject: [PATCH 1/5] Adds working bitstring Sample function to MPS. --- lib/mps_statespace.h | 64 +++++++++++ tests/mps_statespace_test.cc | 204 +++++++++++++++++++++++++++++++++++ 2 files changed, 268 insertions(+) diff --git a/lib/mps_statespace.h b/lib/mps_statespace.h index acdf69db..641d0332 100644 --- a/lib/mps_statespace.h +++ b/lib/mps_statespace.h @@ -25,7 +25,9 @@ #include #include #include +#include #include +#include #include "../eigen/Eigen/Dense" #include "../eigen/unsupported/Eigen/CXX11/Tensor" @@ -372,6 +374,68 @@ class MPSStateSpace { out = t_4d.contract(t_2d, product_dims); } + // Draw a single bitstring sample from state using scratch and scratch2 + // as working space. + static void SampleOnce(MPS& state, MPS& scratch, MPS& scratch2, + std::mt19937* random_gen, std::vector* sample) { + const auto bond_dim = state.bond_dim(); + const auto num_qubits = state.num_qubits(); + std::default_random_engine generator; + fp_type* scratch_raw = scratch.get(); + fp_type rdm[8]; + + sample->reserve(num_qubits); + Copy(state, scratch); + Copy(state, scratch2); + + // Sample left block. + ReduceDensityMatrix(scratch, scratch2, 0, rdm); + auto p0 = rdm[0] / (rdm[0] + rdm[6]); + std::bernoulli_distribution distribution(1 - p0); + auto bit_val = distribution(*random_gen); + + sample->push_back(bit_val); + MatrixMap tensor_block((Complex*)scratch_raw, 2, bond_dim); + tensor_block.row(!bit_val).setZero(); + tensor_block.imag() *= -1; + + // Sample internal blocks. + for (unsigned i = 1; i < num_qubits - 1; i++) { + ReduceDensityMatrix(scratch, scratch2, i, rdm); + p0 = rdm[0] / (rdm[0] + rdm[6]); + distribution = std::bernoulli_distribution(1 - p0); + bit_val = distribution(*random_gen); + + sample->push_back(bit_val); + const auto mem_start = GetBlockOffset(scratch, i); + new (&tensor_block) MatrixMap((Complex*)(scratch_raw + mem_start), + bond_dim * 2, bond_dim); + for (unsigned j = !bit_val; j < 2 * bond_dim; j += 2) { + tensor_block.row(j).setZero(); + } + tensor_block.imag() *= -1; + } + + // Sample right block. + ReduceDensityMatrix(scratch, scratch2, num_qubits - 1, rdm); + p0 = rdm[0] / (rdm[0] + rdm[6]); + distribution = std::bernoulli_distribution(1 - p0); + bit_val = distribution(*random_gen); + sample->push_back(bit_val); + } + + // Draw num_samples bitstring samples from state and store the result + // bit vectors in results. Uses scratch and scratch2 as workspace. + static void Sample(MPS& state, MPS& scratch, MPS& scratch2, + unsigned num_samples, unsigned seed, + std::vector>* results) { + std::mt19937 rand_source(seed); + results->reserve(num_samples); + for (unsigned i = 0; i < num_samples; i++) { + SampleOnce(state, scratch, scratch2, &rand_source, &(*results)[i]); + } + } + // Testing only. Convert the MPS to a wavefunction under "normal" ordering. // Requires: wf be allocated beforehand with bond_dim * 2 ^ num_qubits -1 // memory. diff --git a/tests/mps_statespace_test.cc b/tests/mps_statespace_test.cc index 190f08df..28c92ed3 100644 --- a/tests/mps_statespace_test.cc +++ b/tests/mps_statespace_test.cc @@ -16,6 +16,7 @@ #include "../lib/formux.h" #include "gtest/gtest.h" +#include "gmock/gmock.h" namespace qsim { @@ -900,6 +901,209 @@ TEST(MPSStateSpaceTest, ReduceDensityMatrixLarge){ } +TEST(MPSStateSpaceTest, SampleOnceSimple){ + auto ss = MPSStateSpace(1); + auto mps = ss.Create(3, 4); + auto scratch = ss.Create(3, 4); + auto scratch2 = ss.Create(3, 4); + std::mt19937 rand_source(1234); + std::vector results; + + // Set to |100>. + results.clear(); + ss.SetStateZero(mps); + mps.get()[0] = 0; + mps.get()[8] = 1; + ss.SampleOnce(mps, scratch, scratch2, &rand_source, &results); + ASSERT_THAT(results, testing::ElementsAre(1, 0, 0)); + //EXPECT_EQ(1,2); + + // Set to |010>. + results.clear(); + ss.SetStateZero(mps); + mps.get()[16] = 0; + mps.get()[24] = 1; + ss.SampleOnce(mps, scratch, scratch2, &rand_source, &results); + ASSERT_THAT(results, testing::ElementsAre(0, 1, 0)); + + // Set to |001>. + results.clear(); + ss.SetStateZero(mps); + mps.get()[80] = 0; + mps.get()[82] = 1; + ss.SampleOnce(mps, scratch, scratch2, &rand_source, &results); + ASSERT_THAT(results, testing::ElementsAre(0, 0, 1)); + + // Set to |101>. + results.clear(); + ss.SetStateZero(mps); + mps.get()[0] = 0; + mps.get()[8] = 1; + mps.get()[80] = 0; + mps.get()[82] = 1; + ss.SampleOnce(mps, scratch, scratch2, &rand_source, &results); + ASSERT_THAT(results, testing::ElementsAre(1, 0, 1)); +} + +TEST(MPSStateSpaceTest, SampleGHZ){ + const int num_samples = 10000; + auto ss = MPSStateSpace(1); + auto mps = ss.Create(3, 4); + auto scratch = ss.Create(3, 4); + auto scratch2 = ss.Create(3, 4); + std::vector> results( + num_samples, std::vector({})); + + memset(mps.get(), 0, ss.RawSize(mps)); + mps.get()[0] = 1; + mps.get()[10] = 1; + mps.get()[16] = 1; + mps.get()[42] = -1; + mps.get()[80] = 0.70710677; + mps.get()[86] = -0.70710677; + + float count = 0; + ss.Sample(mps, scratch, scratch2, num_samples, 1234, &results); + for(int i = 0 ; i < num_samples; i++){ + ASSERT_THAT(results[i], testing::AnyOf(testing::ElementsAre(1, 1, 1), + testing::ElementsAre(0, 0, 0))); + count += results[i][0]; + } + EXPECT_NEAR(count / float(num_samples), 0.5, 1e-2); +} + +TEST(MPSStateSpaceTest, SampleComplex){ + const int num_samples = 10000; + auto ss = MPSStateSpace(1); + auto mps = ss.Create(4, 4); + auto scratch = ss.Create(4, 4); + auto scratch2 = ss.Create(4, 4); + std::vector> results( + num_samples, std::vector({})); + + memset(mps.get(), 0, ss.RawSize(mps)); + mps.get()[ 0 ] = 0.033688569334715854 ; + mps.get()[ 1 ] = -0.10444182602180123 ; + mps.get()[ 2 ] = 0.9076354671683359 ; + mps.get()[ 3 ] = 0.405160344657187 ; + mps.get()[ 8 ] = -0.9595253512026178 ; + mps.get()[ 9 ] = -0.25936097827312377 ; + mps.get()[ 10 ] = -0.03987001675676861 ; + mps.get()[ 11 ] = 0.10224185693597321 ; + mps.get()[ 16 ] = -0.4350591822776815 ; + mps.get()[ 17 ] = 0.22228546667942578 ; + mps.get()[ 18 ] = -0.6285732819602607 ; + mps.get()[ 19 ] = 0.5943422063507785 ; + mps.get()[ 20 ] = -0.02428345908884816 ; + mps.get()[ 21 ] = 0.026256572727652475 ; + mps.get()[ 22 ] = 0.0728063572325396 ; + mps.get()[ 23 ] = -0.07991114142962712 ; + mps.get()[ 24 ] = -0.1642035571020447 ; + mps.get()[ 25 ] = -0.8209212529030018 ; + mps.get()[ 26 ] = 0.21124207331921135 ; + mps.get()[ 27 ] = 0.4033152234452636 ; + mps.get()[ 28 ] = -0.11315780332634073 ; + mps.get()[ 29 ] = -0.18477947087021204 ; + mps.get()[ 30 ] = -0.11199707215175961 ; + mps.get()[ 31 ] = -0.17985444082650426 ; + mps.get()[ 32 ] = -0.18059162771674087 ; + mps.get()[ 33 ] = -0.12173196101857839 ; + mps.get()[ 34 ] = 0.19817171168239098 ; + mps.get()[ 35 ] = 0.063054719070231 ; + mps.get()[ 36 ] = 0.45024015745008505 ; + mps.get()[ 37 ] = 0.11068157212593255 ; + mps.get()[ 38 ] = 0.8106501683288581 ; + mps.get()[ 39 ] = 0.19287240226762353 ; + mps.get()[ 40 ] = -0.08702392413741208 ; + mps.get()[ 41 ] = 0.07370887245848737 ; + mps.get()[ 42 ] = -0.018412987786278347 ; + mps.get()[ 43 ] = 0.027921764369775018 ; + mps.get()[ 44 ] = 0.42351439662329743 ; + mps.get()[ 45 ] = -0.7466201917698305 ; + mps.get()[ 46 ] = -0.24298735917837008 ; + mps.get()[ 47 ] = 0.4359199055764641 ; + mps.get()[ 80 ] = 0.422436255430577 ; + mps.get()[ 81 ] = 0.0 ; + mps.get()[ 82 ] = 0.1211402132186689 ; + mps.get()[ 83 ] = -0.819174648113452 ; + mps.get()[ 84 ] = 0.0 ; + mps.get()[ 85 ] = -7.333691512826885e-20 ; + mps.get()[ 88 ] = -0.8676720638499252 ; + mps.get()[ 89 ] = 0.1360568551008419 ; + mps.get()[ 90 ] = 0.011793867549118184 ; + mps.get()[ 91 ] = -0.44097834083157716 ; + mps.get()[ 96 ] = -0.08978879818674973 ; + mps.get()[ 97 ] = 0.0 ; + mps.get()[ 98 ] = -0.021807957717091198 ; + mps.get()[ 99 ] = -0.05873893136151775 ; + mps.get()[ 100 ] = 0.0 ; + mps.get()[ 101 ] = -9.89518074266081e-19 ; + mps.get()[ 104 ] = 0.14307979940517454 ; + mps.get()[ 105 ] = 0.06032194765563529 ; + mps.get()[ 106 ] = 0.22589440044405648 ; + mps.get()[ 107 ] = -0.2397609424987549 ; + mps.get()[ 112 ] = 0.12734430206944722 ; + mps.get()[ 113 ] = 0.0 ; + mps.get()[ 114 ] = -0.003114595079760157 ; + mps.get()[ 115 ] = 0.06816683893204967 ; + mps.get()[ 116 ] = 2.722010100011512e-17 ; + mps.get()[ 117 ] = 1.9629880528929172e-18 ; + mps.get()[ 120 ] = 0.014022255715263434 ; + mps.get()[ 121 ] = 0.017127855001478075 ; + mps.get()[ 122 ] = 0.025812082320798548 ; + mps.get()[ 123 ] = -0.027110021000464 ; + mps.get()[ 128 ] = -0.018262196707018574 ; + mps.get()[ 129 ] = 0.0 ; + mps.get()[ 130 ] = 0.0032725358458428836 ; + mps.get()[ 131 ] = 0.0310845568816579 ; + mps.get()[ 132 ] = 1.935877805637811e-17 ; + mps.get()[ 133 ] = -1.0370773958773989e-18 ; + mps.get()[ 136 ] = -0.028632305212090994 ; + mps.get()[ 137 ] = 0.012199896816087576 ; + mps.get()[ 138 ] = 0.0009323445588941451 ; + mps.get()[ 139 ] = -0.014212789540748644 ; + mps.get()[ 144 ] = -0.07762944756130831 ; + mps.get()[ 145 ] = -0.25063255485414937 ; + mps.get()[ 146 ] = 0.515385895406013 ; + mps.get()[ 147 ] = 0.7314486807404007 ; + mps.get()[ 148 ] = 0.20689214104052 ; + mps.get()[ 149 ] = 0.2781707321332216 ; + mps.get()[ 150 ] = 0.08286244916183945 ; + mps.get()[ 151 ] = 0.05888783848647657 ; + + ss.Sample(mps, scratch, scratch2, num_samples, 12345, &results); + + std::vector expected({ + 0.00467637, + 0.0020386, + 0.00112952, + 0.0269848, + 0.00704221, + 0.00147802, + 0.00243688, + 0.0350753, + 0.0324814, + 0.0141599, + 0.0412371, + 0.0780275, + 0.0644363, + 0.140995, + 0.0355866, + 0.512215, + }); + std::vector hist(16, 0); + for(int i =0;i Date: Fri, 17 Dec 2021 02:35:52 -0800 Subject: [PATCH 2/5] remove iostream include. --- lib/mps_statespace.h | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/mps_statespace.h b/lib/mps_statespace.h index 641d0332..c017e650 100644 --- a/lib/mps_statespace.h +++ b/lib/mps_statespace.h @@ -25,7 +25,6 @@ #include #include #include -#include #include #include From fa36fd74e8baafec61d1944c961ed5e902633bd3 Mon Sep 17 00:00:00 2001 From: Michael Broughton Date: Fri, 17 Dec 2021 13:08:26 -0800 Subject: [PATCH 3/5] remove use of gmock. --- tests/mps_statespace_test.cc | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/tests/mps_statespace_test.cc b/tests/mps_statespace_test.cc index 28c92ed3..9473eaf9 100644 --- a/tests/mps_statespace_test.cc +++ b/tests/mps_statespace_test.cc @@ -16,7 +16,6 @@ #include "../lib/formux.h" #include "gtest/gtest.h" -#include "gmock/gmock.h" namespace qsim { @@ -915,8 +914,9 @@ TEST(MPSStateSpaceTest, SampleOnceSimple){ mps.get()[0] = 0; mps.get()[8] = 1; ss.SampleOnce(mps, scratch, scratch2, &rand_source, &results); - ASSERT_THAT(results, testing::ElementsAre(1, 0, 0)); - //EXPECT_EQ(1,2); + EXPECT_EQ(results[0], 1); + EXPECT_EQ(results[1], 0); + EXPECT_EQ(results[2], 0); // Set to |010>. results.clear(); @@ -924,7 +924,9 @@ TEST(MPSStateSpaceTest, SampleOnceSimple){ mps.get()[16] = 0; mps.get()[24] = 1; ss.SampleOnce(mps, scratch, scratch2, &rand_source, &results); - ASSERT_THAT(results, testing::ElementsAre(0, 1, 0)); + EXPECT_EQ(results[0], 0); + EXPECT_EQ(results[1], 1); + EXPECT_EQ(results[2], 0); // Set to |001>. results.clear(); @@ -932,7 +934,9 @@ TEST(MPSStateSpaceTest, SampleOnceSimple){ mps.get()[80] = 0; mps.get()[82] = 1; ss.SampleOnce(mps, scratch, scratch2, &rand_source, &results); - ASSERT_THAT(results, testing::ElementsAre(0, 0, 1)); + EXPECT_EQ(results[0], 0); + EXPECT_EQ(results[1], 0); + EXPECT_EQ(results[2], 1); // Set to |101>. results.clear(); @@ -942,7 +946,9 @@ TEST(MPSStateSpaceTest, SampleOnceSimple){ mps.get()[80] = 0; mps.get()[82] = 1; ss.SampleOnce(mps, scratch, scratch2, &rand_source, &results); - ASSERT_THAT(results, testing::ElementsAre(1, 0, 1)); + EXPECT_EQ(results[0], 1); + EXPECT_EQ(results[1], 0); + EXPECT_EQ(results[2], 1); } TEST(MPSStateSpaceTest, SampleGHZ){ @@ -965,8 +971,10 @@ TEST(MPSStateSpaceTest, SampleGHZ){ float count = 0; ss.Sample(mps, scratch, scratch2, num_samples, 1234, &results); for(int i = 0 ; i < num_samples; i++){ - ASSERT_THAT(results[i], testing::AnyOf(testing::ElementsAre(1, 1, 1), - testing::ElementsAre(0, 0, 0))); + bool all_same = 1; + all_same &= results[i][0] == results[i][1]; + all_same &= results[i][1] == results[i][2]; + EXPECT_EQ(all_same, 1); count += results[i][0]; } EXPECT_NEAR(count / float(num_samples), 0.5, 1e-2); From 70f7a2bfe24a65635efa63f41514c3ddb5efbf03 Mon Sep 17 00:00:00 2001 From: Michael Broughton Date: Tue, 21 Dec 2021 04:21:19 -0800 Subject: [PATCH 4/5] Moved to O(n) sampling algorithm. --- lib/mps_statespace.h | 153 ++++++++++++++++++++++--- tests/mps_statespace_test.cc | 210 ++++++++++++++++++----------------- 2 files changed, 243 insertions(+), 120 deletions(-) diff --git a/lib/mps_statespace.h b/lib/mps_statespace.h index c017e650..721b626e 100644 --- a/lib/mps_statespace.h +++ b/lib/mps_statespace.h @@ -22,9 +22,11 @@ #include #endif +#include #include #include #include +#include #include #include @@ -324,7 +326,7 @@ class MPSStateSpace { // Merge top into partial_contract2. new (&top) ConstMatrixMap((Complex*)(scratch_raw + offset), bond_dim, 2 * bond_dim); - // [bd, bd] = [2bd, bd] @ [bd, 2bd] + // [bd, bd] = [bd, 2bd] @ [bd, 2bd] partial_contract.noalias() = top * partial_contract2.adjoint(); } @@ -377,46 +379,165 @@ class MPSStateSpace { // as working space. static void SampleOnce(MPS& state, MPS& scratch, MPS& scratch2, std::mt19937* random_gen, std::vector* sample) { + // TODO: carefully profile with perf and optimize temp storage + // locations for cache friendliness. const auto bond_dim = state.bond_dim(); const auto num_qubits = state.num_qubits(); + const auto end = Size(state); + const auto left_frontier_offset = GetBlockOffset(state, num_qubits + 1); std::default_random_engine generator; + fp_type* state_raw = state.get(); fp_type* scratch_raw = scratch.get(); + fp_type* scratch2_raw = scratch2.get(); fp_type rdm[8]; sample->reserve(num_qubits); Copy(state, scratch); Copy(state, scratch2); - // Sample left block. - ReduceDensityMatrix(scratch, scratch2, 0, rdm); + // Store prefix contractions in scratch2. + auto offset = GetBlockOffset(state, num_qubits - 1); + ConstMatrixMap top((Complex*)(state_raw + offset), bond_dim, 2); + ConstMatrixMap bot((Complex*)(scratch_raw + offset), bond_dim, 2); + MatrixMap partial_contract((Complex*)(scratch2_raw + offset), bond_dim, + bond_dim); + MatrixMap partial_contract2((Complex*)(scratch_raw + end), bond_dim, + 2 * bond_dim); + partial_contract.noalias() = top * bot.adjoint(); + + for (unsigned i = num_qubits - 2; i > 0; --i) { + offset = GetBlockOffset(state, i); + // reshape: + new (&partial_contract2) + MatrixMap((Complex*)(scratch_raw + end), 2 * bond_dim, bond_dim); + + // Merge bot into left boundary merged tensor. + new (&bot) ConstMatrixMap((Complex*)(scratch_raw + offset), 2 * bond_dim, + bond_dim); + partial_contract2.noalias() = bot * partial_contract.adjoint(); + + // reshape: + new (&partial_contract2) + MatrixMap((Complex*)(scratch_raw + end), bond_dim, 2 * bond_dim); + + // Merge top into partial_contract2. + new (&top) ConstMatrixMap((Complex*)(state_raw + offset), bond_dim, + 2 * bond_dim); + + // merge into partial_contract -> scracth2_raw. + new (&partial_contract) + MatrixMap((Complex*)(scratch2_raw + offset), bond_dim, bond_dim); + partial_contract.noalias() = top * partial_contract2.adjoint(); + } + + // Compute RDM-0 and draw first sample. + offset = GetBlockOffset(state, 1); + new (&top) ConstMatrixMap((Complex*)state_raw, 2, bond_dim); + new (&bot) ConstMatrixMap((Complex*)scratch_raw, 2, bond_dim); + new (&partial_contract) + MatrixMap((Complex*)(scratch2_raw + offset), bond_dim, bond_dim); + new (&partial_contract2) + MatrixMap((Complex*)(scratch_raw + end), 2, bond_dim); + + partial_contract2.noalias() = bot * partial_contract.adjoint(); + + new (&partial_contract) MatrixMap((Complex*)rdm, 2, 2); + partial_contract.noalias() = top * partial_contract2.adjoint(); auto p0 = rdm[0] / (rdm[0] + rdm[6]); std::bernoulli_distribution distribution(1 - p0); auto bit_val = distribution(*random_gen); - sample->push_back(bit_val); - MatrixMap tensor_block((Complex*)scratch_raw, 2, bond_dim); - tensor_block.row(!bit_val).setZero(); - tensor_block.imag() *= -1; - // Sample internal blocks. + // collapse state. + new (&partial_contract) MatrixMap((Complex*)scratch_raw, 2, bond_dim); + partial_contract.row(!bit_val).setZero(); + + // Prepare left contraction frontier. + new (&partial_contract2) MatrixMap( + (Complex*)(scratch2_raw + left_frontier_offset), bond_dim, bond_dim); + partial_contract2.noalias() = + partial_contract.transpose() * partial_contract.conjugate(); + + // Compute RDM-i and draw internal tensor samples. for (unsigned i = 1; i < num_qubits - 1; i++) { - ReduceDensityMatrix(scratch, scratch2, i, rdm); + // Get leftmost [bd, bd] contraction and contract with top. + offset = GetBlockOffset(state, i); + new (&partial_contract) MatrixMap( + (Complex*)(scratch2_raw + left_frontier_offset), bond_dim, bond_dim); + new (&top) ConstMatrixMap((Complex*)(state_raw + offset), bond_dim, + 2 * bond_dim); + new (&partial_contract2) + MatrixMap((Complex*)(state_raw + end), bond_dim, 2 * bond_dim); + partial_contract2.noalias() = partial_contract * top.conjugate(); + + // Contract top again for correct shape. + MatrixMap partial_contract3((Complex*)(scratch_raw + end), 2 * bond_dim, + 2 * bond_dim); + partial_contract3.noalias() = top.transpose() * partial_contract2; + + // Conduct final tensor contraction operations. Cannot be easily compiled + // to matmul. Perf reports shows only ~6% of runtime spent here on large + // systems. + offset = GetBlockOffset(state, i + 1); + const Eigen::TensorMap> + t_4d((Complex*)(scratch_raw + end), 2, bond_dim, 2, bond_dim); + const Eigen::TensorMap> + t_2d((Complex*)(scratch2_raw + offset), bond_dim, bond_dim); + + const Eigen::array, 2> product_dims = { + Eigen::IndexPair(1, 0), + Eigen::IndexPair(3, 1), + }; + Eigen::TensorMap> out( + (Complex*)rdm, 2, 2); + out = t_4d.contract(t_2d, product_dims); + + // Sample bit and collapse state. p0 = rdm[0] / (rdm[0] + rdm[6]); distribution = std::bernoulli_distribution(1 - p0); bit_val = distribution(*random_gen); sample->push_back(bit_val); - const auto mem_start = GetBlockOffset(scratch, i); - new (&tensor_block) MatrixMap((Complex*)(scratch_raw + mem_start), - bond_dim * 2, bond_dim); + offset = GetBlockOffset(state, i); + new (&partial_contract) + MatrixMap((Complex*)(scratch_raw + offset), bond_dim * 2, bond_dim); for (unsigned j = !bit_val; j < 2 * bond_dim; j += 2) { - tensor_block.row(j).setZero(); + partial_contract.row(j).setZero(); } - tensor_block.imag() *= -1; + + // Update left frontier. + new (&partial_contract) MatrixMap( + (Complex*)(scratch2_raw + left_frontier_offset), bond_dim, bond_dim); + + // reshape: + new (&partial_contract2) + MatrixMap((Complex*)(state_raw + end), bond_dim, 2 * bond_dim); + + // Merge bot into left boundary merged tensor. + new (&bot) ConstMatrixMap((Complex*)(scratch_raw + offset), bond_dim, + 2 * bond_dim); + partial_contract2.noalias() = partial_contract * bot.conjugate(); + + // reshape: + new (&partial_contract2) + MatrixMap((Complex*)(state_raw + end), 2 * bond_dim, bond_dim); + + // Merge top into partial_contract2. + new (&top) ConstMatrixMap((Complex*)(scratch_raw + offset), 2 * bond_dim, + bond_dim); + partial_contract.noalias() = top.transpose() * partial_contract2; } - // Sample right block. - ReduceDensityMatrix(scratch, scratch2, num_qubits - 1, rdm); + // Compute RDM-(n-1) and sample. + offset = GetBlockOffset(state, num_qubits - 1); + new (&partial_contract2) + MatrixMap((Complex*)(state_raw + end), bond_dim, 2); + + new (&top) ConstMatrixMap((Complex*)(state_raw + offset), bond_dim, 2); + partial_contract2.noalias() = partial_contract * top.conjugate(); + new (&partial_contract) MatrixMap((Complex*)rdm, 2, 2); + partial_contract.noalias() = top.transpose() * partial_contract2; + p0 = rdm[0] / (rdm[0] + rdm[6]); distribution = std::bernoulli_distribution(1 - p0); bit_val = distribution(*random_gen); diff --git a/tests/mps_statespace_test.cc b/tests/mps_statespace_test.cc index 9473eaf9..1ea037c5 100644 --- a/tests/mps_statespace_test.cc +++ b/tests/mps_statespace_test.cc @@ -976,6 +976,7 @@ TEST(MPSStateSpaceTest, SampleGHZ){ all_same &= results[i][1] == results[i][2]; EXPECT_EQ(all_same, 1); count += results[i][0]; + EXPECT_EQ(results[i].size(), 3); } EXPECT_NEAR(count / float(num_samples), 0.5, 1e-2); } @@ -990,126 +991,127 @@ TEST(MPSStateSpaceTest, SampleComplex){ num_samples, std::vector({})); memset(mps.get(), 0, ss.RawSize(mps)); - mps.get()[ 0 ] = 0.033688569334715854 ; - mps.get()[ 1 ] = -0.10444182602180123 ; - mps.get()[ 2 ] = 0.9076354671683359 ; - mps.get()[ 3 ] = 0.405160344657187 ; - mps.get()[ 8 ] = -0.9595253512026178 ; - mps.get()[ 9 ] = -0.25936097827312377 ; - mps.get()[ 10 ] = -0.03987001675676861 ; - mps.get()[ 11 ] = 0.10224185693597321 ; - mps.get()[ 16 ] = -0.4350591822776815 ; - mps.get()[ 17 ] = 0.22228546667942578 ; - mps.get()[ 18 ] = -0.6285732819602607 ; - mps.get()[ 19 ] = 0.5943422063507785 ; - mps.get()[ 20 ] = -0.02428345908884816 ; - mps.get()[ 21 ] = 0.026256572727652475 ; - mps.get()[ 22 ] = 0.0728063572325396 ; - mps.get()[ 23 ] = -0.07991114142962712 ; - mps.get()[ 24 ] = -0.1642035571020447 ; - mps.get()[ 25 ] = -0.8209212529030018 ; - mps.get()[ 26 ] = 0.21124207331921135 ; - mps.get()[ 27 ] = 0.4033152234452636 ; - mps.get()[ 28 ] = -0.11315780332634073 ; - mps.get()[ 29 ] = -0.18477947087021204 ; - mps.get()[ 30 ] = -0.11199707215175961 ; - mps.get()[ 31 ] = -0.17985444082650426 ; - mps.get()[ 32 ] = -0.18059162771674087 ; - mps.get()[ 33 ] = -0.12173196101857839 ; - mps.get()[ 34 ] = 0.19817171168239098 ; - mps.get()[ 35 ] = 0.063054719070231 ; - mps.get()[ 36 ] = 0.45024015745008505 ; - mps.get()[ 37 ] = 0.11068157212593255 ; - mps.get()[ 38 ] = 0.8106501683288581 ; - mps.get()[ 39 ] = 0.19287240226762353 ; - mps.get()[ 40 ] = -0.08702392413741208 ; - mps.get()[ 41 ] = 0.07370887245848737 ; - mps.get()[ 42 ] = -0.018412987786278347 ; - mps.get()[ 43 ] = 0.027921764369775018 ; - mps.get()[ 44 ] = 0.42351439662329743 ; - mps.get()[ 45 ] = -0.7466201917698305 ; - mps.get()[ 46 ] = -0.24298735917837008 ; - mps.get()[ 47 ] = 0.4359199055764641 ; - mps.get()[ 80 ] = 0.422436255430577 ; + mps.get()[ 0 ] = -0.4917038696869799 ; + mps.get()[ 1 ] = 0.016731957658280873 ; + mps.get()[ 2 ] = 0.86132663373237 ; + mps.get()[ 3 ] = 0.12674293823327035 ; + mps.get()[ 8 ] = -0.5023020703950029 ; + mps.get()[ 9 ] = -0.711083648814302 ; + mps.get()[ 10 ] = -0.20727818303023368 ; + mps.get()[ 11 ] = -0.4461932766843352 ; + mps.get()[ 16 ] = 0.15655121570640956 ; + mps.get()[ 17 ] = 0.4732738079187066 ; + mps.get()[ 18 ] = -0.08511634068671248 ; + mps.get()[ 19 ] = 0.4509108800471812 ; + mps.get()[ 20 ] = 0.3399824326377983 ; + mps.get()[ 21 ] = 0.26456637633430585 ; + mps.get()[ 22 ] = 0.5923848721836553 ; + mps.get()[ 23 ] = -0.06659540240231236 ; + mps.get()[ 24 ] = 0.3386920440520109 ; + mps.get()[ 25 ] = -0.5078386788732782 ; + mps.get()[ 26 ] = -0.5938438138167242 ; + mps.get()[ 27 ] = -0.2253530600030204 ; + mps.get()[ 28 ] = -0.08439705180650249 ; + mps.get()[ 29 ] = 0.18289872169116567 ; + mps.get()[ 30 ] = 0.33989833066754255 ; + mps.get()[ 31 ] = -0.2604753706869852 ; + mps.get()[ 32 ] = 0.3013840839514031 ; + mps.get()[ 33 ] = -0.10757629710841352 ; + mps.get()[ 34 ] = -0.043855659850960294 ; + mps.get()[ 35 ] = -0.0999497956398576 ; + mps.get()[ 36 ] = 0.6336147397284169 ; + mps.get()[ 37 ] = 0.43658807519265264 ; + mps.get()[ 38 ] = -0.448346536528476 ; + mps.get()[ 39 ] = 0.30428652791930944 ; + mps.get()[ 40 ] = 0.2954131683108271 ; + mps.get()[ 41 ] = -0.4349910681437736 ; + mps.get()[ 42 ] = 0.35640542464599323 ; + mps.get()[ 43 ] = 0.4970533197510696 ; + mps.get()[ 44 ] = -0.37101487814696105 ; + mps.get()[ 45 ] = 0.2100308254832807 ; + mps.get()[ 46 ] = 0.10591704897593116 ; + mps.get()[ 47 ] = 0.3955295090226334 ; + mps.get()[ 80 ] = -0.24953341864058454 ; mps.get()[ 81 ] = 0.0 ; - mps.get()[ 82 ] = 0.1211402132186689 ; - mps.get()[ 83 ] = -0.819174648113452 ; - mps.get()[ 84 ] = 0.0 ; - mps.get()[ 85 ] = -7.333691512826885e-20 ; - mps.get()[ 88 ] = -0.8676720638499252 ; - mps.get()[ 89 ] = 0.1360568551008419 ; - mps.get()[ 90 ] = 0.011793867549118184 ; - mps.get()[ 91 ] = -0.44097834083157716 ; - mps.get()[ 96 ] = -0.08978879818674973 ; + mps.get()[ 82 ] = -0.5480093086703182 ; + mps.get()[ 83 ] = -0.20497358945530025 ; + mps.get()[ 84 ] = -1.1887516198406813e-16 ; + mps.get()[ 85 ] = 3.714848812002129e-18 ; + mps.get()[ 88 ] = 0.6045663379213811 ; + mps.get()[ 89 ] = -0.3501271865840065 ; + mps.get()[ 90 ] = -0.29968140886676936 ; + mps.get()[ 91 ] = 0.40493683779718603 ; + mps.get()[ 96 ] = 0.3073334814703704 ; mps.get()[ 97 ] = 0.0 ; - mps.get()[ 98 ] = -0.021807957717091198 ; - mps.get()[ 99 ] = -0.05873893136151775 ; - mps.get()[ 100 ] = 0.0 ; - mps.get()[ 101 ] = -9.89518074266081e-19 ; - mps.get()[ 104 ] = 0.14307979940517454 ; - mps.get()[ 105 ] = 0.06032194765563529 ; - mps.get()[ 106 ] = 0.22589440044405648 ; - mps.get()[ 107 ] = -0.2397609424987549 ; - mps.get()[ 112 ] = 0.12734430206944722 ; + mps.get()[ 98 ] = 0.07297353820052123 ; + mps.get()[ 99 ] = -0.2859132301813451 ; + mps.get()[ 100 ] = -1.7214471606144266e-16 ; + mps.get()[ 101 ] = 5.379522376920083e-18 ; + mps.get()[ 104 ] = -0.18689238699414557 ; + mps.get()[ 105 ] = -0.4911602105890581 ; + mps.get()[ 106 ] = -0.30326863844349566 ; + mps.get()[ 107 ] = -0.22667282775953723 ; + mps.get()[ 112 ] = -0.10881711525857803 ; mps.get()[ 113 ] = 0.0 ; - mps.get()[ 114 ] = -0.003114595079760157 ; - mps.get()[ 115 ] = 0.06816683893204967 ; - mps.get()[ 116 ] = 2.722010100011512e-17 ; - mps.get()[ 117 ] = 1.9629880528929172e-18 ; - mps.get()[ 120 ] = 0.014022255715263434 ; - mps.get()[ 121 ] = 0.017127855001478075 ; - mps.get()[ 122 ] = 0.025812082320798548 ; - mps.get()[ 123 ] = -0.027110021000464 ; - mps.get()[ 128 ] = -0.018262196707018574 ; + mps.get()[ 114 ] = -0.146152770590198 ; + mps.get()[ 115 ] = 0.2149415742117364 ; + mps.get()[ 116 ] = -4.72314539505504e-16 ; + mps.get()[ 117 ] = 1.1519866817207415e-17 ; + mps.get()[ 120 ] = -0.01567698028444534 ; + mps.get()[ 121 ] = 0.013440646849502781 ; + mps.get()[ 122 ] = -0.17367051562799563 ; + mps.get()[ 123 ] = -0.24954843447516284 ; + mps.get()[ 128 ] = 0.24030153622040965 ; mps.get()[ 129 ] = 0.0 ; - mps.get()[ 130 ] = 0.0032725358458428836 ; - mps.get()[ 131 ] = 0.0310845568816579 ; - mps.get()[ 132 ] = 1.935877805637811e-17 ; - mps.get()[ 133 ] = -1.0370773958773989e-18 ; - mps.get()[ 136 ] = -0.028632305212090994 ; - mps.get()[ 137 ] = 0.012199896816087576 ; - mps.get()[ 138 ] = 0.0009323445588941451 ; - mps.get()[ 139 ] = -0.014212789540748644 ; - mps.get()[ 144 ] = -0.07762944756130831 ; - mps.get()[ 145 ] = -0.25063255485414937 ; - mps.get()[ 146 ] = 0.515385895406013 ; - mps.get()[ 147 ] = 0.7314486807404007 ; - mps.get()[ 148 ] = 0.20689214104052 ; - mps.get()[ 149 ] = 0.2781707321332216 ; - mps.get()[ 150 ] = 0.08286244916183945 ; - mps.get()[ 151 ] = 0.05888783848647657 ; + mps.get()[ 130 ] = -0.08309837568058188 ; + mps.get()[ 131 ] = 0.07924116582885271 ; + mps.get()[ 132 ] = -7.075275311738327e-17 ; + mps.get()[ 133 ] = 3.930708506521293e-18 ; + mps.get()[ 136 ] = 0.0725269370009367 ; + mps.get()[ 137 ] = 0.06123701427497634 ; + mps.get()[ 138 ] = -0.006630682493419155 ; + mps.get()[ 139 ] = 0.015491880670142021 ; + mps.get()[ 144 ] = -0.021403127627426542 ; + mps.get()[ 145 ] = 0.04422341855596844 ; + mps.get()[ 146 ] = 0.27602112861704176 ; + mps.get()[ 147 ] = 0.7790060986745896 ; + mps.get()[ 148 ] = 0.25252680029727903 ; + mps.get()[ 149 ] = 0.49967041792054084 ; + mps.get()[ 150 ] = -0.031679241045523554 ; + mps.get()[ 151 ] = -0.010202895067710558 ; ss.Sample(mps, scratch, scratch2, num_samples, 12345, &results); - std::vector expected({ - 0.00467637, - 0.0020386, - 0.00112952, - 0.0269848, - 0.00704221, - 0.00147802, - 0.00243688, - 0.0350753, - 0.0324814, - 0.0141599, - 0.0412371, - 0.0780275, - 0.0644363, - 0.140995, - 0.0355866, - 0.512215, + 0.036801, + 0.040697, + 0.002013, + 0.064595, + 0.014892, + 0.082028, + 0.008521, + 0.168310, + 0.022078, + 0.005907, + 0.024806, + 0.189074, + 0.090056, + 0.023125, + 0.116683, + 0.110406 }); std::vector hist(16, 0); for(int i =0;i Date: Tue, 21 Dec 2021 04:28:33 -0800 Subject: [PATCH 5/5] remove unused includes. --- lib/mps_statespace.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/lib/mps_statespace.h b/lib/mps_statespace.h index 721b626e..9b3acf31 100644 --- a/lib/mps_statespace.h +++ b/lib/mps_statespace.h @@ -22,11 +22,9 @@ #include #endif -#include #include #include #include -#include #include #include