Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement experimental general get_params_to_statevector #121

Conversation

gamatos
Copy link
Collaborator

@gamatos gamatos commented Jan 10, 2024

Implement experimental general get_params_to_statevector

Happy new year everyone!

This PR implements a version of the get_params_to_statevector function that supports more general
operations beyond just application of parameterised gates to an initial state. The features added
are as follows, including a link to further down in the PR where I explain the rationale behind
them:

I created a qujax.experimental module with an unstable API that can have breaking changes. That
way, we can discuss any changes or suggestions without delaying a merge into develop too much.

Finally, I added a notebook experimental/noise_channel_monte_carlo.ipynb illustrating some of
these new features. This notebook simulates a noisy quantum circuit using a Monte-Carlo/quantum
trajectories approach.

The documentation has also been updated and is rendering correctly.

Summary of features of the new get_params_to_statevector

Support for general operations

The new get_params_to_statevector function can now support arbitrary operations beyond just
contracting the statetensor being operated on by the circuit with some tensor. The reasoning behind
this feature is that some relevant applications can not be implemented by applying tensors to the
statetensor.

An example of this is mid-circuit measurements. Since the measurement probabilities are given by the
statetensor at the time the measurement is made, there is no way to perform the measurement without
it depending on the statetensor itself, which is beyond what qujax can currently support. Another
example is if the user wants to print debug information (using e.g. jax.debug.print) about the
statetensor halfway through the circuit. There are more examples, and any operation that directly
depends on the value that the statetensor itself takes at some point in the circuit can not be
implemented by the current get_params_to_statevector.

More generally, this will greatly lift restrictions on what the user can do using qujax and ensure
that applications that were not originally considered can still be covered without the user having
to hack them into the package.

The new get_params_to_statevector function has five parameters:

  • op_seq - Sequence of operations to be performed. This is similar to the the gate_seq function
    in the old get_params_to_statevector except the strings in the sequence can represent any
    operation. Functions and arrays in the sequence still represent gates.
  • op_metaparams_seq - Sequence of metaparameters to each operation in op_seq. This generalises
    the old qubit_inds_seq, in the sense that one can think of gate application as an operation that
    takes the qubit indices as a metaparameter. General operations will take more general metaparameters
    (see e.g. the "ConditionalGate" operation in the new notebook for an example). The difference
    between a metaparameter and a parameter is that a metaparameter is fixed when
    get_params_to_statevector is called (and params_to_statevector is then later
    compiled/differentiated by JAX with that fixed value), while a parameter is later passed to the
    params_to_statevector function.
  • param_pos_seq - Largely unchanged from the old get_params_to_statevector, specifies the
    positions of the parameters passed into the params_to_statevector function returned by
    get_params_to_statevector. Main difference is that it now supports dictionaries, as explained a
    further down in the PR.
  • op_dict new - Dictionary specifying which operations are supported. Defaults to a dictionary
    of default operations (currently specified in get_default_operations). Each dictionary entry is a
    function taking a set of metaparameters and returning another function, which in turn has arguments
    (parameters, statetensor_in, classical_register_in) and returns (statetensor_out, classical_register_out).
  • gate_dict new - Dictionary specifying which gates are supported. Defaults to the entries in
    qujax.gates.

A custom operation can be implemented by defining a new function and passing it as an op_dict
entry. It can then be used by passing the respective string as an entry in op_seq and the
respective metaparameters in op_metaparams_seq.

Inclusion of a classical register

The new get_params_to_statevector now takes, and returns, a classical register (represented by a
jax.Array) alongside the statetensor. This allows for values to be recorded during circuit
execution and reused later in the circuit, which allows for implementing e.g. conditional gates
which depend on a measurement.

Specifying parameters using dictionaries

One can now specify parameters using dictionaries of jax.Arrays instead of just a jax.Array. The
reasoning behind this is that there exist parameters that can be fundamentally very different, and
separating them out makes specifying and tracking indices in the param_pos_seq argument
significantly easier. An example of this is the "ConditionalGate" operation, which both takes an
index specifying a gate out of many to apply, while the gates themselves also have parameters (see
e.g. test_parameterised_stochasticity in test_experimental.py for a specific instance of this).

The old way of specifying parameters using a jax.Array is unchanged and is still supported.

Using dictionaries can also make it easier to separate parameters into non-differentiable and
differentiable parameters (note that jax.grad always differentiates with respect some specified
positional arguments, so if all parameters are passed as one monolithic array they need to be
separated out if we are only interested in differentiating with respect to some).

Note that the fact that we can work with dictionaries does not interfere with JAX (which can work
with general PyTrees).

Custom gate dictionary

The user can now specify a custom gate dictionary which replaces/adds to the gates in qujax.gates.

This allows for easier serialisation (op_seq can now be fully made into a sequence of strings and
not have any jax.Arrays or functions in it). It also enables the user to change to an alternative
set of gates (or restrict to a subset of gates) and have qujax directly enforce that, and is useful
in e.g. changing between gate definitions if needed.

To do

  • Add more tests
  • Add notebook on mid-circuit measurement
  • Add notebook on conditional operations (i.e. operations that can depend on value of
    measurements)
  • Adapt print_circuit function to work with this new get_params_to_statevector

To discuss

  • Would we want to deprecate specifying gates through functions/arrays in favour of the
    dict-based approach introduced above?
  • Which operations should we support by default? Currently there is a "ConditionalGate"
    operation implemented which applies one out of many gates based on an index (see e.g.
    experimental/noise_channel_monte_carlo.ipynb)
  • Would we want to replace the old implementation of get_params_to_statevector by this one, or
    merge this into the stable API as a different function (e.g. get_generic_params_to_statevector)?
    If the former, are we open to breaking changes or do we want to make it fully backwards compatible?
  • Should we also generalise the get_params_to_densitytensor_func function in the same way?

We don't need to decide on everything before this gets merged since this new function is under the
experimental API, so the discussion can continue until we are good to merge it into the stable API
in a release or two.

@gamatos gamatos requested a review from gabrielcqc January 10, 2024 15:45
@gamatos gamatos force-pushed the general_statevector_operations branch 4 times, most recently from 8ee67df to 0daf431 Compare January 10, 2024 15:59
* Supports general operations beyond parameterized gate application
* Allows for specification of parameters using dictionaries
* Adds a classical register for storage of intermediate computations
(e.g. mid-circuit-measurements)
* Allows the user to define new gates identified by a string

In addition:
* Introduce unstable experimental API in experimental.py
* Add `./examples/experimental` folder for notebooks using unstable API
* Add `experimental/noise_channel_monte_carlo.ipynb` with example
of simulating noise using Monte-Carlo/quantum trajectories
* Add test_experimental.py for testing unstable API
* Add `ParameterizedGateFunction`, `UnparameterizedGateFunction` types
@gamatos gamatos force-pushed the general_statevector_operations branch from 0daf431 to e98c57b Compare January 10, 2024 16:02
Copy link
Collaborator

@gabrielcqc gabrielcqc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really good work!

As for the discussion topics, here are my views:

  • If it is compatible to have both things I would leave it like this at least for some time so people using qujax can still run their code. Maybe in the future transition to only using a dict-based approach and have clear examples on how to use them.
  • No opinion at the moment.
  • If the computational cost is similar with the old and new function I'd say replace it, but if it is slower or requires more resources in any way maybe having two different functions.
  • No opinion at the moment.

@gamatos gamatos changed the base branch from develop to general_statevector_operations March 20, 2024 17:53
@gamatos
Copy link
Collaborator Author

gamatos commented Mar 20, 2024

I ended up delaying merging this PR into the develop branch until the features it implements are more fleshed out; I created a new branch called general_statevector_operations and will merge it there instead. I will open another PR soon with more tests and operations implemented, along with one or two more notebooks illustrating their use. We can then do another code review before we merge it into develop.

@gamatos gamatos merged commit 6185d2b into CQCL:general_statevector_operations Mar 20, 2024
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants