diff --git a/mediathread/main/tests/test_views.py b/mediathread/main/tests/test_views.py index 78c251171..2eb803520 100644 --- a/mediathread/main/tests/test_views.py +++ b/mediathread/main/tests/test_views.py @@ -435,12 +435,12 @@ def test_get_initial_anonymous(self): view.request = RequestFactory().get('/contact/') view.request.session = {} view.request.user = AnonymousUser() - view.get_initial() + initial = view.get_initial() - self.assertIsNotNone(view.initial['issue_date']) - self.assertFalse('name' in view.initial) - self.assertFalse('email' in view.initial) - self.assertFalse('username' in view.initial) + self.assertIsNotNone(initial['issue_date']) + self.assertFalse('name' in initial) + self.assertFalse('email' in initial) + self.assertFalse('username' in initial) def test_get_initial_not_anonymous(self): view = ContactUsView() @@ -450,12 +450,20 @@ def test_get_initial_not_anonymous(self): last_name='Bar', email='foo@bar.com') - view.get_initial() + initial = view.get_initial() + self.assertIsNotNone(initial['issue_date']) + self.assertEquals(initial['name'], 'Foo Bar') + self.assertEquals(initial['email'], 'foo@bar.com') + self.assertEquals(initial['username'], view.request.user.username) - self.assertIsNotNone(view.initial['issue_date']) - self.assertEquals(view.initial['name'], 'Foo Bar') - self.assertEquals(view.initial['email'], 'foo@bar.com') - self.assertEquals(view.initial['username'], view.request.user.username) + # a subsequent call using an anonymous session returns a clean initial + view.request.session = {} + view.request.user = AnonymousUser() + initial = view.get_initial() + self.assertIsNotNone(initial['issue_date']) + self.assertFalse('name' in initial) + self.assertFalse('email' in initial) + self.assertFalse('username' in initial) def test_form_valid(self): view = ContactUsView() diff --git a/mediathread/main/views.py b/mediathread/main/views.py index b149abf46..64605ea4d 100644 --- a/mediathread/main/views.py +++ b/mediathread/main/views.py @@ -417,17 +417,18 @@ def get_initial(self): """ Returns the initial data to use for forms on this view. """ + initial = super(ContactUsView, self).get_initial() if not self.request.user.is_anonymous(): - self.initial['name'] = self.request.user.get_full_name() - self.initial['email'] = self.request.user.email - self.initial['username'] = self.request.user.username + initial['name'] = self.request.user.get_full_name() + initial['email'] = self.request.user.email + initial['username'] = self.request.user.username - self.initial['issue_date'] = datetime.now() + initial['issue_date'] = datetime.now() if SESSION_KEY in self.request.session: - self.initial['course'] = self.request.session[SESSION_KEY].title + initial['course'] = self.request.session[SESSION_KEY].title - return super(ContactUsView, self).get_initial() + return initial def form_valid(self, form): subject = "Mediathread Contact Us Request"