Skip to content

Commit

Permalink
Moving CuTensorNetHandle to the general.py at the root of the module
Browse files Browse the repository at this point in the history
  • Loading branch information
PabloAndresCQ committed Jun 4, 2024
1 parent 3954fec commit 26f848a
Show file tree
Hide file tree
Showing 9 changed files with 81 additions and 69 deletions.
1 change: 1 addition & 0 deletions pytket/extensions/cutensornet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Module for conversion from tket primitives to cuQuantum primitives."""

from .backends import CuTensorNetBackend
from .general import CuTensorNetHandle

# _metadata.py is copied to the folder after installation.
from ._metadata import __extension_version__, __extension_name__ # type: ignore
67 changes: 67 additions & 0 deletions pytket/extensions/cutensornet/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,76 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations # type: ignore
import warnings
import logging
from logging import Logger

from typing import Any, Optional

try:
import cupy as cp # type: ignore
except ImportError:
warnings.warn("local settings failed to import cupy", ImportWarning)
try:
import cuquantum.cutensornet as cutn # type: ignore
except ImportError:
warnings.warn("local settings failed to import cutensornet", ImportWarning)


class CuTensorNetHandle:
"""Initialise the cuTensorNet library with automatic workspace memory
management.
Note:
Always use as ``with CuTensorNetHandle() as libhandle:`` so that cuTensorNet
handles are automatically destroyed at the end of execution.
Attributes:
handle (int): The cuTensorNet library handle created by this initialisation.
device_id (int): The ID of the device (GPU) where cuTensorNet is initialised.
If not provided, defaults to ``cp.cuda.Device()``.
"""

def __init__(self, device_id: Optional[int] = None):
self._is_destroyed = False

# Make sure CuPy uses the specified device
dev = cp.cuda.Device(device_id)
dev.use()

self.dev = dev
self.device_id = dev.id

self.handle = cutn.create()

def destroy(self) -> None:
"""Destroys the memory handle, releasing memory.
Only call this method if you are initialising a ``CuTensorNetHandle`` outside
a ``with CuTensorNetHandle() as libhandle`` statement.
"""
cutn.destroy(self.handle)
self._is_destroyed = True

def __enter__(self) -> CuTensorNetHandle:
return self

def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
self.destroy()

def print_device_properties(self, logger: Logger) -> None:
"""Prints local GPU properties."""
device_props = cp.cuda.runtime.getDeviceProperties(self.dev.id)
logger.debug("===== device info ======")
logger.debug("GPU-name:", device_props["name"].decode())
logger.debug("GPU-clock:", device_props["clockRate"])
logger.debug("GPU-memoryClock:", device_props["memoryClockRate"])
logger.debug("GPU-nSM:", device_props["multiProcessorCount"])
logger.debug("GPU-major:", device_props["major"])
logger.debug("GPU-minor:", device_props["minor"])
logger.debug("========================")


def set_logger(
logger_name: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from sympy import Expr # type: ignore
from numpy.typing import NDArray
from pytket.circuit import Circuit
from pytket.extensions.cutensornet.general import set_logger
from pytket.extensions.cutensornet.structured_state import CuTensorNetHandle
from pytket.extensions.cutensornet.general import CuTensorNetHandle, set_logger
from pytket.utils.operators import QubitPauliOperator

try:
Expand Down
4 changes: 3 additions & 1 deletion pytket/extensions/cutensornet/structured_state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
https://github.com/CQCL/pytket-cutensornet.
"""

from .general import CuTensorNetHandle, Config, StructuredState
from pytket.extensions.cutensornet import CuTensorNetHandle

from .general import Config, StructuredState
from .simulation import SimulationAlgorithm, simulate, prepare_circuit_mps

from .mps import DirMPS, MPS
Expand Down
60 changes: 1 addition & 59 deletions pytket/extensions/cutensornet/structured_state/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from abc import ABC, abstractmethod
import warnings
import logging
from logging import Logger
from typing import Any, Optional, Type

import numpy as np # type: ignore
Expand All @@ -27,11 +26,8 @@
import cupy as cp # type: ignore
except ImportError:
warnings.warn("local settings failed to import cupy", ImportWarning)
try:
import cuquantum.cutensornet as cutn # type: ignore
except ImportError:
warnings.warn("local settings failed to import cutensornet", ImportWarning)

from pytket.extensions.cutensornet import CuTensorNetHandle

# An alias for the CuPy type used for tensors
try:
Expand All @@ -40,60 +36,6 @@
Tensor = Any


class CuTensorNetHandle:
"""Initialise the cuTensorNet library with automatic workspace memory
management.
Note:
Always use as ``with CuTensorNetHandle() as libhandle:`` so that cuTensorNet
handles are automatically destroyed at the end of execution.
Attributes:
handle (int): The cuTensorNet library handle created by this initialisation.
device_id (int): The ID of the device (GPU) where cuTensorNet is initialised.
If not provided, defaults to ``cp.cuda.Device()``.
"""

def __init__(self, device_id: Optional[int] = None):
self._is_destroyed = False

# Make sure CuPy uses the specified device
dev = cp.cuda.Device(device_id)
dev.use()

self.dev = dev
self.device_id = dev.id

self.handle = cutn.create()

def destroy(self) -> None:
"""Destroys the memory handle, releasing memory.
Only call this method if you are initialising a ``CuTensorNetHandle`` outside
a ``with CuTensorNetHandle() as libhandle`` statement.
"""
cutn.destroy(self.handle)
self._is_destroyed = True

def __enter__(self) -> CuTensorNetHandle:
return self

def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
self.destroy()

def print_device_properties(self, logger: Logger) -> None:
"""Prints local GPU properties."""
device_props = cp.cuda.runtime.getDeviceProperties(self.dev.id)
logger.debug("===== device info ======")
logger.debug("GPU-name:", device_props["name"].decode())
logger.debug("GPU-clock:", device_props["clockRate"])
logger.debug("GPU-memoryClock:", device_props["memoryClockRate"])
logger.debug("GPU-nSM:", device_props["multiProcessorCount"])
logger.debug("GPU-major:", device_props["major"])
logger.debug("GPU-minor:", device_props["minor"])
logger.debug("========================")


class Config:
"""Configuration class for simulation using ``StructuredState``."""

Expand Down
4 changes: 2 additions & 2 deletions pytket/extensions/cutensornet/structured_state/mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
from pytket.circuit import Command, Op, OpType, Qubit
from pytket.pauli import Pauli, QubitPauliString

from pytket.extensions.cutensornet.general import set_logger
from pytket.extensions.cutensornet.general import CuTensorNetHandle, set_logger

from .general import CuTensorNetHandle, Config, StructuredState, Tensor
from .general import Config, StructuredState, Tensor


class DirMPS(Enum):
Expand Down
3 changes: 2 additions & 1 deletion pytket/extensions/cutensornet/structured_state/mps_mpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
warnings.warn("local settings failed to import cutensornet", ImportWarning)

from pytket.circuit import Qubit
from .general import CuTensorNetHandle, Tensor, Config
from pytket.extensions.cutensornet import CuTensorNetHandle
from .general import Tensor, Config
from .mps import (
DirMPS,
MPS,
Expand Down
4 changes: 2 additions & 2 deletions pytket/extensions/cutensornet/structured_state/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
from pytket.passes import DefaultMappingPass
from pytket.predicates import CompilationUnit

from pytket.extensions.cutensornet.general import set_logger
from .general import CuTensorNetHandle, Config, StructuredState
from pytket.extensions.cutensornet.general import CuTensorNetHandle, set_logger
from .general import Config, StructuredState
from .mps_gate import MPSxGate
from .mps_mpo import MPSxMPO
from .ttn_gate import TTNxGate
Expand Down
4 changes: 2 additions & 2 deletions pytket/extensions/cutensornet/structured_state/ttn.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@
from pytket.circuit import Command, Qubit
from pytket.pauli import QubitPauliString

from pytket.extensions.cutensornet.general import set_logger
from pytket.extensions.cutensornet.general import CuTensorNetHandle, set_logger

from .general import CuTensorNetHandle, Config, StructuredState, Tensor
from .general import Config, StructuredState, Tensor


class DirTTN(IntEnum):
Expand Down

0 comments on commit 26f848a

Please sign in to comment.