Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add logo, update documentation #110

Merged
merged 3 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
sphinx ~= 7.2
sphinx_rtd_theme
myst-parser ~= 2.0
80 changes: 12 additions & 68 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,25 +1,21 @@
# qujax

<div align="center">
<a href="https://cqcl.github.io/qujax/"><img src="docs/logo.svg" alt="logo"></img></a>
</div>

[![PyPI - Version](https://img.shields.io/pypi/v/qujax)](https://pypi.org/project/qujax/)
[![DOI](https://joss.theoj.org/papers/10.21105/joss.05504/status.svg)](https://doi.org/10.21105/joss.05504)

* [Installation](#installation)
* [Quick start](#quick-start)
+ [Pure state simulation](#pure-state-simulation)
+ [Mixed state simulation](#mixed-state-simulation)
* [Converting from TKET](#converting-from-tket)
* [Examples](#examples)
* [Contributing](#contributing)
* [Citing qujax](#citing-qujax)
* [API Reference](https://cqcl.github.io/qujax/)
[**Documentation**](https://cqcl.github.io/qujax/) | [**Installation**](#installation) | [**Quick start**](#quick-start) | [**Examples**](https://cqcl.github.io/qujax/examples.html) | [**Contributing**](#contributing) | [**Citing qujax**](#citing-qujax)

qujax is a [JAX](https://github.com/google/jax)-based Python library for the classical simulation of quantum circuits. It is designed to be *simple*, *fast* and *flexible*.

It follows a functional programming design by translating circuits into pure functions. This allows qujax to [seamlessly interface with JAX](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions), enabling direct access to its powerful automatic differentiation tools, just-in-time compiler, vectorization capabilities, GPU/TPU integration and growing ecosystem of packages.

qujax can be used both for pure and for mixed quantum state simulation. It not only supports the standard gate set, but also allows user-defined custom operations, including general quantum channels, enabling the user to e.g. model device noise and errors.

An overview of the core functionalities of qujax can be found in the [Quick start](#quick-start) section. More advanced use-cases, including the training of parameterised quantum circuits, are listed in [Examples](#examples).
A summary of the core functionalities of qujax can be found in the [Quick start](#quick-start) section. More advanced use-cases, including the training of parameterised quantum circuits, can be found in the [Examples](https://cqcl.github.io/qujax/examples.html) section of the documentation.


## Installation
Expand All @@ -33,9 +29,7 @@ pip install qujax

**Important note: qujax circuit parameters are expressed in units of $\pi$ (e.g. in the range $[0,2]$ as opposed to $[0, 2\pi]$)**.

### Pure state simulation

We start by defining the quantum gates making up the circuit, along with the qubits that they act on and the indices of the parameters for each gate.
Start by defining the quantum gates making up the circuit, the qubits that they act on, and the indices of the parameters for each gate.

A list of all gates can be found [here](https://github.com/CQCL/qujax/blob/main/qujax/gates.py) (custom operations can be included by [passing an array or function](https://cqcl.github.io/qujax/statetensor/get_params_to_statetensor_func.html) instead of a string).

Expand All @@ -56,7 +50,7 @@ qujax.print_circuit(circuit_gates, circuit_qubit_inds, circuit_params_inds);
# q1: ---------------------CZ--
```

We then translate the circuit to a pure function `param_to_st` that takes a set of parameters and an (optional) initial quantum state as its input.
Translate the circuit to a pure function `param_to_st` that takes a set of parameters and an (optional) initial quantum state as its input.

```python
param_to_st = qujax.get_params_to_statetensor_func(circuit_gates,
Expand All @@ -70,18 +64,7 @@ param_to_st(jnp.array([0.1]))

The optional initial state can be passed to `param_to_st` using the `statetensor_in` argument. When it is not provided, the initial state defaults to $\ket{0...0}$.

Note that qujax represents quantum states as _statetensors_. For example, for $N=4$ qubits, the corresponding vector space has $2^4$ dimensions, and a quantum state in this space is represented by an array with shape `(2,2,2,2)`. The usual statevector representation with shape `(16,)` can be obtained by calling `.flatten()` or `.reshape(-1)` or `.reshape(2**N)` on this array.

In the statetensor representation, the coefficient associated with e.g. basis state $\ket{0101}$ is given by `arr[0,1,0,1]`; each axis corresponds to one qubit.

```python
param_to_st(jnp.array([0.1])).flatten()
# Array([0.58778524+0.j, 0.+0.j, 0.80901706+0.j, 0.+0.j], dtype=complex64)
```

Finally, by defining an observable, we can map the statetensor to an expectation value. A general observable is specified using lists of Pauli matrices, the qubits they act on, and the associated coefficients.

For example, $Z_1Z_2Z_3Z_4 - 2 X_3$ would be written as `[['Z','Z','Z','Z'], ['X']], [[1,2,3,4], [3]], [1., -2.]`.
Map the state to an expectation value by defining an observable using lists of Pauli matrices, the qubits they act on, and the associated coefficients.

```python
st_to_expectation = qujax.get_statetensor_to_expectation_func([['Z']], [[0]], [1.])
Expand All @@ -99,52 +82,13 @@ expectation_and_grad(jnp.array([0.1]))
# Array([-2.987832], dtype=float32))
```

### Mixed state simulation
Mixed state simulations are analogous to the above, but with calls to `get_params_to_densitytensor_func` and `get_densitytensor_to_expectation_func` instead.
Mixed state simulations are analogous to the above, but with calls to [`get_params_to_densitytensor_func`](https://cqcl.github.io/qujax/densitytensor/get_params_to_densitytensor_func.html) and [`get_densitytensor_to_expectation_func`](https://cqcl.github.io/qujax/densitytensor/get_densitytensor_to_expectation_func.html) instead.

```python
param_to_dt = qujax.get_params_to_densitytensor_func(circuit_gates,
circuit_qubit_inds,
circuit_params_inds)
dt = param_to_dt(jnp.array([0.1]))
dt.shape
# (2, 2, 2, 2)

dt_to_expectation = qujax.get_densitytensor_to_expectation_func([['Z']], [[0]], [1.])
dt_to_expectation(dt)
# Array(-0.3090171, dtype=float32)
```

Similarly to a statetensor, which represents the reshaped $2^N$-dimensional statevector of a pure quantum state, a _densitytensor_ represents the reshaped $2^N \times 2^N$ density matrix of a mixed quantum state. This densitytensor has shape `(2,) * 2 * N`.

For example, for $N=2$, and a mixed state $\frac{1}{2} (\ket{00}\bra{11} + \ket{11}\bra{00} + \ket{11}\bra{11} + \ket{00}\bra{00})$, the corresponding densitytensor `dt` is such that `dt[0,0,1,1] = dt[1,1,0,0] = dt[1,1,1,1] = dt[0,0,0,0] = 1/2`, and all other entries are zero.

The equivalent density matrix can be obtained by calling `.reshape(2 ** N, 2 ** N)`.
A more in-depth version of the above can be found in the [Getting started](https://cqcl.github.io/qujax/getting_started.html) section of the documentation. More advanced use-cases, including the training of parameterised quantum circuits, can be found in the [Examples](https://cqcl.github.io/qujax/examples.html) section of the documentation.

## Converting from TKET

One can directly convert a [`pytket`](https://cqcl.github.io/tket/pytket/api/) circuit using the [`tk_to_qujax`](https://cqcl.github.io/pytket-qujax/api/api.html#pytket.extensions.qujax.qujax_convert.tk_to_qujax) and [`tk_to_qujax_symbolic`](https://cqcl.github.io/pytket-qujax/api/api.html#pytket.extensions.qujax.qujax_convert.tk_to_qujax_symbolic) functions in the [**`pytket-qujax`**](https://github.com/CQCL/pytket-qujax) extension.

An example of this can be found in the [`pytket-qujax_heisenberg_vqe.ipynb`](https://github.com/CQCL/pytket/blob/main/examples/pytket-qujax_heisenberg_vqe.ipynb) notebook.

## Examples

Below are some use-case notebooks. These both illustrate the flexibility of qujax and the power of directly interfacing with JAX and its package ecosystem.

- [`heisenberg_vqe.ipynb`](https://github.com/CQCL/qujax/blob/develop/examples/heisenberg_vqe.ipynb) - an implementation of the variational quantum eigensolver to find the ground state of a quantum Hamiltonian.
- [`maxcut_vqe.ipynb`](https://github.com/CQCL/qujax/blob/develop/examples/maxcut_vqe.ipynb) - an implementation of the variational quantum eigensolver to solve a MaxCut problem. Trains with Adam via [`optax`](https://github.com/deepmind/optax) and uses more realistic stochastic parameter shift gradients.
- [`noise_channel.ipynb`](https://github.com/CQCL/qujax/blob/develop/examples/noise_channel.ipynb) - uses the densitytensor simulator to fit the parameters of a depolarising noise channel.
- [`qaoa.ipynb`](https://github.com/CQCL/qujax/blob/develop/examples/qaoa.ipynb) - uses a problem-inspired QAOA ansatz to find the ground state of a quantum Hamiltonian. Demonstrates how to encode more sophisticated parameters that control multiple gates.
- [`barren_plateaus.ipynb`](https://github.com/CQCL/qujax/blob/develop/examples/barren_plateaus.ipynb) - illustrates how to sample gradients of a cost function to identify the presence of barren plateaus. Uses batched/vectorized evaluation to speed up computation.
- [`reducing_jit_compilation_time.ipynb`](https://github.com/CQCL/qujax/blob/develop/examples/reducing_jit_compilation_time.ipynb) - explains how JAX compilation works and how that can lead to excessive compilation times when executing quantum circuits. Presents a solution for the case of circuits with a repeating structure.
- [`variational_inference.ipynb`](https://github.com/CQCL/qujax/blob/develop/examples/variational_inference.ipynb) - uses a parameterised quantum circuit as a variational distribution to fit to a target probability mass function. Uses Adam via [`optax`](https://github.com/deepmind/optax) to minimise the KL divergence between circuit and target distributions.
- [`classification.ipynb`](https://github.com/CQCL/qujax/blob/develop/examples/classification.ipynb) - train a quantum circuit for binary classification using data re-uploading.
- [`generative_modelling.ipynb`](https://github.com/CQCL/qujax/blob/develop/examples/generative_modelling.ipynb) - uses a parameterised quantum circuit as a generative model for a real life dataset. Trains via stochastic gradient Langevin dynamics on the maximum mean discrepancy between statetensor and dataset.

The [`pytket`](https://github.com/CQCL/pytket) repository also contains `tk_to_qujax` implementations for some of the above at [`pytket-qujax_classification.ipynb`](https://github.com/CQCL/pytket/blob/main/examples/pytket-qujax-classification.ipynb),
[`pytket-qujax_heisenberg_vqe.ipynb`](https://github.com/CQCL/pytket/blob/main/examples/pytket-qujax_heisenberg_vqe.ipynb)
and [`pytket-qujax_qaoa.ipynb`](https://github.com/CQCL/pytket/blob/main/examples/pytket-qujax_qaoa.ipynb).

A [`pytket`](https://cqcl.github.io/tket/pytket/api/) circuit can be directly converted using the [`tk_to_qujax`](https://cqcl.github.io/pytket-qujax/api/api.html#pytket.extensions.qujax.qujax_convert.tk_to_qujax) and [`tk_to_qujax_symbolic`](https://cqcl.github.io/pytket-qujax/api/api.html#pytket.extensions.qujax.qujax_convert.tk_to_qujax_symbolic) functions in the [**`pytket-qujax`**](https://github.com/CQCL/pytket-qujax) extension. See [`pytket-qujax_heisenberg_vqe.ipynb`](https://github.com/CQCL/pytket/blob/main/examples/pytket-qujax_heisenberg_vqe.ipynb) for an example.

## Contributing

Expand Down
53 changes: 53 additions & 0 deletions docs/_static/css/custom.css
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@

.wy-nav-top{
background-color: #203847
}

.wy-side-nav-search{
background-color: white;
}

.icon.icon-home{
color: #000000
}

.wy-menu-vertical p.caption{
color: #85cfcb;
}

.wy-side-nav-search > a{
color: #000000;
}

.wy-side-nav-search > div.version{
color: #000000;
}

.sig {
background: #85cfcb;
}

.wy-nav-content {
max-width: 1000px;
}

html.writer-html4 .rst-content dl:not(.docutils) > dt, html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.citation):not(.glossary):not(.simple) > dt {
display:block;
background-color: #ebeef1;
}

#examples ul, ul.simple {
list-style: none;
}

#examples ul li, ul.simple li {
margin-bottom: 10px;
}

h1, h2, h3, h4, h5, h6 {
color: #203847
}

div.toctree-wrapper .caption-text{
color: #203847;
}
20 changes: 20 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"sphinx_rtd_theme",
"sphinx.ext.napoleon",
"sphinx.ext.mathjax",
"myst_parser",
]

templates_path = ["_templates"]
Expand All @@ -42,3 +43,22 @@
}

latex_engine = "pdflatex"

titles_only = True

rst_prolog = """
.. role:: python(code)
:language: python
"""

html_logo = "logo.svg"

html_static_path = ["_static"]
html_css_files = [
"css/custom.css",
]

html_theme_options = {
"collapse_navigation": False,
"prev_next_buttons_location": "None",
}
17 changes: 17 additions & 0 deletions docs/examples.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Examples

Below are some use-case notebooks. These both illustrate the flexibility of qujax and the power of directly interfacing with JAX and its package ecosystem.

- [heisenberg_vqe.ipynb](https://github.com/CQCL/qujax/blob/develop/examples/heisenberg_vqe.ipynb) - an implementation of the variational quantum eigensolver to find the ground state of a quantum Hamiltonian.
- [maxcut_vqe.ipynb](https://github.com/CQCL/qujax/blob/develop/examples/maxcut_vqe.ipynb) - an implementation of the variational quantum eigensolver to solve a MaxCut problem. Trains with Adam via [`optax`](https://github.com/deepmind/optax) and uses more realistic stochastic parameter shift gradients.
- [noise_channel.ipynb](https://github.com/CQCL/qujax/blob/develop/examples/noise_channel.ipynb) - uses the densitytensor simulator to fit the parameters of a depolarising noise channel.
- [qaoa.ipynb](https://github.com/CQCL/qujax/blob/develop/examples/qaoa.ipynb) - uses a problem-inspired QAOA ansatz to find the ground state of a quantum Hamiltonian. Demonstrates how to encode more sophisticated parameters that control multiple gates.
- [barren_plateaus.ipynb](https://github.com/CQCL/qujax/blob/develop/examples/barren_plateaus.ipynb) - illustrates how to sample gradients of a cost function to identify the presence of barren plateaus. Uses batched/vectorized evaluation to speed up computation.
- [reducing_jit_compilation_time.ipynb](https://github.com/CQCL/qujax/blob/develop/examples/reducing_jit_compilation_time.ipynb) - explains how JAX compilation works and how that can lead to excessive compilation times when executing quantum circuits. Presents a solution for the case of circuits with a repeating structure.
- [variational_inference.ipynb](https://github.com/CQCL/qujax/blob/develop/examples/variational_inference.ipynb) - uses a parameterised quantum circuit as a variational distribution to fit to a target probability mass function. Uses Adam via [`optax`](https://github.com/deepmind/optax) to minimise the KL divergence between circuit and target distributions.
- [classification.ipynb](https://github.com/CQCL/qujax/blob/develop/examples/classification.ipynb) - train a quantum circuit for binary classification using data re-uploading.
- [generative_modelling.ipynb](https://github.com/CQCL/qujax/blob/develop/examples/generative_modelling.ipynb) - uses a parameterised quantum circuit as a generative model for a real life dataset. Trains via stochastic gradient Langevin dynamics on the maximum mean discrepancy between statetensor and dataset.

The [pytket](https://github.com/CQCL/pytket) repository also contains `tk_to_qujax` implementations for some of the above at [pytket-qujax_classification.ipynb](https://github.com/CQCL/pytket/blob/main/examples/pytket-qujax-classification.ipynb),
[pytket-qujax_heisenberg_vqe.ipynb](https://github.com/CQCL/pytket/blob/main/examples/pytket-qujax_heisenberg_vqe.ipynb)
and [pytket-qujax_qaoa.ipynb](https://github.com/CQCL/pytket/blob/main/examples/pytket-qujax_qaoa.ipynb).
Loading