Skip to content

Commit

Permalink
Merge branch 'main' into mangle-openqasm-reserved-words
Browse files Browse the repository at this point in the history
  • Loading branch information
rmshaffer authored Jun 12, 2024
2 parents 83dfd3c + 360cab1 commit 58dadab
Show file tree
Hide file tree
Showing 5 changed files with 214 additions and 1 deletion.
63 changes: 63 additions & 0 deletions src/autoqasm/converters/arithmetic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

"""Converters for aritmetic operator nodes"""

import ast

import gast
from malt.core import ag_ctx, converter
from malt.pyct import templates

ARITHMETIC_OPERATORS = {
gast.FloorDiv: "ag__.floor_div",
}


class ArithmeticTransformer(converter.Base):
"""Transformer for arithmetic nodes."""

def visit_BinOp(self, node: ast.stmt) -> ast.stmt:
"""Transforms a BinOp node.
Args :
node(ast.stmt) : AST node to transform
Returns :
ast.stmt : Transformed node
"""
node = self.generic_visit(node)
op_type = type(node.op)
if op_type not in ARITHMETIC_OPERATORS:
return node

template = f"{ARITHMETIC_OPERATORS[op_type]}(lhs_,rhs_)"

new_node = templates.replace(
template,
lhs_=node.left,
rhs_=node.right,
original=node,
)[0].value

return new_node


def transform(node: ast.stmt, ctx: ag_ctx.ControlStatusCtx) -> ast.stmt:
"""Transform arithmetic nodes.
Args:
node(ast.stmt) : AST node to transform
ctx (ag_ctx.ControlStatusCtx) : Transformer context.
Returns :
ast.stmt : Transformed node.
"""

return ArithmeticTransformer(ctx).visit(node)
1 change: 1 addition & 0 deletions src/autoqasm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from malt.impl.api import autograph_artifact # noqa: F401
from malt.operators.variables import Undefined, UndefinedReturnValue, ld, ldu # noqa: F401

from .arithmetic import floor_div # noqa: F401
from .assignments import assign_for_output, assign_stmt # noqa: F401
from .comparisons import gt_, gteq_, lt_, lteq_ # noqa: F401
from .conditional_expressions import if_exp # noqa: F401
Expand Down
81 changes: 81 additions & 0 deletions src/autoqasm/operators/arithmetic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

"""Operators for arithmetic operators: // """

from __future__ import annotations

from autoqasm import program
from autoqasm import types as aq_types

from .utils import _register_and_convert_parameters


def floor_div(
num: aq_types.IntVar | aq_types.FloatVar | int | float,
den: aq_types.IntVar | aq_types.FloatVar | int | float,
) -> int | aq_types.IntVar:
"""Functional form of "//".
Args:
num (IntVar | FloatVar | int | float) : The numerator of the integer division
den (IntVar | FloatVar | int | float) : The denominator of the integer division
Returns :
int | IntVar : integer division, IntVar if either numerator or denominator
are QASM types, else int
"""
if aq_types.is_qasm_type(num) or aq_types.is_qasm_type(den):
return _oqpy_floor_div(num, den)
else:
return _py_floor_div(num, den)


def _oqpy_floor_div(
num: aq_types.IntVar | aq_types.FloatVar | int | float,
den: aq_types.IntVar | aq_types.FloatVar | int | float,
) -> aq_types.IntVar | aq_types.FloatVar:
num, den = _register_and_convert_parameters(num, den)
oqpy_program = program.get_program_conversion_context().get_oqpy_program()
num_is_float = isinstance(num, (aq_types.FloatVar, float))
den_is_float = isinstance(den, (aq_types.FloatVar, float))

# if either is a FloatVar, then both must be FloatVar
if num_is_float and not den_is_float:
den_float_var = aq_types.FloatVar()
oqpy_program.declare(den_float_var)
oqpy_program.set(den_float_var, den)
den = den_float_var
if den_is_float and not num_is_float:
num_float_var = aq_types.FloatVar()
oqpy_program.declare(num_float_var)
oqpy_program.set(num_float_var, num)
num = num_float_var

# if either is a FloatVar, then the result will be a FloatVar
result = aq_types.IntVar()
oqpy_program.declare(result)
oqpy_program.set(result, num / den)

if num_is_float or den_is_float:
float_result = aq_types.FloatVar()
oqpy_program.declare(float_result)
oqpy_program.set(float_result, result)
return float_result

return result


def _py_floor_div(
num: int | float,
den: int | float,
) -> int | float:
return num // den
9 changes: 8 additions & 1 deletion src/autoqasm/transpiler/transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,13 @@
from malt.utils import ag_logging as logging

from autoqasm import operators, program, types
from autoqasm.converters import assignments, break_statements, comparisons, return_statements
from autoqasm.converters import (
arithmetic,
assignments,
break_statements,
comparisons,
return_statements,
)


class PyToOqpy(transpiler.PyToPy):
Expand Down Expand Up @@ -135,6 +141,7 @@ def transform_ast(
node = control_flow.transform(node, ctx)
node = conditional_expressions.transform(node, ctx)
node = comparisons.transform(node, ctx)
node = arithmetic.transform(node, ctx)
node = logical_expressions.transform(node, ctx)
node = variables.transform(node, ctx)

Expand Down
61 changes: 61 additions & 0 deletions test/unit_tests/autoqasm/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,3 +930,64 @@ def test_list_ops():
assert np.array_equal(c, [[2, 3, 4], [2, 3, 4]])

assert test_list_ops.build().to_ir()


def test_integer_division_on_intvars():
@aq.main(num_qubits=2)
def main():
a = aq.IntVar(5)
b = aq.IntVar(2)
c = a // b # noqa: F841

expected_ir = """OPENQASM 3.0;
int[32] c;
qubit[2] __qubits__;
int[32] a = 5;
int[32] b = 2;
int[32] __int_2__;
__int_2__ = a / b;
c = __int_2__;"""
assert main.build().to_ir() == expected_ir


def test_integer_division_on_mixed_vars():
@aq.main(num_qubits=2)
def main():
a = aq.IntVar(5)
b = aq.FloatVar(2.3)
c = a // b # noqa: F841
d = b // a # noqa: F841

expected_ir = """OPENQASM 3.0;
float[64] c;
float[64] d;
qubit[2] __qubits__;
int[32] a = 5;
float[64] b = 2.3;
float[64] __float_2__;
__float_2__ = a;
int[32] __int_3__;
__int_3__ = __float_2__ / b;
float[64] __float_4__;
__float_4__ = __int_3__;
c = __float_4__;
float[64] __float_5__;
__float_5__ = a;
int[32] __int_6__;
__int_6__ = b / __float_5__;
float[64] __float_7__;
__float_7__ = __int_6__;
d = __float_7__;"""
assert main.build().to_ir() == expected_ir


def test_integer_division_on_python_types():
@aq.main(num_qubits=2)
def main():
a = 5
b = 2.3
c = a // b # noqa: F841

expected_ir = """OPENQASM 3.0;
qubit[2] __qubits__;"""
assert main.build().to_ir() == expected_ir

0 comments on commit 58dadab

Please sign in to comment.