diff --git a/lib/mps_statespace.h b/lib/mps_statespace.h index acdf69db..9b3acf31 100644 --- a/lib/mps_statespace.h +++ b/lib/mps_statespace.h @@ -26,6 +26,7 @@ #include #include #include +#include #include "../eigen/Eigen/Dense" #include "../eigen/unsupported/Eigen/CXX11/Tensor" @@ -323,7 +324,7 @@ class MPSStateSpace { // Merge top into partial_contract2. new (&top) ConstMatrixMap((Complex*)(scratch_raw + offset), bond_dim, 2 * bond_dim); - // [bd, bd] = [2bd, bd] @ [bd, 2bd] + // [bd, bd] = [bd, 2bd] @ [bd, 2bd] partial_contract.noalias() = top * partial_contract2.adjoint(); } @@ -372,6 +373,187 @@ class MPSStateSpace { out = t_4d.contract(t_2d, product_dims); } + // Draw a single bitstring sample from state using scratch and scratch2 + // as working space. + static void SampleOnce(MPS& state, MPS& scratch, MPS& scratch2, + std::mt19937* random_gen, std::vector* sample) { + // TODO: carefully profile with perf and optimize temp storage + // locations for cache friendliness. + const auto bond_dim = state.bond_dim(); + const auto num_qubits = state.num_qubits(); + const auto end = Size(state); + const auto left_frontier_offset = GetBlockOffset(state, num_qubits + 1); + std::default_random_engine generator; + fp_type* state_raw = state.get(); + fp_type* scratch_raw = scratch.get(); + fp_type* scratch2_raw = scratch2.get(); + fp_type rdm[8]; + + sample->reserve(num_qubits); + Copy(state, scratch); + Copy(state, scratch2); + + // Store prefix contractions in scratch2. + auto offset = GetBlockOffset(state, num_qubits - 1); + ConstMatrixMap top((Complex*)(state_raw + offset), bond_dim, 2); + ConstMatrixMap bot((Complex*)(scratch_raw + offset), bond_dim, 2); + MatrixMap partial_contract((Complex*)(scratch2_raw + offset), bond_dim, + bond_dim); + MatrixMap partial_contract2((Complex*)(scratch_raw + end), bond_dim, + 2 * bond_dim); + partial_contract.noalias() = top * bot.adjoint(); + + for (unsigned i = num_qubits - 2; i > 0; --i) { + offset = GetBlockOffset(state, i); + // reshape: + new (&partial_contract2) + MatrixMap((Complex*)(scratch_raw + end), 2 * bond_dim, bond_dim); + + // Merge bot into left boundary merged tensor. + new (&bot) ConstMatrixMap((Complex*)(scratch_raw + offset), 2 * bond_dim, + bond_dim); + partial_contract2.noalias() = bot * partial_contract.adjoint(); + + // reshape: + new (&partial_contract2) + MatrixMap((Complex*)(scratch_raw + end), bond_dim, 2 * bond_dim); + + // Merge top into partial_contract2. + new (&top) ConstMatrixMap((Complex*)(state_raw + offset), bond_dim, + 2 * bond_dim); + + // merge into partial_contract -> scracth2_raw. + new (&partial_contract) + MatrixMap((Complex*)(scratch2_raw + offset), bond_dim, bond_dim); + partial_contract.noalias() = top * partial_contract2.adjoint(); + } + + // Compute RDM-0 and draw first sample. + offset = GetBlockOffset(state, 1); + new (&top) ConstMatrixMap((Complex*)state_raw, 2, bond_dim); + new (&bot) ConstMatrixMap((Complex*)scratch_raw, 2, bond_dim); + new (&partial_contract) + MatrixMap((Complex*)(scratch2_raw + offset), bond_dim, bond_dim); + new (&partial_contract2) + MatrixMap((Complex*)(scratch_raw + end), 2, bond_dim); + + partial_contract2.noalias() = bot * partial_contract.adjoint(); + + new (&partial_contract) MatrixMap((Complex*)rdm, 2, 2); + partial_contract.noalias() = top * partial_contract2.adjoint(); + auto p0 = rdm[0] / (rdm[0] + rdm[6]); + std::bernoulli_distribution distribution(1 - p0); + auto bit_val = distribution(*random_gen); + sample->push_back(bit_val); + + // collapse state. + new (&partial_contract) MatrixMap((Complex*)scratch_raw, 2, bond_dim); + partial_contract.row(!bit_val).setZero(); + + // Prepare left contraction frontier. + new (&partial_contract2) MatrixMap( + (Complex*)(scratch2_raw + left_frontier_offset), bond_dim, bond_dim); + partial_contract2.noalias() = + partial_contract.transpose() * partial_contract.conjugate(); + + // Compute RDM-i and draw internal tensor samples. + for (unsigned i = 1; i < num_qubits - 1; i++) { + // Get leftmost [bd, bd] contraction and contract with top. + offset = GetBlockOffset(state, i); + new (&partial_contract) MatrixMap( + (Complex*)(scratch2_raw + left_frontier_offset), bond_dim, bond_dim); + new (&top) ConstMatrixMap((Complex*)(state_raw + offset), bond_dim, + 2 * bond_dim); + new (&partial_contract2) + MatrixMap((Complex*)(state_raw + end), bond_dim, 2 * bond_dim); + partial_contract2.noalias() = partial_contract * top.conjugate(); + + // Contract top again for correct shape. + MatrixMap partial_contract3((Complex*)(scratch_raw + end), 2 * bond_dim, + 2 * bond_dim); + partial_contract3.noalias() = top.transpose() * partial_contract2; + + // Conduct final tensor contraction operations. Cannot be easily compiled + // to matmul. Perf reports shows only ~6% of runtime spent here on large + // systems. + offset = GetBlockOffset(state, i + 1); + const Eigen::TensorMap> + t_4d((Complex*)(scratch_raw + end), 2, bond_dim, 2, bond_dim); + const Eigen::TensorMap> + t_2d((Complex*)(scratch2_raw + offset), bond_dim, bond_dim); + + const Eigen::array, 2> product_dims = { + Eigen::IndexPair(1, 0), + Eigen::IndexPair(3, 1), + }; + Eigen::TensorMap> out( + (Complex*)rdm, 2, 2); + out = t_4d.contract(t_2d, product_dims); + + // Sample bit and collapse state. + p0 = rdm[0] / (rdm[0] + rdm[6]); + distribution = std::bernoulli_distribution(1 - p0); + bit_val = distribution(*random_gen); + + sample->push_back(bit_val); + offset = GetBlockOffset(state, i); + new (&partial_contract) + MatrixMap((Complex*)(scratch_raw + offset), bond_dim * 2, bond_dim); + for (unsigned j = !bit_val; j < 2 * bond_dim; j += 2) { + partial_contract.row(j).setZero(); + } + + // Update left frontier. + new (&partial_contract) MatrixMap( + (Complex*)(scratch2_raw + left_frontier_offset), bond_dim, bond_dim); + + // reshape: + new (&partial_contract2) + MatrixMap((Complex*)(state_raw + end), bond_dim, 2 * bond_dim); + + // Merge bot into left boundary merged tensor. + new (&bot) ConstMatrixMap((Complex*)(scratch_raw + offset), bond_dim, + 2 * bond_dim); + partial_contract2.noalias() = partial_contract * bot.conjugate(); + + // reshape: + new (&partial_contract2) + MatrixMap((Complex*)(state_raw + end), 2 * bond_dim, bond_dim); + + // Merge top into partial_contract2. + new (&top) ConstMatrixMap((Complex*)(scratch_raw + offset), 2 * bond_dim, + bond_dim); + partial_contract.noalias() = top.transpose() * partial_contract2; + } + + // Compute RDM-(n-1) and sample. + offset = GetBlockOffset(state, num_qubits - 1); + new (&partial_contract2) + MatrixMap((Complex*)(state_raw + end), bond_dim, 2); + + new (&top) ConstMatrixMap((Complex*)(state_raw + offset), bond_dim, 2); + partial_contract2.noalias() = partial_contract * top.conjugate(); + new (&partial_contract) MatrixMap((Complex*)rdm, 2, 2); + partial_contract.noalias() = top.transpose() * partial_contract2; + + p0 = rdm[0] / (rdm[0] + rdm[6]); + distribution = std::bernoulli_distribution(1 - p0); + bit_val = distribution(*random_gen); + sample->push_back(bit_val); + } + + // Draw num_samples bitstring samples from state and store the result + // bit vectors in results. Uses scratch and scratch2 as workspace. + static void Sample(MPS& state, MPS& scratch, MPS& scratch2, + unsigned num_samples, unsigned seed, + std::vector>* results) { + std::mt19937 rand_source(seed); + results->reserve(num_samples); + for (unsigned i = 0; i < num_samples; i++) { + SampleOnce(state, scratch, scratch2, &rand_source, &(*results)[i]); + } + } + // Testing only. Convert the MPS to a wavefunction under "normal" ordering. // Requires: wf be allocated beforehand with bond_dim * 2 ^ num_qubits -1 // memory. diff --git a/tests/mps_statespace_test.cc b/tests/mps_statespace_test.cc index 190f08df..1ea037c5 100644 --- a/tests/mps_statespace_test.cc +++ b/tests/mps_statespace_test.cc @@ -900,6 +900,220 @@ TEST(MPSStateSpaceTest, ReduceDensityMatrixLarge){ } +TEST(MPSStateSpaceTest, SampleOnceSimple){ + auto ss = MPSStateSpace(1); + auto mps = ss.Create(3, 4); + auto scratch = ss.Create(3, 4); + auto scratch2 = ss.Create(3, 4); + std::mt19937 rand_source(1234); + std::vector results; + + // Set to |100>. + results.clear(); + ss.SetStateZero(mps); + mps.get()[0] = 0; + mps.get()[8] = 1; + ss.SampleOnce(mps, scratch, scratch2, &rand_source, &results); + EXPECT_EQ(results[0], 1); + EXPECT_EQ(results[1], 0); + EXPECT_EQ(results[2], 0); + + // Set to |010>. + results.clear(); + ss.SetStateZero(mps); + mps.get()[16] = 0; + mps.get()[24] = 1; + ss.SampleOnce(mps, scratch, scratch2, &rand_source, &results); + EXPECT_EQ(results[0], 0); + EXPECT_EQ(results[1], 1); + EXPECT_EQ(results[2], 0); + + // Set to |001>. + results.clear(); + ss.SetStateZero(mps); + mps.get()[80] = 0; + mps.get()[82] = 1; + ss.SampleOnce(mps, scratch, scratch2, &rand_source, &results); + EXPECT_EQ(results[0], 0); + EXPECT_EQ(results[1], 0); + EXPECT_EQ(results[2], 1); + + // Set to |101>. + results.clear(); + ss.SetStateZero(mps); + mps.get()[0] = 0; + mps.get()[8] = 1; + mps.get()[80] = 0; + mps.get()[82] = 1; + ss.SampleOnce(mps, scratch, scratch2, &rand_source, &results); + EXPECT_EQ(results[0], 1); + EXPECT_EQ(results[1], 0); + EXPECT_EQ(results[2], 1); +} + +TEST(MPSStateSpaceTest, SampleGHZ){ + const int num_samples = 10000; + auto ss = MPSStateSpace(1); + auto mps = ss.Create(3, 4); + auto scratch = ss.Create(3, 4); + auto scratch2 = ss.Create(3, 4); + std::vector> results( + num_samples, std::vector({})); + + memset(mps.get(), 0, ss.RawSize(mps)); + mps.get()[0] = 1; + mps.get()[10] = 1; + mps.get()[16] = 1; + mps.get()[42] = -1; + mps.get()[80] = 0.70710677; + mps.get()[86] = -0.70710677; + + float count = 0; + ss.Sample(mps, scratch, scratch2, num_samples, 1234, &results); + for(int i = 0 ; i < num_samples; i++){ + bool all_same = 1; + all_same &= results[i][0] == results[i][1]; + all_same &= results[i][1] == results[i][2]; + EXPECT_EQ(all_same, 1); + count += results[i][0]; + EXPECT_EQ(results[i].size(), 3); + } + EXPECT_NEAR(count / float(num_samples), 0.5, 1e-2); +} + +TEST(MPSStateSpaceTest, SampleComplex){ + const int num_samples = 10000; + auto ss = MPSStateSpace(1); + auto mps = ss.Create(4, 4); + auto scratch = ss.Create(4, 4); + auto scratch2 = ss.Create(4, 4); + std::vector> results( + num_samples, std::vector({})); + + memset(mps.get(), 0, ss.RawSize(mps)); + mps.get()[ 0 ] = -0.4917038696869799 ; + mps.get()[ 1 ] = 0.016731957658280873 ; + mps.get()[ 2 ] = 0.86132663373237 ; + mps.get()[ 3 ] = 0.12674293823327035 ; + mps.get()[ 8 ] = -0.5023020703950029 ; + mps.get()[ 9 ] = -0.711083648814302 ; + mps.get()[ 10 ] = -0.20727818303023368 ; + mps.get()[ 11 ] = -0.4461932766843352 ; + mps.get()[ 16 ] = 0.15655121570640956 ; + mps.get()[ 17 ] = 0.4732738079187066 ; + mps.get()[ 18 ] = -0.08511634068671248 ; + mps.get()[ 19 ] = 0.4509108800471812 ; + mps.get()[ 20 ] = 0.3399824326377983 ; + mps.get()[ 21 ] = 0.26456637633430585 ; + mps.get()[ 22 ] = 0.5923848721836553 ; + mps.get()[ 23 ] = -0.06659540240231236 ; + mps.get()[ 24 ] = 0.3386920440520109 ; + mps.get()[ 25 ] = -0.5078386788732782 ; + mps.get()[ 26 ] = -0.5938438138167242 ; + mps.get()[ 27 ] = -0.2253530600030204 ; + mps.get()[ 28 ] = -0.08439705180650249 ; + mps.get()[ 29 ] = 0.18289872169116567 ; + mps.get()[ 30 ] = 0.33989833066754255 ; + mps.get()[ 31 ] = -0.2604753706869852 ; + mps.get()[ 32 ] = 0.3013840839514031 ; + mps.get()[ 33 ] = -0.10757629710841352 ; + mps.get()[ 34 ] = -0.043855659850960294 ; + mps.get()[ 35 ] = -0.0999497956398576 ; + mps.get()[ 36 ] = 0.6336147397284169 ; + mps.get()[ 37 ] = 0.43658807519265264 ; + mps.get()[ 38 ] = -0.448346536528476 ; + mps.get()[ 39 ] = 0.30428652791930944 ; + mps.get()[ 40 ] = 0.2954131683108271 ; + mps.get()[ 41 ] = -0.4349910681437736 ; + mps.get()[ 42 ] = 0.35640542464599323 ; + mps.get()[ 43 ] = 0.4970533197510696 ; + mps.get()[ 44 ] = -0.37101487814696105 ; + mps.get()[ 45 ] = 0.2100308254832807 ; + mps.get()[ 46 ] = 0.10591704897593116 ; + mps.get()[ 47 ] = 0.3955295090226334 ; + mps.get()[ 80 ] = -0.24953341864058454 ; + mps.get()[ 81 ] = 0.0 ; + mps.get()[ 82 ] = -0.5480093086703182 ; + mps.get()[ 83 ] = -0.20497358945530025 ; + mps.get()[ 84 ] = -1.1887516198406813e-16 ; + mps.get()[ 85 ] = 3.714848812002129e-18 ; + mps.get()[ 88 ] = 0.6045663379213811 ; + mps.get()[ 89 ] = -0.3501271865840065 ; + mps.get()[ 90 ] = -0.29968140886676936 ; + mps.get()[ 91 ] = 0.40493683779718603 ; + mps.get()[ 96 ] = 0.3073334814703704 ; + mps.get()[ 97 ] = 0.0 ; + mps.get()[ 98 ] = 0.07297353820052123 ; + mps.get()[ 99 ] = -0.2859132301813451 ; + mps.get()[ 100 ] = -1.7214471606144266e-16 ; + mps.get()[ 101 ] = 5.379522376920083e-18 ; + mps.get()[ 104 ] = -0.18689238699414557 ; + mps.get()[ 105 ] = -0.4911602105890581 ; + mps.get()[ 106 ] = -0.30326863844349566 ; + mps.get()[ 107 ] = -0.22667282775953723 ; + mps.get()[ 112 ] = -0.10881711525857803 ; + mps.get()[ 113 ] = 0.0 ; + mps.get()[ 114 ] = -0.146152770590198 ; + mps.get()[ 115 ] = 0.2149415742117364 ; + mps.get()[ 116 ] = -4.72314539505504e-16 ; + mps.get()[ 117 ] = 1.1519866817207415e-17 ; + mps.get()[ 120 ] = -0.01567698028444534 ; + mps.get()[ 121 ] = 0.013440646849502781 ; + mps.get()[ 122 ] = -0.17367051562799563 ; + mps.get()[ 123 ] = -0.24954843447516284 ; + mps.get()[ 128 ] = 0.24030153622040965 ; + mps.get()[ 129 ] = 0.0 ; + mps.get()[ 130 ] = -0.08309837568058188 ; + mps.get()[ 131 ] = 0.07924116582885271 ; + mps.get()[ 132 ] = -7.075275311738327e-17 ; + mps.get()[ 133 ] = 3.930708506521293e-18 ; + mps.get()[ 136 ] = 0.0725269370009367 ; + mps.get()[ 137 ] = 0.06123701427497634 ; + mps.get()[ 138 ] = -0.006630682493419155 ; + mps.get()[ 139 ] = 0.015491880670142021 ; + mps.get()[ 144 ] = -0.021403127627426542 ; + mps.get()[ 145 ] = 0.04422341855596844 ; + mps.get()[ 146 ] = 0.27602112861704176 ; + mps.get()[ 147 ] = 0.7790060986745896 ; + mps.get()[ 148 ] = 0.25252680029727903 ; + mps.get()[ 149 ] = 0.49967041792054084 ; + mps.get()[ 150 ] = -0.031679241045523554 ; + mps.get()[ 151 ] = -0.010202895067710558 ; + + ss.Sample(mps, scratch, scratch2, num_samples, 12345, &results); + std::vector expected({ + 0.036801, + 0.040697, + 0.002013, + 0.064595, + 0.014892, + 0.082028, + 0.008521, + 0.168310, + 0.022078, + 0.005907, + 0.024806, + 0.189074, + 0.090056, + 0.023125, + 0.116683, + 0.110406 + }); + std::vector hist(16, 0); + for(int i =0;i