diff --git a/.gitignore b/.gitignore
index 0b5a2a6..2ba2ee7 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,5 +1,6 @@
*.pyc
*.egg-info
+*.swp
build/
bootstrap3
diff --git a/cas_server/auth.py b/cas_server/auth.py
index 7ccacae..99018a4 100644
--- a/cas_server/auth.py
+++ b/cas_server/auth.py
@@ -12,6 +12,9 @@
"""Some authentication classes for the CAS"""
from django.conf import settings
from django.contrib.auth import get_user_model
+from django.utils import timezone
+
+from datetime import timedelta
try:
import MySQLdb
import MySQLdb.cursors
@@ -19,6 +22,8 @@ try:
except ImportError:
MySQLdb = None
+from .models import FederatedUser
+
class AuthUser(object):
def __init__(self, username):
@@ -140,3 +145,37 @@ class DjangoAuthUser(AuthUser):
for field in self.user._meta.fields:
attr[field.attname] = getattr(self.user, field.attname)
return attr
+
+
+class CASFederateAuth(AuthUser):
+ user = None
+
+ def __init__(self, username):
+ component = username.split('@')
+ username = '@'.join(component[:-1])
+ provider = component[-1]
+ try:
+ self.user = FederatedUser.objects.get(username=username, provider=provider)
+ super(CASFederateAuth, self).__init__(
+ "%s@%s" % (self.user.username, self.user.provider)
+ )
+ except FederatedUser.DoesNotExist:
+ super(CASFederateAuth, self).__init__("%s@%s" % (username, provider))
+
+ def test_password(self, ticket):
+ """test `password` agains the user"""
+ if not self.user or not self.user.ticket:
+ return False
+ else:
+ return (
+ ticket == self.user.ticket and
+ self.user.last_update >
+ (timezone.now() - timedelta(seconds=settings.CAS_TICKET_VALIDITY))
+ )
+
+ def attributs(self):
+ """return a dict of user attributes"""
+ if not self.user:
+ return {}
+ else:
+ return self.user.attributs
diff --git a/cas_server/cas.py b/cas_server/cas.py
new file mode 100644
index 0000000..bea0638
--- /dev/null
+++ b/cas_server/cas.py
@@ -0,0 +1,337 @@
+from six.moves.urllib import parse as urllib_parse
+from six.moves.urllib import request as urllib_request
+from six.moves.urllib.request import Request
+from uuid import uuid4
+import datetime
+
+
+class CASError(ValueError):
+ pass
+
+
+class SingleLogoutMixin(object):
+ @classmethod
+ def get_saml_slos(cls, logout_request):
+ """returns saml logout ticket info"""
+ from lxml import etree
+ try:
+ root = etree.fromstring(logout_request)
+ return root.xpath(
+ "//samlp:SessionIndex",
+ namespaces={'samlp': "urn:oasis:names:tc:SAML:2.0:protocol"})
+ except etree.XMLSyntaxError:
+ pass
+
+
+class CASClient(object):
+ def __new__(self, *args, **kwargs):
+ version = kwargs.pop('version')
+ if version in (1, '1'):
+ return CASClientV1(*args, **kwargs)
+ elif version in (2, '2'):
+ return CASClientV2(*args, **kwargs)
+ elif version in (3, '3'):
+ return CASClientV3(*args, **kwargs)
+ elif version == 'CAS_2_SAML_1_0':
+ return CASClientWithSAMLV1(*args, **kwargs)
+ raise ValueError('Unsupported CAS_VERSION %r' % version)
+
+
+class CASClientBase(object):
+
+ logout_redirect_param_name = 'service'
+
+ def __init__(self, service_url=None, server_url=None,
+ extra_login_params=None, renew=False,
+ username_attribute=None):
+
+ self.service_url = service_url
+ self.server_url = server_url
+ self.extra_login_params = extra_login_params or {}
+ self.renew = renew
+ self.username_attribute = username_attribute
+ pass
+
+ def verify_ticket(self, ticket):
+ """must return a triple"""
+ raise NotImplementedError()
+
+ def get_login_url(self):
+ """Generates CAS login URL"""
+ params = {'service': self.service_url}
+ if self.renew:
+ params.update({'renew': 'true'})
+
+ params.update(self.extra_login_params)
+ url = urllib_parse.urljoin(self.server_url, 'login')
+ query = urllib_parse.urlencode(params)
+ return url + '?' + query
+
+ def get_logout_url(self, redirect_url=None):
+ """Generates CAS logout URL"""
+ url = urllib_parse.urljoin(self.server_url, 'logout')
+ if redirect_url:
+ params = {self.logout_redirect_param_name: redirect_url}
+ url += '?' + urllib_parse.urlencode(params)
+ return url
+
+ def get_proxy_url(self, pgt):
+ """Returns proxy url, given the proxy granting ticket"""
+ params = urllib_parse.urlencode({'pgt': pgt, 'targetService': self.service_url})
+ return "%s/proxy?%s" % (self.server_url, params)
+
+ def get_proxy_ticket(self, pgt):
+ """Returns proxy ticket given the proxy granting ticket"""
+ response = urllib_request.urlopen(self.get_proxy_url(pgt))
+ if response.code == 200:
+ from lxml import etree
+ root = etree.fromstring(response.read())
+ tickets = root.xpath(
+ "//cas:proxyTicket",
+ namespaces={"cas": "http://www.yale.edu/tp/cas"}
+ )
+ if len(tickets) == 1:
+ return tickets[0].text
+ errors = root.xpath(
+ "//cas:authenticationFailure",
+ namespaces={"cas": "http://www.yale.edu/tp/cas"}
+ )
+ if len(errors) == 1:
+ raise CASError(errors[0].attrib['code'], errors[0].text)
+ raise CASError("Bad http code %s" % response.code)
+
+
+class CASClientV1(CASClientBase):
+ """CAS Client Version 1"""
+
+ logout_redirect_param_name = 'url'
+
+ def verify_ticket(self, ticket):
+ """Verifies CAS 1.0 authentication ticket.
+
+ Returns username on success and None on failure.
+ """
+ params = [('ticket', ticket), ('service', self.service)]
+ url = (urllib_parse.urljoin(self.server_url, 'validate') + '?' +
+ urllib_parse.urlencode(params))
+ page = urllib_request.urlopen(url)
+ try:
+ verified = page.readline().strip()
+ if verified == 'yes':
+ return page.readline().strip(), None, None
+ else:
+ return None, None, None
+ finally:
+ page.close()
+
+
+class CASClientV2(CASClientBase):
+ """CAS Client Version 2"""
+
+ url_suffix = 'serviceValidate'
+ logout_redirect_param_name = 'url'
+
+ def __init__(self, proxy_callback=None, *args, **kwargs):
+ """proxy_callback is for V2 and V3 so V3 is subclass of V2"""
+ self.proxy_callback = proxy_callback
+ super(CASClientV2, self).__init__(*args, **kwargs)
+
+ def verify_ticket(self, ticket):
+ """Verifies CAS 2.0+/3.0+ XML-based authentication ticket and returns extended attributes"""
+ response = self.get_verification_response(ticket)
+ return self.verify_response(response)
+
+ def get_verification_response(self, ticket):
+ params = [('ticket', ticket), ('service', self.service_url)]
+ if self.proxy_callback:
+ params.append(('pgtUrl', self.proxy_callback))
+ base_url = urllib_parse.urljoin(self.server_url, self.url_suffix)
+ url = base_url + '?' + urllib_parse.urlencode(params)
+ page = urllib_request.urlopen(url)
+ try:
+ return page.read()
+ finally:
+ page.close()
+
+ @classmethod
+ def parse_attributes_xml_element(cls, element):
+ attributes = dict()
+ for attribute in element:
+ tag = attribute.tag.split("}").pop()
+ if tag in attributes:
+ if isinstance(attributes[tag], list):
+ attributes[tag].append(attribute.text)
+ else:
+ attributes[tag] = [attributes[tag]]
+ attributes[tag].append(attribute.text)
+ else:
+ if tag == 'attraStyle':
+ pass
+ else:
+ attributes[tag] = attribute.text
+ return attributes
+
+ @classmethod
+ def verify_response(cls, response):
+ user, attributes, pgtiou = cls.parse_response_xml(response)
+ if len(attributes) == 0:
+ attributes = None
+ return user, attributes, pgtiou
+
+ @classmethod
+ def parse_response_xml(cls, response):
+ try:
+ from xml.etree import ElementTree
+ except ImportError:
+ from elementtree import ElementTree
+
+ user = None
+ attributes = {}
+ pgtiou = None
+
+ tree = ElementTree.fromstring(response)
+ if tree[0].tag.endswith('authenticationSuccess'):
+ for element in tree[0]:
+ if element.tag.endswith('user'):
+ user = element.text
+ elif element.tag.endswith('proxyGrantingTicket'):
+ pgtiou = element.text
+ elif element.tag.endswith('attributes'):
+ attributes = cls.parse_attributes_xml_element(element)
+ return user, attributes, pgtiou
+
+
+class CASClientV3(CASClientV2, SingleLogoutMixin):
+ """CAS Client Version 3"""
+ url_suffix = 'serviceValidate'
+ logout_redirect_param_name = 'service'
+
+ @classmethod
+ def parse_attributes_xml_element(cls, element):
+ attributes = dict()
+ for attribute in element:
+ tag = attribute.tag.split("}").pop()
+ if tag in attributes:
+ if isinstance(attributes[tag], list):
+ attributes[tag].append(attribute.text)
+ else:
+ attributes[tag] = [attributes[tag]]
+ attributes[tag].append(attribute.text)
+ else:
+ attributes[tag] = attribute.text
+ return attributes
+
+ @classmethod
+ def verify_response(cls, response):
+ return cls.parse_response_xml(response)
+
+
+SAML_1_0_NS = 'urn:oasis:names:tc:SAML:1.0:'
+SAML_1_0_PROTOCOL_NS = '{' + SAML_1_0_NS + 'protocol' + '}'
+SAML_1_0_ASSERTION_NS = '{' + SAML_1_0_NS + 'assertion' + '}'
+SAML_ASSERTION_TEMPLATE = """
+
+
+
+
+{ticket}
+
+"""
+
+
+class CASClientWithSAMLV1(CASClientV2, SingleLogoutMixin):
+ """CASClient 3.0+ with SAML"""
+
+ def verify_ticket(self, ticket, **kwargs):
+ """Verifies CAS 3.0+ XML-based authentication ticket and returns extended attributes.
+
+ @date: 2011-11-30
+ @author: Carlos Gonzalez Vila
+
+ Returns username and attributes on success and None,None on failure.
+ """
+
+ try:
+ from xml.etree import ElementTree
+ except ImportError:
+ from elementtree import ElementTree
+
+ page = self.fetch_saml_validation(ticket)
+
+ try:
+ user = None
+ attributes = {}
+ response = page.read()
+ tree = ElementTree.fromstring(response)
+ # Find the authentication status
+ success = tree.find('.//' + SAML_1_0_PROTOCOL_NS + 'StatusCode')
+ if success is not None and success.attrib['Value'].endswith(':Success'):
+ # User is validated
+ attrs = tree.findall('.//' + SAML_1_0_ASSERTION_NS + 'Attribute')
+ for at in attrs:
+ if self.username_attribute in list(at.attrib.values()):
+ user = at.find(SAML_1_0_ASSERTION_NS + 'AttributeValue').text
+ attributes['uid'] = user
+
+ values = at.findall(SAML_1_0_ASSERTION_NS + 'AttributeValue')
+ if len(values) > 1:
+ values_array = []
+ for v in values:
+ values_array.append(v.text)
+ attributes[at.attrib['AttributeName']] = values_array
+ else:
+ attributes[at.attrib['AttributeName']] = values[0].text
+ return user, attributes, None
+ finally:
+ page.close()
+
+ def fetch_saml_validation(self, ticket):
+ # We do the SAML validation
+ headers = {
+ 'soapaction': 'http://www.oasis-open.org/committees/security',
+ 'cache-control': 'no-cache',
+ 'pragma': 'no-cache',
+ 'accept': 'text/xml',
+ 'connection': 'keep-alive',
+ 'content-type': 'text/xml; charset=utf-8',
+ }
+ params = [('TARGET', self.service_url)]
+ saml_validate_url = urllib_parse.urljoin(
+ self.server_url, 'samlValidate',
+ )
+ request = Request(
+ saml_validate_url + '?' + urllib_parse.urlencode(params),
+ self.get_saml_assertion(ticket),
+ headers,
+ )
+ return urllib_request.urlopen(request)
+
+ @classmethod
+ def get_saml_assertion(cls, ticket):
+ """
+ http://www.jasig.org/cas/protocol#samlvalidate-cas-3.0
+
+ SAML request values:
+
+ RequestID [REQUIRED]:
+ unique identifier for the request
+ IssueInstant [REQUIRED]:
+ timestamp of the request
+ samlp:AssertionArtifact [REQUIRED]:
+ the valid CAS Service Ticket obtained as a response parameter at login.
+ """
+ # RequestID [REQUIRED] - unique identifier for the request
+ request_id = uuid4()
+
+ # e.g. 2014-06-02T09:21:03.071189
+ timestamp = datetime.datetime.now().isoformat()
+
+ return SAML_ASSERTION_TEMPLATE.format(
+ request_id=request_id,
+ timestamp=timestamp,
+ ticket=ticket,
+ ).encode('utf8')
diff --git a/cas_server/default_settings.py b/cas_server/default_settings.py
index 139569d..fe5de28 100644
--- a/cas_server/default_settings.py
+++ b/cas_server/default_settings.py
@@ -18,6 +18,7 @@ def setting_default(name, default_value):
setattr(settings, name, value)
setting_default('CAS_LOGIN_TEMPLATE', 'cas_server/login.html')
+setting_default('CAS_FEDERATE_TEMPLATE', 'cas_server/federate.html')
setting_default('CAS_WARN_TEMPLATE', 'cas_server/warn.html')
setting_default('CAS_LOGGED_TEMPLATE', 'cas_server/logged.html')
setting_default('CAS_LOGOUT_TEMPLATE', 'cas_server/logout.html')
@@ -70,3 +71,14 @@ setting_default('CAS_SQL_DBCHARSET', 'utf8')
setting_default('CAS_SQL_USER_QUERY', 'SELECT user AS usersame, pass AS '
'password, users.* FROM users WHERE user = %s')
setting_default('CAS_SQL_PASSWORD_CHECK', 'crypt') # crypt or plain
+
+
+setting_default('CAS_FEDERATE', False)
+# A dict of "provider name" -> (provider CAS server url, CAS version)
+setting_default('CAS_FEDERATE_PROVIDERS', {})
+
+if settings.CAS_FEDERATE:
+ settings.CAS_AUTH_CLASS = "cas_server.auth.CASFederateAuth"
+
+CAS_FEDERATE_PROVIDERS_LIST = settings.CAS_FEDERATE_PROVIDERS.keys()
+CAS_FEDERATE_PROVIDERS_LIST.sort()
diff --git a/cas_server/federate.py b/cas_server/federate.py
new file mode 100644
index 0000000..529ddd1
--- /dev/null
+++ b/cas_server/federate.py
@@ -0,0 +1,69 @@
+# ⁻*- coding: utf-8 -*-
+# This program is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
+# more details.
+#
+# You should have received a copy of the GNU General Public License version 3
+# along with this program; if not, write to the Free Software Foundation, Inc., 51
+# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+#
+# (c) 2015 Valentin Samir
+from .default_settings import settings
+
+from .cas import CASClient
+from .models import FederatedUser
+
+
+class CASFederateValidateUser(object):
+ username = None
+ attributs = {}
+ client = None
+
+ def __init__(self, provider, service_url):
+ self.provider = provider
+
+ if provider in settings.CAS_FEDERATE_PROVIDERS:
+ (server_url, version) = settings.CAS_FEDERATE_PROVIDERS[provider]
+ self.client = CASClient(
+ service_url=service_url,
+ version=version,
+ server_url=server_url,
+ extra_login_params={"provider": provider},
+ renew=False,
+ )
+
+ def get_login_url(self):
+ return self.client.get_login_url() if self.client is not None else False
+
+ def get_logout_url(self, redirect_url=None):
+ return self.client.get_logout_url(redirect_url) if self.client is not None else False
+
+ def verify_ticket(self, ticket):
+ """test `password` agains the user"""
+ if self.client is None:
+ return False
+ username, attributs, pgtiou = self.client.verify_ticket(ticket)
+ if username is not None:
+ attributs["provider"] = self.provider
+ self.username = username
+ self.attributs = attributs
+ try:
+ user = FederatedUser.objects.get(
+ username=username,
+ provider=self.provider
+ )
+ user.attributs = attributs
+ user.ticket = ticket
+ user.save()
+ except FederatedUser.DoesNotExist:
+ user = FederatedUser.objects.create(
+ username=username,
+ provider=self.provider,
+ attributs=attributs,
+ ticket=ticket
+ )
+ user.save()
+ return True
+ else:
+ return False
diff --git a/cas_server/forms.py b/cas_server/forms.py
index f970ccd..33b3a2c 100644
--- a/cas_server/forms.py
+++ b/cas_server/forms.py
@@ -9,7 +9,7 @@
#
# (c) 2015 Valentin Samir
"""forms for the app"""
-from .default_settings import settings
+from .default_settings import settings, CAS_FEDERATE_PROVIDERS_LIST
from django import forms
from django.utils.translation import ugettext_lazy as _
@@ -27,6 +27,17 @@ class WarnForm(forms.Form):
lt = forms.CharField(widget=forms.HiddenInput(), required=False)
+class FederateSelect(forms.Form):
+ provider = forms.ChoiceField(
+ label=_('Identity provider'),
+ choices=[(p, p) for p in CAS_FEDERATE_PROVIDERS_LIST]
+ )
+ service = forms.CharField(label=_('service'), widget=forms.HiddenInput(), required=False)
+ method = forms.CharField(widget=forms.HiddenInput(), required=False)
+ remember = forms.BooleanField(label=_('Remember the identity provider'), required=False)
+ warn = forms.BooleanField(label=_('warn'), required=False)
+
+
class UserCredential(forms.Form):
"""Form used on the login page to retrive user credentials"""
username = forms.CharField(label=_('login'))
@@ -46,6 +57,31 @@ class UserCredential(forms.Form):
cleaned_data["username"] = auth.username
else:
raise forms.ValidationError(_(u"Bad user"))
+ return cleaned_data
+
+
+class FederateUserCredential(UserCredential):
+ """Form used on the login page to retrive user credentials"""
+ username = forms.CharField(widget=forms.HiddenInput())
+ service = forms.CharField(widget=forms.HiddenInput(), required=False)
+ password = forms.CharField(widget=forms.HiddenInput())
+ ticket = forms.CharField(widget=forms.HiddenInput())
+ lt = forms.CharField(widget=forms.HiddenInput(), required=False)
+ method = forms.CharField(widget=forms.HiddenInput(), required=False)
+ warn = forms.BooleanField(widget=forms.HiddenInput(), required=False)
+
+ def clean(self):
+ cleaned_data = super(FederateUserCredential, self).clean()
+ try:
+ component = cleaned_data["username"].split('@')
+ username = '@'.join(component[:-1])
+ provider = component[-1]
+ user = models.FederatedUser.objects.get(username=username, provider=provider)
+ user.ticket = ""
+ user.save()
+ except models.FederatedUser.DoesNotExist:
+ raise
+ return cleaned_data
class TicketForm(forms.ModelForm):
diff --git a/cas_server/migrations/0005_auto_20160616_1018.py b/cas_server/migrations/0005_auto_20160616_1018.py
new file mode 100644
index 0000000..4a503ea
--- /dev/null
+++ b/cas_server/migrations/0005_auto_20160616_1018.py
@@ -0,0 +1,31 @@
+# -*- coding: utf-8 -*-
+# Generated by Django 1.9.6 on 2016-06-16 10:18
+from __future__ import unicode_literals
+
+from django.db import migrations, models
+import picklefield.fields
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('cas_server', '0004_auto_20151218_1032'),
+ ]
+
+ operations = [
+ migrations.CreateModel(
+ name='FederatedUser',
+ fields=[
+ ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
+ ('username', models.CharField(max_length=124)),
+ ('provider', models.CharField(max_length=124)),
+ ('attributs', picklefield.fields.PickledObjectField(editable=False)),
+ ('ticket', models.CharField(max_length=255)),
+ ('last_update', models.DateTimeField(auto_now=True)),
+ ],
+ ),
+ migrations.AlterUniqueTogether(
+ name='federateduser',
+ unique_together=set([('username', 'provider')]),
+ ),
+ ]
diff --git a/cas_server/models.py b/cas_server/models.py
index 9cb0ac5..746e7e6 100644
--- a/cas_server/models.py
+++ b/cas_server/models.py
@@ -35,6 +35,16 @@ SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
logger = logging.getLogger(__name__)
+class FederatedUser(models.Model):
+ class Meta:
+ unique_together = ("username", "provider")
+ username = models.CharField(max_length=124)
+ provider = models.CharField(max_length=124)
+ attributs = PickledObjectField()
+ ticket = models.CharField(max_length=255)
+ last_update = models.DateTimeField(auto_now=True)
+
+
class User(models.Model):
"""A user logged into the CAS"""
class Meta:
diff --git a/cas_server/templates/cas_server/federate.html b/cas_server/templates/cas_server/federate.html
new file mode 100644
index 0000000..1411513
--- /dev/null
+++ b/cas_server/templates/cas_server/federate.html
@@ -0,0 +1,22 @@
+{% extends "cas_server/base.html" %}
+{% load bootstrap3 %}
+{% load staticfiles %}
+{% load i18n %}
+{% block content %}
+
+{% if auto_submit %}
+
+{% endif %}
+{% endblock %}
+
diff --git a/cas_server/templates/cas_server/login.html b/cas_server/templates/cas_server/login.html
index b423797..d4559fe 100644
--- a/cas_server/templates/cas_server/login.html
+++ b/cas_server/templates/cas_server/login.html
@@ -3,11 +3,20 @@
{% load staticfiles %}
{% load i18n %}
{% block content %}
-
+{% if auto_submit %}
+
+{% endif %}
{% endblock %}
diff --git a/cas_server/urls.py b/cas_server/urls.py
index b2ed38b..2a87ef4 100644
--- a/cas_server/urls.py
+++ b/cas_server/urls.py
@@ -59,4 +59,5 @@ urlpatterns = patterns(
),
name='auth'
),
+ url("^federate(?:/(?P([^/]+)))?$", views.FederateAuth.as_view(), name='federateAuth'),
)
diff --git a/cas_server/utils.py b/cas_server/utils.py
index c3b2c32..ee6f1c4 100644
--- a/cas_server/utils.py
+++ b/cas_server/utils.py
@@ -20,6 +20,7 @@ import random
import string
import json
from importlib import import_module
+from datetime import datetime, timedelta
try:
from urlparse import urlparse, urlunparse, parse_qsl
@@ -60,7 +61,43 @@ def redirect_params(url_name, params=None):
def reverse_params(url_name, params=None, **kwargs):
url = reverse(url_name, **kwargs)
params = urlencode(params if params else {})
- return url + "?%s" % params
+ if params:
+ return url + "?%s" % params
+ else:
+ return url
+
+
+def copy_params(get_or_post_params, ignore=set()):
+ params = {}
+ for key in get_or_post_params:
+ if key not in ignore and get_or_post_params[key]:
+ params[key] = get_or_post_params[key]
+ return params
+
+
+def set_cookie(response, key, value, max_age):
+ expires = datetime.strftime(
+ datetime.utcnow() + timedelta(seconds=max_age),
+ "%a, %d-%b-%Y %H:%M:%S GMT"
+ )
+ response.set_cookie(
+ key,
+ value,
+ max_age=max_age,
+ expires=expires,
+ domain=settings.SESSION_COOKIE_DOMAIN,
+ secure=settings.SESSION_COOKIE_SECURE or None
+ )
+
+
+def get_current_url(request, ignore_params=set()):
+ protocol = 'https' if request.is_secure() else "http"
+ service_url = "%s://%s%s" % (protocol, request.get_host(), request.path)
+ if request.GET:
+ params = copy_params(request.GET, ignore_params)
+ if params:
+ service_url += "?%s" % urlencode(params)
+ return service_url
def update_url(url, params):
diff --git a/cas_server/views.py b/cas_server/views.py
index 4e27ead..733c53c 100644
--- a/cas_server/views.py
+++ b/cas_server/views.py
@@ -37,6 +37,7 @@ import cas_server.models as models
from .utils import JsonResponse
from .models import ServiceTicket, ProxyTicket, ProxyGrantingTicket
from .models import ServicePattern
+from .federate import CASFederateValidateUser
SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
@@ -113,7 +114,18 @@ class LogoutView(View, LogoutMixin):
"""methode called on GET request on this view"""
logger.info("logout requested")
self.init_get(request)
+ # if CAS federation mode is enable, bakup the provider before flushing the sessions
+ if settings.CAS_FEDERATE:
+ component = self.request.session.get("username").split('@')
+ provider = component[-1]
+ auth = CASFederateValidateUser(provider, service_url="")
session_nb = self.logout(self.request.GET.get("all"))
+ # if CAS federation mode is enable, redirect to user CAS logout page
+ if settings.CAS_FEDERATE:
+ params = utils.copy_params(request.GET)
+ url = utils.update_url(auth.get_logout_url(), params)
+ if url:
+ return HttpResponseRedirect(url)
# if service is set, redirect to service after logout
if self.service:
list(messages.get_messages(request)) # clean messages before leaving the django app
@@ -168,6 +180,45 @@ class LogoutView(View, LogoutMixin):
)
+class FederateAuth(View):
+ def post(self, request, provider=None):
+ form = forms.FederateSelect(request.POST)
+ if form.is_valid():
+ params = utils.copy_params(
+ request.POST,
+ ignore={"provider", "csrfmiddlewaretoken", "ticket"}
+ )
+ url = utils.reverse_params(
+ "cas_server:federateAuth",
+ kwargs=dict(provider=form.cleaned_data["provider"]),
+ params=params
+ )
+ response = HttpResponseRedirect(url)
+ if form.cleaned_data["remember"]:
+ max_age = 7 * 24 * 60 * 60 # one week
+ utils.set_cookie(response, "_remember_provider", request.POST["provider"], max_age)
+ return response
+ else:
+ return redirect("cas_server:login")
+
+ def get(self, request, provider=None):
+ if provider not in settings.CAS_FEDERATE_PROVIDERS:
+ return redirect("cas_server:login")
+ service_url = utils.get_current_url(request, {"ticket", "provider"})
+ auth = CASFederateValidateUser(provider, service_url)
+ if 'ticket' not in request.GET:
+ return HttpResponseRedirect(auth.get_login_url())
+ else:
+ ticket = request.GET['ticket']
+ if auth.verify_ticket(ticket):
+ params = utils.copy_params(request.GET)
+ params['username'] = "%s@%s" % (auth.username, auth.provider)
+ url = utils.reverse_params("cas_server:login", params)
+ return HttpResponseRedirect(url)
+ else:
+ return HttpResponseRedirect(auth.get_login_url())
+
+
class LoginView(View, LogoutMixin):
"""credential requestor / acceptor"""
@@ -206,6 +257,10 @@ class LoginView(View, LogoutMixin):
self.ajax = 'HTTP_X_AJAX' in request.META
if request.POST.get('warned') and request.POST['warned'] != "False":
self.warned = True
+ self.warn = request.POST.get('warn')
+ if settings.CAS_FEDERATE:
+ self.username = request.POST.get('username')
+ self.ticket = request.POST.get('ticket')
def check_lt(self):
# save LT for later check
@@ -248,6 +303,7 @@ class LoginView(View, LogoutMixin):
)
self.user.save()
elif ret == self.USER_LOGIN_FAILURE: # bad user login
+ self.ticket = None
self.logout()
elif ret == self.USER_ALREADY_LOGGED:
pass
@@ -291,6 +347,10 @@ class LoginView(View, LogoutMixin):
self.gateway = request.GET.get('gateway')
self.method = request.GET.get('method')
self.ajax = 'HTTP_X_AJAX' in request.META
+ self.warn = request.GET.get('warn')
+ if settings.CAS_FEDERATE:
+ self.username = request.GET.get('username')
+ self.ticket = request.GET.get('ticket')
def get(self, request, *args, **kwargs):
"""methode called on GET request on this view"""
@@ -308,15 +368,28 @@ class LoginView(View, LogoutMixin):
return self.USER_AUTHENTICATED
def init_form(self, values=None):
- self.form = forms.UserCredential(
- values,
- initial={
- 'service': self.service,
- 'method': self.method,
- 'warn': self.request.session.get("warn"),
- 'lt': self.request.session['lt'][-1]
- }
- )
+ form_initial = {
+ 'service': self.service,
+ 'method': self.method,
+ 'warn': self.warn or self.request.session.get("warn"),
+ 'lt': self.request.session['lt'][-1]
+ }
+ if settings.CAS_FEDERATE:
+ if self.username and self.ticket:
+ form_initial['username'] = self.username
+ form_initial['password'] = self.ticket
+ form_initial['ticket'] = self.ticket
+ self.form = forms.FederateUserCredential(
+ values,
+ initial=form_initial
+ )
+ else:
+ self.form = forms.FederateSelect(values, initial=form_initial)
+ else:
+ self.form = forms.UserCredential(
+ values,
+ initial=form_initial
+ )
def service_login(self):
"""Perform login agains a service"""
@@ -483,7 +556,38 @@ class LoginView(View, LogoutMixin):
}
return JsonResponse(self.request, data)
else:
- return render(self.request, settings.CAS_LOGIN_TEMPLATE, {'form': self.form})
+ if settings.CAS_FEDERATE:
+ if self.username and self.ticket:
+ return render(
+ self.request,
+ settings.CAS_LOGIN_TEMPLATE,
+ {
+ 'form': self.form,
+ 'auto_submit': True,
+ 'post_url': reverse("cas_server:login")
+ }
+ )
+ else:
+ if (
+ self.request.COOKIES.get('_remember_provider') and
+ self.request.COOKIES['_remember_provider'] in
+ settings.CAS_FEDERATE_PROVIDERS
+ ):
+ params = utils.copy_params(self.request.GET)
+ url = utils.reverse_params(
+ "cas_server:federateAuth",
+ params=params,
+ kwargs=dict(provider=self.request.COOKIES['_remember_provider'])
+ )
+ return HttpResponseRedirect(url)
+ else:
+ return render(
+ self.request,
+ settings.CAS_FEDERATE_TEMPLATE,
+ {'form': self.form}
+ )
+ else:
+ return render(self.request, settings.CAS_LOGIN_TEMPLATE, {'form': self.form})
def common(self):
"""Part execute uppon GET and POST request"""