diff --git a/tests/test_printing.py b/tests/test_printing.py index 6eb96100..941340a0 100644 --- a/tests/test_printing.py +++ b/tests/test_printing.py @@ -8,12 +8,14 @@ def test_PreamblPPrinter(): - # Make sure we can print a `Function` and `FunctionGraph` + """Make sure we can print a `Function` and `FunctionGraph`.""" + srng = at.random.RandomStream(seed=2320) + mu = at.scalar("\\mu") sigma = at.scalar("\\sigma") b = at.scalar("b") - y = b * at.random.normal(mu, sigma) + y = b * srng.normal(mu, sigma) y_fn = aesara.function([mu, sigma, b], y) @@ -28,7 +30,9 @@ def test_PreamblPPrinter(): def test_notex_print(): - normalrv_noname_expr = at.scalar("b") * at.random.normal( + srng = at.random.RandomStream(seed=2320) + + normalrv_noname_expr = at.scalar("b") * srng.normal( at.scalar("\\mu"), at.scalar("\\sigma") ) expected = textwrap.dedent( @@ -41,7 +45,7 @@ def test_notex_print(): assert pprint(normalrv_noname_expr) == expected.strip() # Make sure the constant shape is show in values and not symbols. - normalrv_name_expr = at.scalar("b") * at.random.normal( + normalrv_name_expr = at.scalar("b") * srng.normal( at.scalar("\\mu"), at.scalar("\\sigma"), size=[2, 1], name="X" ) expected = textwrap.dedent( @@ -53,10 +57,10 @@ def test_notex_print(): ) assert pprint(normalrv_name_expr) == expected.strip() - normalrv_noname_expr_2 = at.matrix("M") * at.random.normal( + normalrv_noname_expr_2 = at.matrix("M") * srng.normal( at.scalar("\\mu_2"), at.scalar("\\sigma_2") ) - normalrv_noname_expr_2 *= at.scalar("b") * at.random.normal( + normalrv_noname_expr_2 *= at.scalar("b") * srng.normal( normalrv_noname_expr_2, at.scalar("\\sigma") ) + at.scalar("c") expected = textwrap.dedent( @@ -99,7 +103,9 @@ def test_notex_print(): def test_tex_print(): - normalrv_noname_expr = at.scalar("b") * at.random.normal( + srng = at.random.RandomStream(seed=2320) + + normalrv_noname_expr = at.scalar("b") * srng.normal( at.scalar("\\mu"), at.scalar("\\sigma") ) expected = textwrap.dedent( @@ -117,7 +123,7 @@ def test_tex_print(): ) assert latex_pprint(normalrv_noname_expr) == expected.strip() - normalrv_name_expr = at.scalar("b") * at.random.normal( + normalrv_name_expr = at.scalar("b") * srng.normal( at.scalar("\\mu"), at.scalar("\\sigma"), size=[2, 1], name="X" ) expected = textwrap.dedent( @@ -135,10 +141,10 @@ def test_tex_print(): ) assert latex_pprint(normalrv_name_expr) == expected.strip() - normalrv_noname_expr_2 = at.matrix("M") * at.random.normal( + normalrv_noname_expr_2 = at.matrix("M") * srng.normal( at.scalar("\\mu_2"), at.scalar("\\sigma_2") ) - normalrv_noname_expr_2 *= at.scalar("b") * at.random.normal( + normalrv_noname_expr_2 *= at.scalar("b") * srng.normal( normalrv_noname_expr_2, at.scalar("\\sigma") ) + at.scalar("c") expected = textwrap.dedent( @@ -205,9 +211,9 @@ def test_tex_print(): ) assert latex_pprint(at.vector("M", dtype="uint32")[0:4:2]) == expected.strip() - S_rv = at.random.invgamma(0.5, 0.5, name="S") - T_rv = at.random.halfcauchy(1.0, name="T") - Y_rv = at.random.normal(T_rv, at.sqrt(S_rv), name="Y") + S_rv = srng.invgamma(0.5, 0.5, name="S") + T_rv = srng.halfcauchy(1.0, name="T") + Y_rv = srng.normal(T_rv, at.sqrt(S_rv), name="Y") expected = textwrap.dedent( r""" \begin{equation} @@ -230,8 +236,10 @@ def test_tex_print(): reason=r"AePPL is not aware of the distributions' support and displays \mathbb{R} by default" ) def test_tex_print_support_dimension(): - U_rv = at.random.uniform(0, 1, name="U") - T_rv = at.random.halfcauchy(U_rv, name="T") + srng = at.random.RandomStream(seed=2320) + + U_rv = srng.uniform(0, 1, name="U") + T_rv = srng.halfcauchy(U_rv, name="T") expected = textwrap.dedent( r""" \begin{equation}