diff --git a/py4web/core.py b/py4web/core.py index da825f29..368b0dfd 100644 --- a/py4web/core.py +++ b/py4web/core.py @@ -514,10 +514,7 @@ def on_success(self, context): output = context["output"] flash = self.local.flash or "" if isinstance(output, dict): - if "template_inject" in context: - context["template_inject"]["flash"] = flash - else: - context["template_inject"] = dict(flash=flash) + context["template_inject"]["flash"] = flash elif self.local.flash is not None: response.headers.setdefault("component-flash", json.dumps(flash)) @@ -611,7 +608,7 @@ def on_success(self, context): ctx = dict(request=request) ctx.update(HELPERS) ctx.update(URL=URL) - ctx.update(context.get("template_inject", {})) + ctx.update(context["template_inject"]) ctx.update(output) ctx["__vars__"] = output app_folder = os.path.join(os.environ["PY4WEB_APPS_FOLDER"], request.app_name) @@ -993,6 +990,7 @@ def wrapper(*args, **kwargs): "output": None, "exception": None, "processed": processed, + "template_inject": {}, } try: for fixture in fixtures: diff --git a/py4web/utils/auth.py b/py4web/utils/auth.py index e8b501eb..a32e2f95 100644 --- a/py4web/utils/auth.py +++ b/py4web/utils/auth.py @@ -339,7 +339,7 @@ def deny_action(self, action_name): def on_success(self, context): if self.inject: - context["template_inject"] = {"user": self.get_user()} + context["template_inject"]["user"] = self.get_user() def define_tables(self): """Defines the auth_user table""" diff --git a/tests/test_action.py b/tests/test_action.py index 1e229eb2..34917fef 100644 --- a/tests/test_action.py +++ b/tests/test_action.py @@ -2,6 +2,7 @@ import copy import multiprocessing import os +import sys import threading import time import unittest @@ -18,7 +19,18 @@ ) SECRET = str(uuid.uuid4()) -db = DAL("sqlite://storage_%s" % uuid.uuid4(), folder="/tmp/") +if sys.platform == "win32": + path = "./tmp/" +else: + path = "/tmp/" + +try: + os.mkdir(path) +except Exception: + pass +with open(path + "sql.log", "w"): + pass +db = DAL("sqlite://storage_%s" % uuid.uuid4(), folder=path) db.define_table("thing", Field("name")) session = Session(secret=SECRET) cache = Cache() diff --git a/tests/test_template.py b/tests/test_template.py index ff996493..025608ed 100644 --- a/tests/test_template.py +++ b/tests/test_template.py @@ -9,7 +9,7 @@ class TemplateTest(unittest.TestCase): def test_template(self): t = Template("index.html", path=PATH) - context = dict(output=dict(n=3)) + context = dict(output=dict(n=3), template_inject={}) t.on_success(context) output = context["output"] self.assertEqual(output, "0,1,2.\n")