-
Notifications
You must be signed in to change notification settings - Fork 159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds working bitstring Sample function to MPS. #490
Adds working bitstring Sample function to MPS. #490
Conversation
lib/mps_statespace.h
Outdated
// Sample left block. | ||
ReduceDensityMatrix(scratch, scratch2, 0, rdm); | ||
auto p0 = rdm[0] / (rdm[0] + rdm[6]); | ||
std::bernoulli_distribution distribution(1 - p0); | ||
auto bit_val = distribution(*random_gen); | ||
|
||
sample->push_back(bit_val); | ||
MatrixMap tensor_block((Complex*)scratch_raw, 2, bond_dim); | ||
tensor_block.row(!bit_val).setZero(); | ||
tensor_block.imag() *= -1; | ||
|
||
// Sample internal blocks. | ||
for (unsigned i = 1; i < num_qubits - 1; i++) { | ||
ReduceDensityMatrix(scratch, scratch2, i, rdm); | ||
p0 = rdm[0] / (rdm[0] + rdm[6]); | ||
distribution = std::bernoulli_distribution(1 - p0); | ||
bit_val = distribution(*random_gen); | ||
|
||
sample->push_back(bit_val); | ||
const auto mem_start = GetBlockOffset(scratch, i); | ||
new (&tensor_block) MatrixMap((Complex*)(scratch_raw + mem_start), | ||
bond_dim * 2, bond_dim); | ||
for (unsigned j = !bit_val; j < 2 * bond_dim; j += 2) { | ||
tensor_block.row(j).setZero(); | ||
} | ||
tensor_block.imag() *= -1; | ||
} | ||
|
||
// Sample right block. | ||
ReduceDensityMatrix(scratch, scratch2, num_qubits - 1, rdm); | ||
p0 = rdm[0] / (rdm[0] + rdm[6]); | ||
distribution = std::bernoulli_distribution(1 - p0); | ||
bit_val = distribution(*random_gen); | ||
sample->push_back(bit_val); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @MichaelBroughton for adding this! If I understand the implementation correctly, you can save time from n^2 to n where n is the number of sites by storing intermediate contractions. That is, first compute \rho_{1,...,n-1} (rdm obtained by tracing out the last qubit) and store it in memory, then from this compute \rho_{1,...,n-2} and store it in memory, then from this ... compute \rho_{1}. Then sample bits from 1 to n, again reusing these already-computed rdms.
I hope that makes sense... I can't find a reference implementation right now. Basically the idea is not to contract edges more than once - e.g., currently to get \rho_{1} and \rho_{2}, the last n - 2 edges of the MPS are individually contracted twice, but you can avoid this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @rmlarose I'm not sure I follow. Can you write out the approach here quickly using something like python and einsum ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, looking back I think you may already be doing what I suggested since scratch
and scratch2
get modified in each call to ReduceDensityMatrix
.
For an n=3 MPS, the sequence of rdms should be:
| | |
@--@--@ (rho_123)
| | |
| |
@--@ (rho_12 obtained from rho_123)
| |
|
@ (rho_1 obtained from rho_12, *not* from rho_123)
|
That is, for each i
you want to obtain rho_1...i
from rho_1...(i + 1)
, not from rho_1...n
. If you are already doing this (it's not 100% clear to me how scratch
and scratch2
get modified in ReduceDensityMatrix
), you can ignore my comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not entirely sure I see how this contraction process is O(n) and not O(n^2). The contraction process in this PR is this:
- Computes RDM on the first qubit.
- Toss a coin with probabilities set by the RDM.
- Projects the qubit to zero or one state depening on coin outcome.
- Repeat this process for qubits [1,....n].
Step 1 in this operation always considers tensors [0,...n]. You compute n RDM matrices at the cost of n contractions each so it's O(n^2) to draw a single bitstring. Looking at your proposal IIUC you do this:
- Compute RDM on the first qubit.
- Toss a coin with probability set by the RDM.
- Project the qubit to zero or one state depending on the coin outcome.
- As I move to contracting the next qubit after the one I just did I re-use some of the contractions I've done up until this point (on the left side only) to compute the RDM to prevent having to contract all the way from the first qubit every time.
Suppose I have this MPS:
│ │ │ │ │
┌─┴─┐ ┌─┴─┐ ┌─┴─┐ ┌─┴─┐ ┌─┴─┐
│ 0 ├─┤ 1 ├─┤ 2 ├─┤ 3 ├─┤ 4 │
└───┘ └───┘ └───┘ └───┘ └───┘
and part way through my sampling process I need to compute the RDM on qubit 3 (having already sampled and projected down qubits 0,1 and 2):
│
┌───┐ ┌───┐ ┌───┐ ┌─┴─┐ ┌───┐
│ 0 ├─┤ 1 ├─┤ 2 ├─┤ 3 ├─┤ 4 │
└─┬─┘ └─┬─┘ └─┬─┘ └───┘ └─┬─┘
│ │ │ │
┌─┴─┐ ┌─┴─┐ ┌─┴─┐ ┌───┐ ┌─┴─┐
│ 0 ├─┤ 1 ├─┤ 2 ├─┤ 3 ├─┤ 4 │
└───┘ └───┘ └───┘ └─┬─┘ └───┘
│
With my approach I would just naively recontract everything together. IIUC your approach would keep another temporary variable containing the contractions of the projections from tensor 0 up to i - 1 if I want to compute the RDM for qubit i. This way I can use the fact that I have previously computed:
│
┌───┐ ┌───┐ ┌─┴─┐ ┌───┐ ┌───┐
│ 0 ├─┤ 1 ├─┤ 2 ├─┤ 3 ├─┤ 4 │
└─┬─┘ └─┬─┘ └───┘ └─┬─┘ └─┬─┘
│ │ │ │
┌─┴─┐ ┌─┴─┐ ┌───┐ ┌─┴─┐ ┌─┴─┐
│ 0 ├─┤ 1 ├─┤ 2 ├─┤ 3 ├─┤ 4 │
└───┘ └───┘ └─┬─┘ └───┘ └───┘
│
Where I could store this:
│
┌───┐ ┌─┴─┐ ┌───┐ ┌───┐
│ ├─┤ 2 ├─┤ 3 ├─┤ 4 │
│ │ └───┘ └─┬─┘ └─┬─┘
│0,1│ │ │
│ │ ┌───┐ ┌─┴─┐ ┌─┴─┐
│ ├─┤ 2 ├─┤ 3 ├─┤ 4 │
└───┘ └─┬─┘ └───┘ └───┘
│
And then (from the diagram above) I can contract 2 into 0,1 -> 0,1,2 and have the entire left side of my 3 RDM calculation done by simply remembering some of the computations I did from the previous step. This means on iteration number:
- I want to compute RDM-0 I contract n-1 tensors on the right of tensor 0 and none on the left.
- I want to compute RDM-1 I contract n-2 tensors on the right of tensor 1 and just one on the left (creating my history variable)
- I want to compute RDM-2 I contract n-3 tensors on the right of tensor 2 and just one on the left (applying one contraction onto my history variable)
- I want to compute RDM-3 I contract n-4 tensors on the right of tensor 3 and just one on the left (applying one contraction to my history variable)
.
.
.
There will be n steps where step i requires n-i+1 contractions -> \sum_i^n (n-i) -> n * (n - 1) / 2 -> O(n^2) with a prefactor of 1/2 instead of a prefactor of 1. How do I get this down to O(n) ? You can't "cache" tensor contractions on the right side of your RDM index as you go using this same approach can you ?
If we can get it down to O(n) I'd say it's worth updating in this PR, but if it's just the prefactor of 2 I'd say that's fine to come in a subsequent PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @MichaelBroughton for the very detailed and clear response. It seems my original comment does apply.
With my approach I would just naively recontract everything together. IIUC your approach would keep another temporary variable containing the contractions of the projections from tensor 0 up to i - 1 if I want to compute the RDM for qubit i.
Indeed this is precisely the idea. There is one important change to your sequence of diagrams to get to the O(n) algorithm. Below I sketch out the full O(n) sampling algorithm, which proceeds in two "phases".
Phase 1: Computing RDMs
Starting from the first step with an n=4 MPS, you have this tensor network:
| | | |
@--@--@--@
| | | |
First, you store this. (Note this is not a single-site RDM, which I think is the key difference from how you are thinking about this.) Then you connect the rightmost free legs and contract every edge on the last tensor to get the RDM on the first three sites:
| | |
@--@--@
| | |
Store this. Then connect the rightmost free legs and contract every edge on the last tensor to get the RDM on the first two sites:
| |
@--@
| |
Store this. Then connect...to get the RDM on the first site:
|
@
|
Store this. Note: There are n steps here, and each step requires contracting O(1) edges, yielding O(n) total contracted edges. The output of this step is the list of n "sub tensor networks" shown above.
Phase 2: Sampling
To sample, you iterate backwards through this list. Starting with the last "sub tensor network", namely the RDM on the first site
|
@
|
sample a bit a
.
Next, move to the previous "sub tensor network", namely the RDM on the first two sites, and attach a
(as a vector, |0> or |1>, corresponding to the bit) to the first site:
a
| |
@--@
| |
a
Contract this tensor network to get a 2x2 rdm, then sample a bit b
.
Next, move to the previous "sub tensor network", namely the RDM on the first three sites, and attach a
and b
(as vectors) to the first two sites:
a b
| | |
@--@--@
| | |
a b
Contract this tensor network to get a 2x2 rdm, then sample a bit c
. Note: In the previous step, the tensor network
a
|
@--
|
a
was already contracted. To make "Phase 2" also have O(n) contractions, you need to store and reuse this result.
Finally, repeat the same process to get the last bit. Namely, form
a b c
| | | |
@--@--@--@
| | | |
a b c
contract to get a 2x2 rdm, then sample a bit d
. Note again that the tensor network
a b
| |
@--@--
| |
a b
appearing in this diagram has already been contracted in the previous step.
Output abcd
. Note: This phase has n steps, each of which require O(1) contractions, yielding O(n) contractions. The total for both phase 1 and phase 2 then is O(n).
Pseudocode
If helpful, I sketched this procedure out in the following rough-around-the-edges pseudocode.
def sample(mps) -> List[int]:
"""Input mps is |psi><psi|:
| | | | (top_free_legs)
@--@--@--@ (tensors)
| | | | (bottom_free_legs)
"""
# Phase 1: Compute reduced density matrices rho_1, ..., rho_1...n.
n = number of qubits in mps
rdms = [mps] # List of tensor networks corresponding to each rdm
for i in range(n - 1, 0, -1):
mps = connect top_free_legs[i] and bottom_free_legs[i] then contract all legs on tensors[i]
rdms.append(mps)
# Phase 2: Sample bits.
bits: List[int] = []
for i in range(n):
rdm = rdms[-i - 1]
bit = sample a bit from rdm, a 2x2 density matrix up to normalization
bits.append(bit)
left_tensor_network = add bit to free legs of rdm and contract
replace i + 1 leftmost tensors in rdms[-i - 2] with left_tensor_network
return bits
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. This makes sense. One final question: I'm still not sold on the Phase 1 portion. Looking at the pseudocode here:
rdms = [mps] # List of tensor networks corresponding to each rdm
for i in range(n - 1, 0, -1):
mps = connect top_free_legs[i] and bottom_free_legs[i] then contract all legs on tensors[i]
rdms.append(mps)
rdms will contain initially an n site tensor network. After the first loop iteration it will contain an n site network and an n-1 site network etc. all the way down to a 1 site network. In terms of space requirements it feels like \sum_i^n i -> n*(n-1) / 2 -> O(n^2) has shown up again, only now in space and not time. Am I still missing something ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're absolutely right that phase 1, as I've described it, requires O(n^2) memory. However, after talking with a "famous" mps guy - who has formally kicked me out of the tensor network community for describing it this way 😄 - there's a better way to do phase 1 which only requires O(n) memory:
You still store the copy of the first tensor network (using the same n=4 example):
| | | |
u--v--t--w
| | | |
But now you just compute and store the following tensors by successively connecting top/bottom free indices and contracting:
--w
--(tw)
--(vtw)
Explicitly, as an example, --(tw)
is
i| |j
--t--w
i| |j
(i and j indicate these legs are connected, respectively).
The time is still O(n), but now the space is also O(n).
The only difference to phase 2 is how the RDMs are obtained. For example, in the first step you form
|
u--(vtw)
|
and contract to get the 2x2 density matrix, then sample a bit a
. Note that
|
u
|
is selected from the copy of the tensor network in the first step of phase 1.
Next you form
a
| |
u--v--(tw)
| |
a
and contract, sample, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright then. I've gone ahead and implemented this. It's a lot faster on the big systems which is great. I didn't look too carefully into profiling it for things like cache friendliness, but generally speaking it looks like it's spending a lot of time doing BLAS operations and moving data:
This might be a good follow up for anyone interested in trying to squeeze some more performance out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks to @rmlarose for the MPS review - I leave that portion in his capable hands.
Only a couple of notes from me on the other parts of this PR.
Working implementation and replacement for #460 . We should have everything we need now to implement all the operations for MPS in TFQ.
cc: @jaeyoo , @dstrain115