Skip to content

Commit

Permalink
Manually adapt casmiddleware to cas branch status
Browse files Browse the repository at this point in the history
  • Loading branch information
pvaut committed Feb 26, 2014
1 parent 8e0b99f commit 72942dc
Showing 1 changed file with 149 additions and 125 deletions.
274 changes: 149 additions & 125 deletions cas/casmiddleware.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,60 @@
import logging
import time
from cgi import parse_qs
from urllib import quote, urlencode
from urllib import quote, urlencode, unquote_plus
from urlparse import urlparse
import requests
import xml.dom.minidom
from werkzeug.formparser import parse_form_data
from werkzeug.wrappers import Request,Response
import re

__all__ = ['CASMiddleware']


# Session keys
CAS_USERNAME = 'cas.username'
CAS_GROUPS = 'cas.groups'
CAS_TOKEN = 'cas.token'

CAS_ORIGIN = 'cas.origin'


def get_original_url(environ):
url = environ['wsgi.url_scheme'] + '://'

if environ.get('HTTP_HOST'):
url += environ['HTTP_HOST']
else:
url += environ['SERVER_NAME']

if environ['wsgi.url_scheme'] == 'https':
if environ['SERVER_PORT'] != '443':
url += ':' + environ['SERVER_PORT']
else:
if environ['SERVER_PORT'] != '80':
url += ':' + environ['SERVER_PORT']

url += quote(environ.get('SCRIPT_NAME',''))
url += quote(environ.get('PATH_INFO',''))

if environ.get('QUERY_STRING'):
params = parse_qs(environ['QUERY_STRING'])

for k in ('ticket',):
if k in params:
del params[k]
if params:
url += '?' + urlencode(params, doseq=True)

return url

CAS_COOKIE_NAME = 'cas.cookie'

class CASMiddleware(object):

casNamespaceUri = 'http://www.yale.edu/tp/cas'
samlpNamespaceUri = 'urn:oasis:names:tc:SAML:2.0:protocol'
samlNamespaceUri = 'urn:oasis:names:tc:SAML:2.0:assertion'

def __init__(self, application, cas_root_url, logout_url = '/logout', logout_dest = '', protocol_version = 2, casfailed_url=None):

def __init__(self, application, cas_root_url, entry_page = '/', logout_url = '/logout', logout_dest = '', protocol_version = 2, casfailed_url=None, session_store = None, ignore_redirect = None, ignored_callback = None):
self._application = application
self._root_url = cas_root_url
self._login_url = cas_root_url + '/login'
self._logout_url = logout_url
self._sso_logout_url = cas_root_url + '/logout'
self._logout_dest = logout_dest
self._entry_page = entry_page
self._protocol = protocol_version
self._casfailed_url = casfailed_url
self._session_store = session_store
self._session = None
self._cookie_expires = False
if ignore_redirect is not None:
self._ignore_redirect = re.compile(ignore_redirect)
self._ignored_callback = ignored_callback
else:
self._ignore_redirect = None

def _validate(self, environ, session, ticket):

def _validate(self, service_url, ticket):

if self._protocol == 2:
validate_url = self._root_url + '/serviceValidate'
elif self._protocol == 3:
validate_url = self._root_url + '/p3/serviceValidate'

service_url = get_original_url(environ)
r = requests.get(validate_url, params = {'service': service_url, 'ticket': ticket})
result = r.text
logging.debug(result)
Expand All @@ -80,7 +68,7 @@ def _validate(self, environ, session, ticket):
userNode = nodes[0]
if userNode.firstChild is not None:
username = userNode.firstChild.nodeValue
self._set_session_var(session, CAS_USERNAME, username)
self._set_session_var(CAS_USERNAME, username)
nodes = successNode.getElementsByTagNameNS(self.casNamespaceUri, 'memberOf')
if nodes:
groupName = []
Expand All @@ -89,26 +77,53 @@ def _validate(self, environ, session, ticket):
groupName.append(groupNode.firstChild.nodeValue)
if self._protocol == 2:
#Common but non standard extension - only one value - concatenated on the server
self._set_session_var(session, CAS_GROUPS, groupName[0])
self._set_session_var(CAS_GROUPS, groupName[0])
elif self._protocol == 3:
#So that the value is the same for version 2 or 3
self._set_session_var(session, CAS_GROUPS, '[' + ', '.join(groupName) + ']')
self._set_session_var(CAS_GROUPS, '[' + ', '.join(groupName) + ']')
dom.unlink()

return username

def _is_single_sign_out(self, environ, session):
def _is_session_expired(self, request):
#
# self._session_store.delete(self._session)
# self._get_session(request)
# return True
return False

def _remove_session_by_ticket(self, ticket_id):
sessions = self._session_store.list()
for sid in sessions:
session = self._session_store.get(sid)
logging.debug("Checking session:" + str(session))
if CAS_TOKEN in session and session[CAS_TOKEN] == ticket_id:
logging.info("Removed session for ticket:" + ticket_id)
self._session_store.delete(session)

def _is_single_sign_out(self, environ):
logging.debug("Testing for SLO")
if environ['REQUEST_METHOD'] == 'POST':
service_url = get_original_url(environ)
origin = self._get_session_var(session, CAS_ORIGIN)
current_url = environ.get('PATH_INFO','')
origin = self._entry_page
logging.debug("Testing for SLO:" + current_url + " vs " + origin)
if current_url == origin:
try:
request_body_size = int(environ.get('CONTENT_LENGTH', 0))
except (ValueError):
request_body_size = 0

request_body = environ['wsgi.input'].read(request_body_size)
logging.debug(request_body)
form = parse_form_data(environ)[1]
request_body = form['logoutRequest']
request_body = unquote_plus(request_body).decode('utf8')
dom = xml.dom.minidom.parseString(request_body)
nodes = dom.getElementsByTagNameNS(self.samlpNamespaceUri, 'SessionIndex')
if nodes:
sessionNode = nodes[0]
if sessionNode.firstChild is not None:
sessionId = sessionNode.firstChild.nodeValue
logging.info("Received SLO request for:" + sessionId)
self._remove_session_by_ticket(sessionId)
return True
except (Exception):
logging.warning("Exception parsing post")
logging.exception("Exception parsing post:" + request_body)
return False

def _is_logout(self, environ):
Expand All @@ -119,118 +134,127 @@ def _is_logout(self, environ):
return False

def __call__(self, environ, start_response):
session = self._get_session(environ)
if self._has_session_var(session, CAS_USERNAME):
self._set_values(environ, session)
request = Request(environ)
response = Response('')
self._get_session(request)
if self._has_session_var(CAS_USERNAME) and not self._is_session_expired(request):
self._set_values(environ)
if self._is_logout(environ):
self._do_session_logout(session)
dest = self._get_logout_redirect_url(session)
start_response('302 Moved Temporarily', [
('Location',
'%s?service=%s' % (self._sso_logout_url,
quote(dest)))
])
return []
#Check for single sign out
if (self._is_single_sign_out(environ, session)):
logging.debug('Single sign out request received')
return []
self._do_session_logout()
response = self._get_logout_redirect_url()
return response(environ, start_response)
return self._application(environ, start_response)
else:
query_string = environ.get('QUERY_STRING', '')
params = parse_qs(query_string)
logging.debug('Session not authenticated' + str(session))
params = request.args
logging.debug('Session not authenticated' + str(self._session))
if params.has_key('ticket'):
# Have ticket, validate with CAS server
ticket = params['ticket'][0]
ticket = params['ticket']

service_url = get_original_url(environ)
service_url = request.url

username = self._validate(environ, session, ticket)
service_url = re.sub(r".ticket=" + ticket, "", service_url)
logging.debug('Service URL' + service_url)
logging.debug(str(request))

username = self._validate(service_url, ticket)

if username is not None:
# Validation succeeded, redirect back to app
logging.debug('Validated ' + username)
self._set_session_var(session, CAS_ORIGIN, service_url)
self._save_session(session)
start_response('302 Moved Temporarily', [
('Location', service_url)
])
return []
self._set_session_var(CAS_ORIGIN, service_url)
self._set_session_var(CAS_TOKEN, ticket)
self._save_session()
response.status = '302 Moved Temporarily'
response.headers['Location'] = service_url
return response(environ, start_response)
else:
# Validation failed (for whatever reason)
return self._casfailed(environ, service_url, start_response)
response = self._casfailed(environ, service_url, start_response)
return response(environ, start_response)
else:
#Check for single sign out
if (self._is_single_sign_out(environ)):
logging.debug('Single sign out request received')
response.status = '200 OK'
return response(environ, start_response)
if self._ignore_redirect is not None:
if self._ignore_redirect.match(request.url):
if self._ignored_callback is not None:
return self._ignored_callback(environ, start_response)
logging.debug('Does not have ticket redirecting')
service_url = get_original_url(environ)
# Checking if we came here from an AJAX request to DQXServer
# Note that in principle this should not happen, as the first thing to authenticate is the html page
# Sending a clean error message to the client in this case anyway
if service_url.find('?datatype=') > 0:
resp = '{"Error":"NotAuthenticated"}'
start_response('200 OK', [('Content-type', 'application/json'), ('Content-Length', str(len(resp)))])
return [resp]
else:
start_response('302 Moved Temporarily', [
('Location',
'%s?service=%s' % (self._login_url,
quote(service_url)))
])
return []

def _get_session(self, environ):
return environ['beaker.session']

def _has_session_var(self, session, name):
return name in session

def _remove_session_var(self, session, name):
del session[name]

def _set_session_var(self, session, name, value):
session[name] = value
service_url = request.url
response.status = '302 Moved Temporarily'
response.headers['Location'] = '%s?service=%s' % (self._login_url, quote(service_url))
response.set_cookie(CAS_COOKIE_NAME, value = self._session.sid, max_age = None, expires = None)
return response(environ, start_response)

def _get_session(self, request):
sid = request.cookies.get(CAS_COOKIE_NAME)
if sid is None:
self._session = self._session_store.new()
self._set_session_var('_created_time', str(time.time()))
else:
self._session = self._session_store.get(sid)
self._set_session_var('_accessed_time', str(time.time()))

def _has_session_var(self, name):
return name in self._session

def _remove_session_var(self, name):
del self._session[name]

def _set_session_var(self, name, value):
self._session[name] = value
logging.debug("Setting session:" + name + " to " + value)

def _get_session_var(self, session, name):
return (session[name])
def _get_session_var(self, name):
return (self._session[name])

def _save_session(self, session):
logging.debug("Saving session:" + str(session))
session.save()
def _save_session(self):
if self._session.should_save:
logging.debug("Saving session:" + str(self._session))
self._session_store.save(self._session)

def _do_session_logout(self, session):
self._remove_session_var(session, CAS_USERNAME)
self._remove_session_var(session, CAS_GROUPS)
self._save_session(session)

def _get_logout_redirect_url(self, session):
def _do_session_logout(self):
self._remove_session_var(CAS_USERNAME)
self._remove_session_var(CAS_GROUPS)
self._save_session()
self._session_store.delete(self._session)

def _get_logout_redirect_url(self):
response = Response('')
dest = self._logout_dest
if dest == '':
dest = self._get_session_var(session, CAS_ORIGIN)
if dest == '' and self._has_session_var(CAS_ORIGIN):
dest = self._get_session_var(CAS_ORIGIN)
logging.debug("Log out dest:" + dest)
parsed = urlparse(dest)
if parsed.path == self._logout_url:
dest = self._sso_logout_url
logging.debug("Log out redirecting to:" + dest)
return dest
response.status = '302 Moved Temporarily'
response.headers['Location'] = '%s?service=%s' % (self._sso_logout_url, quote(dest))
return response

#Communicate values to the rest of the application
def _set_values(self, environ, session):
username = self._get_session_var(session, CAS_USERNAME)
groups = self._get_session_var(session, CAS_GROUPS)
def _set_values(self, environ):
username = self._get_session_var(CAS_USERNAME)
logging.debug('Session authenticated for ' + username)
environ['AUTH_TYPE'] = 'CAS'
environ['REMOTE_USER'] = str(username)
environ['HTTP_CAS_MEMBEROF'] = str(groups)

def _casfailed(self, environ, service_url, start_response):

response = Response('')
if self._casfailed_url is not None:
start_response('302 Moved Temporarily', [
('Location', self._casfailed_url)
])
return []
response.status = '302 Moved Temporarily'
response.headers['Location'] = self._casfailed_url
else:
# Default failure notice
start_response('401 Unauthorized', [('Content-Type', 'text/plain'), ('WWW-Authenticate','CAS casUrl="' + self._root_url + '" service="' + service_url + '"')])
return ['CAS authentication failed\n']
response.status = '401 Unauthorized'
response.headers['Location'] = self._casfailed_url
response.headers['Content-Type'] = 'text/plain'
response.headers['WWW-Authenticate'] = 'CAS casUrl="' + self._root_url + '" service="' + service_url + '"'
response.data = 'CAS authentication failed\n'
return response

0 comments on commit 72942dc

Please sign in to comment.