Implement experimental general get_params_to_statevector
#121
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Implement experimental general
get_params_to_statevector
Happy new year everyone!
This PR implements a version of the
get_params_to_statevector
function that supports more generaloperations beyond just application of parameterised gates to an initial state. The features added
are as follows, including a link to further down in the PR where I explain the rationale behind
them:
application
mid-circuit-measurements)
dictionaries
I created a
qujax.experimental
module with an unstable API that can have breaking changes. Thatway, we can discuss any changes or suggestions without delaying a merge into
develop
too much.Finally, I added a notebook
experimental/noise_channel_monte_carlo.ipynb
illustrating some ofthese new features. This notebook simulates a noisy quantum circuit using a Monte-Carlo/quantum
trajectories approach.
The documentation has also been updated and is rendering correctly.
Summary of features of the new
get_params_to_statevector
Support for general operations
The new
get_params_to_statevector
function can now support arbitrary operations beyond justcontracting the statetensor being operated on by the circuit with some tensor. The reasoning behind
this feature is that some relevant applications can not be implemented by applying tensors to the
statetensor.
An example of this is mid-circuit measurements. Since the measurement probabilities are given by the
statetensor at the time the measurement is made, there is no way to perform the measurement without
it depending on the statetensor itself, which is beyond what qujax can currently support. Another
example is if the user wants to print debug information (using e.g. jax.debug.print) about the
statetensor halfway through the circuit. There are more examples, and any operation that directly
depends on the value that the statetensor itself takes at some point in the circuit can not be
implemented by the current
get_params_to_statevector
.More generally, this will greatly lift restrictions on what the user can do using qujax and ensure
that applications that were not originally considered can still be covered without the user having
to hack them into the package.
The new
get_params_to_statevector
function has five parameters:op_seq
- Sequence of operations to be performed. This is similar to the thegate_seq
functionin the old
get_params_to_statevector
except the strings in the sequence can represent anyoperation. Functions and arrays in the sequence still represent gates.
op_metaparams_seq
- Sequence of metaparameters to each operation inop_seq
. This generalisesthe old
qubit_inds_seq
, in the sense that one can think of gate application as an operation thattakes the qubit indices as a metaparameter. General operations will take more general metaparameters
(see e.g. the
"ConditionalGate"
operation in the new notebook for an example). The differencebetween a metaparameter and a parameter is that a metaparameter is fixed when
get_params_to_statevector
is called (andparams_to_statevector
is then latercompiled/differentiated by JAX with that fixed value), while a parameter is later passed to the
params_to_statevector
function.param_pos_seq
- Largely unchanged from the oldget_params_to_statevector
, specifies thepositions of the parameters passed into the
params_to_statevector
function returned byget_params_to_statevector
. Main difference is that it now supports dictionaries, as explained afurther down in the PR.
op_dict
new - Dictionary specifying which operations are supported. Defaults to a dictionaryof default operations (currently specified in
get_default_operations
). Each dictionary entry is afunction taking a set of metaparameters and returning another function, which in turn has arguments
(parameters, statetensor_in, classical_register_in)
and returns(statetensor_out, classical_register_out)
.gate_dict
new - Dictionary specifying which gates are supported. Defaults to the entries inqujax.gates
.A custom operation can be implemented by defining a new function and passing it as an
op_dict
entry. It can then be used by passing the respective string as an entry in
op_seq
and therespective metaparameters in
op_metaparams_seq
.Inclusion of a classical register
The new
get_params_to_statevector
now takes, and returns, a classical register (represented by ajax.Array
) alongside the statetensor. This allows for values to be recorded during circuitexecution and reused later in the circuit, which allows for implementing e.g. conditional gates
which depend on a measurement.
Specifying parameters using dictionaries
One can now specify parameters using dictionaries of
jax.Array
s instead of just ajax.Array
. Thereasoning behind this is that there exist parameters that can be fundamentally very different, and
separating them out makes specifying and tracking indices in the
param_pos_seq
argumentsignificantly easier. An example of this is the
"ConditionalGate"
operation, which both takes anindex specifying a gate out of many to apply, while the gates themselves also have parameters (see
e.g.
test_parameterised_stochasticity
intest_experimental.py
for a specific instance of this).The old way of specifying parameters using a jax.Array is unchanged and is still supported.
Using dictionaries can also make it easier to separate parameters into non-differentiable and
differentiable parameters (note that
jax.grad
always differentiates with respect some specifiedpositional arguments, so if all parameters are passed as one monolithic array they need to be
separated out if we are only interested in differentiating with respect to some).
Note that the fact that we can work with dictionaries does not interfere with JAX (which can work
with general PyTrees).
Custom gate dictionary
The user can now specify a custom gate dictionary which replaces/adds to the gates in
qujax.gates
.This allows for easier serialisation (
op_seq
can now be fully made into a sequence of strings andnot have any
jax.Array
s or functions in it). It also enables the user to change to an alternativeset of gates (or restrict to a subset of gates) and have qujax directly enforce that, and is useful
in e.g. changing between gate definitions if needed.
To do
measurements)
print_circuit
function to work with this newget_params_to_statevector
To discuss
dict-based approach introduced above?
"ConditionalGate"
operation implemented which applies one out of many gates based on an index (see e.g.
experimental/noise_channel_monte_carlo.ipynb
)get_params_to_statevector
by this one, ormerge this into the stable API as a different function (e.g.
get_generic_params_to_statevector
)?If the former, are we open to breaking changes or do we want to make it fully backwards compatible?
get_params_to_densitytensor_func
function in the same way?We don't need to decide on everything before this gets merged since this new function is under the
experimental API, so the discussion can continue until we are good to merge it into the stable API
in a release or two.