Skip to content

Commit

Permalink
Use RandomStreams in test_printing
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Mar 11, 2023
1 parent 1f5d81f commit cbd0e2f
Showing 1 changed file with 23 additions and 15 deletions.
38 changes: 23 additions & 15 deletions tests/test_printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@


def test_PreamblPPrinter():
# Make sure we can print a `Function` and `FunctionGraph`
"""Make sure we can print a `Function` and `FunctionGraph`."""
srng = at.random.RandomStream(seed=2320)

mu = at.scalar("\\mu")
sigma = at.scalar("\\sigma")
b = at.scalar("b")

y = b * at.random.normal(mu, sigma)
y = b * srng.normal(mu, sigma)

y_fn = aesara.function([mu, sigma, b], y)

Expand All @@ -28,7 +30,9 @@ def test_PreamblPPrinter():


def test_notex_print():
normalrv_noname_expr = at.scalar("b") * at.random.normal(
srng = at.random.RandomStream(seed=2320)

normalrv_noname_expr = at.scalar("b") * srng.normal(
at.scalar("\\mu"), at.scalar("\\sigma")
)
expected = textwrap.dedent(
Expand All @@ -41,7 +45,7 @@ def test_notex_print():
assert pprint(normalrv_noname_expr) == expected.strip()

# Make sure the constant shape is show in values and not symbols.
normalrv_name_expr = at.scalar("b") * at.random.normal(
normalrv_name_expr = at.scalar("b") * srng.normal(
at.scalar("\\mu"), at.scalar("\\sigma"), size=[2, 1], name="X"
)
expected = textwrap.dedent(
Expand All @@ -53,10 +57,10 @@ def test_notex_print():
)
assert pprint(normalrv_name_expr) == expected.strip()

normalrv_noname_expr_2 = at.matrix("M") * at.random.normal(
normalrv_noname_expr_2 = at.matrix("M") * srng.normal(
at.scalar("\\mu_2"), at.scalar("\\sigma_2")
)
normalrv_noname_expr_2 *= at.scalar("b") * at.random.normal(
normalrv_noname_expr_2 *= at.scalar("b") * srng.normal(
normalrv_noname_expr_2, at.scalar("\\sigma")
) + at.scalar("c")
expected = textwrap.dedent(
Expand Down Expand Up @@ -99,7 +103,9 @@ def test_notex_print():


def test_tex_print():
normalrv_noname_expr = at.scalar("b") * at.random.normal(
srng = at.random.RandomStream(seed=2320)

normalrv_noname_expr = at.scalar("b") * srng.normal(
at.scalar("\\mu"), at.scalar("\\sigma")
)
expected = textwrap.dedent(
Expand All @@ -117,7 +123,7 @@ def test_tex_print():
)
assert latex_pprint(normalrv_noname_expr) == expected.strip()

normalrv_name_expr = at.scalar("b") * at.random.normal(
normalrv_name_expr = at.scalar("b") * srng.normal(
at.scalar("\\mu"), at.scalar("\\sigma"), size=[2, 1], name="X"
)
expected = textwrap.dedent(
Expand All @@ -135,10 +141,10 @@ def test_tex_print():
)
assert latex_pprint(normalrv_name_expr) == expected.strip()

normalrv_noname_expr_2 = at.matrix("M") * at.random.normal(
normalrv_noname_expr_2 = at.matrix("M") * srng.normal(
at.scalar("\\mu_2"), at.scalar("\\sigma_2")
)
normalrv_noname_expr_2 *= at.scalar("b") * at.random.normal(
normalrv_noname_expr_2 *= at.scalar("b") * srng.normal(
normalrv_noname_expr_2, at.scalar("\\sigma")
) + at.scalar("c")
expected = textwrap.dedent(
Expand Down Expand Up @@ -205,9 +211,9 @@ def test_tex_print():
)
assert latex_pprint(at.vector("M", dtype="uint32")[0:4:2]) == expected.strip()

S_rv = at.random.invgamma(0.5, 0.5, name="S")
T_rv = at.random.halfcauchy(1.0, name="T")
Y_rv = at.random.normal(T_rv, at.sqrt(S_rv), name="Y")
S_rv = srng.invgamma(0.5, 0.5, name="S")
T_rv = srng.halfcauchy(1.0, name="T")
Y_rv = srng.normal(T_rv, at.sqrt(S_rv), name="Y")
expected = textwrap.dedent(
r"""
\begin{equation}
Expand All @@ -230,8 +236,10 @@ def test_tex_print():
reason=r"AePPL is not aware of the distributions' support and displays \mathbb{R} by default"
)
def test_tex_print_support_dimension():
U_rv = at.random.uniform(0, 1, name="U")
T_rv = at.random.halfcauchy(U_rv, name="T")
srng = at.random.RandomStream(seed=2320)

U_rv = srng.uniform(0, 1, name="U")
T_rv = srng.halfcauchy(U_rv, name="T")
expected = textwrap.dedent(
r"""
\begin{equation}
Expand Down

0 comments on commit cbd0e2f

Please sign in to comment.