Merge branch 'dev' into federate
This commit is contained in:
		
							
								
								
									
										8
									
								
								.coveragerc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								.coveragerc
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,8 @@
 | 
			
		||||
[report]
 | 
			
		||||
exclude_lines =
 | 
			
		||||
    pragma: no cover
 | 
			
		||||
    def __repr__
 | 
			
		||||
    def __unicode__
 | 
			
		||||
    raise AssertionError
 | 
			
		||||
    raise NotImplementedError
 | 
			
		||||
    if six.PY3:
 | 
			
		||||
							
								
								
									
										4
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@@ -1,5 +1,6 @@
 | 
			
		||||
*.pyc
 | 
			
		||||
*.egg-info
 | 
			
		||||
*~
 | 
			
		||||
*.swp
 | 
			
		||||
 | 
			
		||||
build/
 | 
			
		||||
@@ -8,6 +9,9 @@ cas/
 | 
			
		||||
dist/
 | 
			
		||||
db.sqlite3
 | 
			
		||||
manage.py
 | 
			
		||||
coverage.xml
 | 
			
		||||
 | 
			
		||||
.tox
 | 
			
		||||
test_venv
 | 
			
		||||
.coverage
 | 
			
		||||
htmlcov/
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										11
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										11
									
								
								Makefile
									
									
									
									
									
								
							@@ -44,3 +44,14 @@ test_project: test_venv test_venv/cas/manage.py
 | 
			
		||||
 | 
			
		||||
run_test_server: test_project
 | 
			
		||||
	test_venv/bin/python test_venv/cas/manage.py runserver
 | 
			
		||||
 | 
			
		||||
coverage: test_venv
 | 
			
		||||
	test_venv/bin/pip install coverage
 | 
			
		||||
	test_venv/bin/coverage run --source='cas_server' --omit='cas_server/migrations*' run_tests
 | 
			
		||||
	test_venv/bin/coverage html
 | 
			
		||||
	rm htmlcov/coverage_html.js  # I am really pissed off by those keybord shortcuts
 | 
			
		||||
 | 
			
		||||
coverage_codacy: coverage
 | 
			
		||||
	test_venv/bin/coverage xml
 | 
			
		||||
	test_venv/bin/pip install codacy-coverage
 | 
			
		||||
	test_venv/bin/python-codacy-coverage -r coverage.xml
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										62
									
								
								README.rst
									
									
									
									
									
								
							
							
						
						
									
										62
									
								
								README.rst
									
									
									
									
									
								
							@@ -10,8 +10,14 @@ CAS Server
 | 
			
		||||
.. image:: https://img.shields.io/pypi/l/django-cas-server.svg
 | 
			
		||||
    :target: https://www.gnu.org/licenses/gpl-3.0.html
 | 
			
		||||
 | 
			
		||||
.. image:: https://api.codacy.com/project/badge/Grade/255c21623d6946ef8802fa7995b61366
 | 
			
		||||
    :target: https://www.codacy.com/app/valentin-samir/django-cas-server
 | 
			
		||||
 | 
			
		||||
.. image:: https://api.codacy.com/project/badge/Coverage/255c21623d6946ef8802fa7995b61366
 | 
			
		||||
    :target: https://www.codacy.com/app/valentin-samir/django-cas-server
 | 
			
		||||
 | 
			
		||||
CAS Server is a Django application implementing the `CAS Protocol 3.0 Specification
 | 
			
		||||
<https://jasig.github.io/cas/development/protocol/CAS-Protocol-Specification.html>`_.
 | 
			
		||||
<https://apereo.github.io/cas/4.2.x/protocol/CAS-Protocol-Specification.html>`_.
 | 
			
		||||
 | 
			
		||||
By defaut, the authentication process use django internal users but you can easily
 | 
			
		||||
use any sources (see auth classes in the auth.py file)
 | 
			
		||||
@@ -37,6 +43,15 @@ Features
 | 
			
		||||
 | 
			
		||||
Quick start
 | 
			
		||||
-----------
 | 
			
		||||
0. If you want to make a virtualenv for ``django-cas-server``, you will need the following
 | 
			
		||||
   dependencies on a bare debian like system::
 | 
			
		||||
 | 
			
		||||
    virtualenv build-essential python-dev libxml2-dev libxslt1-dev zlib1g-dev
 | 
			
		||||
 | 
			
		||||
   If you want to use python3 instead of python2, replace ``python-dev`` with ``python3-dev``.
 | 
			
		||||
 | 
			
		||||
   If you intend to run the tox tests you will also need ``python3.4-dev`` depending of the current
 | 
			
		||||
   version of python3 on your system.
 | 
			
		||||
 | 
			
		||||
1. Add "cas_server" to your INSTALLED_APPS setting like this::
 | 
			
		||||
 | 
			
		||||
@@ -70,7 +85,7 @@ Quick start
 | 
			
		||||
4. You should add some management commands to a crontab: ``clearsessions``,
 | 
			
		||||
   ``cas_clean_tickets`` and ``cas_clean_sessions``.
 | 
			
		||||
 | 
			
		||||
 * ``clearsessions``:  please see `Clearing the session store <https://docs.djangoproject.com/en/1.9/topics/http/sessions/#clearing-the-session-store>`_.
 | 
			
		||||
 * ``clearsessions``:  please see `Clearing the session store <https://docs.djangoproject.com/en/stable/topics/http/sessions/#clearing-the-session-store>`_.
 | 
			
		||||
 * ``cas_clean_tickets``: old tickets and timed-out tickets do not get purge from
 | 
			
		||||
   the database automatically. They are just marked as invalid. ``cas_clean_tickets``
 | 
			
		||||
   is a clean-up management command for this purpose. It send SingleLogOut request
 | 
			
		||||
@@ -122,14 +137,14 @@ Template settings:
 | 
			
		||||
 | 
			
		||||
Authentication settings:
 | 
			
		||||
 | 
			
		||||
*  ``CAS_AUTH_CLASS``: A dotted path to a class implementing ``cas_server.auth.AuthUser``.
 | 
			
		||||
   The default is ``"cas_server.auth.DjangoAuthUser"``
 | 
			
		||||
*  ``CAS_AUTH_CLASS``: A dotted path to a class or a class implementing
 | 
			
		||||
  ``cas_server.auth.AuthUser``. The default is ``"cas_server.auth.DjangoAuthUser"``
 | 
			
		||||
 | 
			
		||||
*  ``SESSION_COOKIE_AGE``: This is a django settings. Here, it control the delay in seconds after
 | 
			
		||||
   which inactive users are logged out. The default is ``1209600`` (2 weeks). You probably should
 | 
			
		||||
   reduce it to something like ``86400`` seconds (1 day).
 | 
			
		||||
 | 
			
		||||
* ``CAS_PROXY_CA_CERTIFICATE_PATH``: Path to certificates authority file. Usually on linux
 | 
			
		||||
* ``CAS_PROXY_CA_CERTIFICATE_PATH``: Path to certificate authorities file. Usually on linux
 | 
			
		||||
  the local CAs are in ``/etc/ssl/certs/ca-certificates.crt``. The default is ``True`` which
 | 
			
		||||
  tell requests to use its internal certificat authorities. Settings it to ``False`` should
 | 
			
		||||
  disable all x509 certificates validation and MUST not be done in production.
 | 
			
		||||
@@ -162,7 +177,7 @@ Tickets validity settings:
 | 
			
		||||
  application. The default is ``60``.
 | 
			
		||||
* ``CAS_PGT_VALIDITY``: Number of seconds the proxy granting tickets are valid.
 | 
			
		||||
  The default is ``3600`` (1 hour).
 | 
			
		||||
* ``CAS_TICKET_TIMEOUT``: Number of seconds a ticket is kept is the database before sending
 | 
			
		||||
* ``CAS_TICKET_TIMEOUT``: Number of seconds a ticket is kept in the database before sending
 | 
			
		||||
  Single Log Out request and being cleared. The default is ``86400`` (24 hours).
 | 
			
		||||
 | 
			
		||||
Tickets miscellaneous settings:
 | 
			
		||||
@@ -184,12 +199,12 @@ Tickets miscellaneous settings:
 | 
			
		||||
* ``CAS_SERVICE_TICKET_PREFIX``: Prefix of service tickets. The default is ``"ST"``.
 | 
			
		||||
  The CAS specification mandate that service tickets MUST begin with the characters ST
 | 
			
		||||
  so you should not change this.
 | 
			
		||||
* ``CAS_PROXY_TICKET_PREFIX``: Prefix of proxy ticket. The default is ``"ST"``.
 | 
			
		||||
* ``CAS_PROXY_TICKET_PREFIX``: Prefix of proxy ticket. The default is ``"PT"``.
 | 
			
		||||
* ``CAS_PROXY_GRANTING_TICKET_PREFIX``: Prefix of proxy granting ticket. The default is ``"PGT"``.
 | 
			
		||||
* ``CAS_PROXY_GRANTING_TICKET_IOU_PREFIX``: Prefix of proxy granting ticket IOU. The default is ``"PGTIOU"``.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Mysql backend settings. Only usefull is you use the mysql authentication backend:
 | 
			
		||||
Mysql backend settings. Only usefull if you are using the mysql authentication backend:
 | 
			
		||||
 | 
			
		||||
* ``CAS_SQL_HOST``: Host for the SQL server. The default is ``"localhost"``.
 | 
			
		||||
* ``CAS_SQL_USERNAME``: Username for connecting to the SQL server.
 | 
			
		||||
@@ -200,8 +215,29 @@ Mysql backend settings. Only usefull is you use the mysql authentication backend
 | 
			
		||||
  The username must be in field ``username``, the password in ``password``,
 | 
			
		||||
  additional fields are used as the user attributes.
 | 
			
		||||
  The default is ``"SELECT user AS usersame, pass AS password, users.* FROM users WHERE user = %s"``
 | 
			
		||||
* ``CAS_SQL_PASSWORD_CHECK``: The method used to check the user password. Must be
 | 
			
		||||
  ``"crypt"`` or ``"plain``". The default is ``"crypt"``.
 | 
			
		||||
* ``CAS_SQL_PASSWORD_CHECK``: The method used to check the user password. Must be one of the following:
 | 
			
		||||
 | 
			
		||||
    * ``"crypt"`` (see <https://en.wikipedia.org/wiki/Crypt_(C)>), the password in the database
 | 
			
		||||
      should begin this $
 | 
			
		||||
    * ``"ldap"`` (see https://tools.ietf.org/id/draft-stroeder-hashed-userpassword-values-01.html)
 | 
			
		||||
      the password in the database must begin with one of {MD5}, {SMD5}, {SHA}, {SSHA}, {SHA256},
 | 
			
		||||
      {SSHA256}, {SHA384}, {SSHA384}, {SHA512}, {SSHA512}, {CRYPT}.
 | 
			
		||||
    * ``"hex_HASH_NAME"`` with ``HASH_NAME`` in md5, sha1, sha224, sha256, sha384, sha512.
 | 
			
		||||
      The hashed password in the database is compare to the hexadecimal digest of the clear
 | 
			
		||||
      password hashed with the corresponding algorithm.
 | 
			
		||||
    * ``"plain"``, the password in the database must be in clear.
 | 
			
		||||
 | 
			
		||||
  The default is ``"crypt"``.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Test backend settings. Only usefull if you are using the test authentication backend:
 | 
			
		||||
 | 
			
		||||
* ``CAS_TEST_USER``: Username of the test user. The default is ``"test"``.
 | 
			
		||||
* ``CAS_TEST_PASSWORD``: Password of the test user. The default is ``"test"``.
 | 
			
		||||
* ``CAS_TEST_ATTRIBUTES``: Attributes of the test user. The default is
 | 
			
		||||
  ``{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net',
 | 
			
		||||
  'alias': ['demo1', 'demo2']}``.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Authentication backend
 | 
			
		||||
----------------------
 | 
			
		||||
@@ -209,8 +245,8 @@ Authentication backend
 | 
			
		||||
``django-cas-server`` comes with some authentication backends:
 | 
			
		||||
 | 
			
		||||
* dummy backend ``cas_server.auth.DummyAuthUser``: all authentication attempt fails.
 | 
			
		||||
* test backend ``cas_server.auth.TestAuthUser``: username is ``test`` and password is ``test``
 | 
			
		||||
  the returned attributes for the user are: ``{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'}``
 | 
			
		||||
* test backend ``cas_server.auth.TestAuthUser``: username, password and returned attributes
 | 
			
		||||
  for the user are defined by the ``CAS_TEST_*`` settings.
 | 
			
		||||
* django backend ``cas_server.auth.DjangoAuthUser``: Users are authenticated agains django users system.
 | 
			
		||||
  This is the default backend. The returned attributes are the fields available on the user model.
 | 
			
		||||
* mysql backend ``cas_server.auth.MysqlAuthUser``: see the 'Mysql backend settings' section.
 | 
			
		||||
@@ -222,7 +258,7 @@ Logs
 | 
			
		||||
----
 | 
			
		||||
 | 
			
		||||
``django-cas-server`` logs most of its actions. To enable login, you must set the ``LOGGING``
 | 
			
		||||
(https://docs.djangoproject.com/en/dev/topics/logging) variable is ``settings.py``.
 | 
			
		||||
(https://docs.djangoproject.com/en/stable/topics/logging) variable in ``settings.py``.
 | 
			
		||||
 | 
			
		||||
Users successful actions (login, logout) are logged with the level ``INFO``, failures are logged
 | 
			
		||||
with the level ``WARNING`` and user attributes transmitted to a service are logged with the level ``DEBUG``.
 | 
			
		||||
 
 | 
			
		||||
@@ -9,4 +9,4 @@
 | 
			
		||||
#
 | 
			
		||||
# (c) 2015 Valentin Samir
 | 
			
		||||
 | 
			
		||||
default_app_config = 'cas_server.apps.AppConfig'
 | 
			
		||||
default_app_config = 'cas_server.apps.CasAppConfig'
 | 
			
		||||
 
 | 
			
		||||
@@ -14,9 +14,9 @@ from .models import ServiceTicket, ProxyTicket, ProxyGrantingTicket, User, Servi
 | 
			
		||||
from .models import Username, ReplaceAttributName, ReplaceAttributValue, FilterAttributValue
 | 
			
		||||
from .forms import TicketForm
 | 
			
		||||
 | 
			
		||||
tickets_readonly_fields = ('validate', 'service', 'service_pattern',
 | 
			
		||||
TICKETS_READONLY_FIELDS = ('validate', 'service', 'service_pattern',
 | 
			
		||||
                           'creation', 'renew', 'single_log_out', 'value')
 | 
			
		||||
tickets_fields = ('validate', 'service', 'service_pattern',
 | 
			
		||||
TICKETS_FIELDS = ('validate', 'service', 'service_pattern',
 | 
			
		||||
                  'creation', 'renew', 'single_log_out')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -25,8 +25,8 @@ class ServiceTicketInline(admin.TabularInline):
 | 
			
		||||
    model = ServiceTicket
 | 
			
		||||
    extra = 0
 | 
			
		||||
    form = TicketForm
 | 
			
		||||
    readonly_fields = tickets_readonly_fields
 | 
			
		||||
    fields = tickets_fields
 | 
			
		||||
    readonly_fields = TICKETS_READONLY_FIELDS
 | 
			
		||||
    fields = TICKETS_FIELDS
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ProxyTicketInline(admin.TabularInline):
 | 
			
		||||
@@ -34,8 +34,8 @@ class ProxyTicketInline(admin.TabularInline):
 | 
			
		||||
    model = ProxyTicket
 | 
			
		||||
    extra = 0
 | 
			
		||||
    form = TicketForm
 | 
			
		||||
    readonly_fields = tickets_readonly_fields
 | 
			
		||||
    fields = tickets_fields
 | 
			
		||||
    readonly_fields = TICKETS_READONLY_FIELDS
 | 
			
		||||
    fields = TICKETS_FIELDS
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ProxyGrantingInline(admin.TabularInline):
 | 
			
		||||
@@ -43,8 +43,8 @@ class ProxyGrantingInline(admin.TabularInline):
 | 
			
		||||
    model = ProxyGrantingTicket
 | 
			
		||||
    extra = 0
 | 
			
		||||
    form = TicketForm
 | 
			
		||||
    readonly_fields = tickets_readonly_fields
 | 
			
		||||
    fields = tickets_fields[1:]
 | 
			
		||||
    readonly_fields = TICKETS_READONLY_FIELDS
 | 
			
		||||
    fields = TICKETS_FIELDS[1:]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class UserAdmin(admin.ModelAdmin):
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,6 @@ from django.utils.translation import ugettext_lazy as _
 | 
			
		||||
from django.apps import AppConfig
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AppConfig(AppConfig):
 | 
			
		||||
class CasAppConfig(AppConfig):
 | 
			
		||||
    name = 'cas_server'
 | 
			
		||||
    verbose_name = _('Central Authentication Service')
 | 
			
		||||
 
 | 
			
		||||
@@ -15,10 +15,10 @@ from django.contrib.auth import get_user_model
 | 
			
		||||
from django.utils import timezone
 | 
			
		||||
 | 
			
		||||
from datetime import timedelta
 | 
			
		||||
try:
 | 
			
		||||
try:  # pragma: no cover
 | 
			
		||||
    import MySQLdb
 | 
			
		||||
    import MySQLdb.cursors
 | 
			
		||||
    import crypt
 | 
			
		||||
    from utils import check_password
 | 
			
		||||
except ImportError:
 | 
			
		||||
    MySQLdb = None
 | 
			
		||||
 | 
			
		||||
@@ -31,14 +31,14 @@ class AuthUser(object):
 | 
			
		||||
 | 
			
		||||
    def test_password(self, password):
 | 
			
		||||
        """test `password` agains the user"""
 | 
			
		||||
        raise NotImplemented()
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def attributs(self):
 | 
			
		||||
        """return a dict of user attributes"""
 | 
			
		||||
        raise NotImplemented()
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DummyAuthUser(AuthUser):
 | 
			
		||||
class DummyAuthUser(AuthUser):  # pragma: no cover
 | 
			
		||||
    """A Dummy authentication class"""
 | 
			
		||||
 | 
			
		||||
    def __init__(self, username):
 | 
			
		||||
@@ -62,14 +62,14 @@ class TestAuthUser(AuthUser):
 | 
			
		||||
 | 
			
		||||
    def test_password(self, password):
 | 
			
		||||
        """test `password` agains the user"""
 | 
			
		||||
        return self.username == "test" and password == "test"
 | 
			
		||||
        return self.username == settings.CAS_TEST_USER and password == settings.CAS_TEST_PASSWORD
 | 
			
		||||
 | 
			
		||||
    def attributs(self):
 | 
			
		||||
        """return a dict of user attributes"""
 | 
			
		||||
        return {'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'}
 | 
			
		||||
        return settings.CAS_TEST_ATTRIBUTES
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MysqlAuthUser(AuthUser):
 | 
			
		||||
class MysqlAuthUser(AuthUser):  # pragma: no cover
 | 
			
		||||
    """A mysql auth class: authentication user agains a mysql database"""
 | 
			
		||||
    user = None
 | 
			
		||||
 | 
			
		||||
@@ -94,30 +94,25 @@ class MysqlAuthUser(AuthUser):
 | 
			
		||||
 | 
			
		||||
    def test_password(self, password):
 | 
			
		||||
        """test `password` agains the user"""
 | 
			
		||||
        if not self.user:
 | 
			
		||||
            return False
 | 
			
		||||
        if self.user:
 | 
			
		||||
            return check_password(
 | 
			
		||||
                settings.CAS_SQL_PASSWORD_CHECK,
 | 
			
		||||
                password,
 | 
			
		||||
                self.user["password"],
 | 
			
		||||
                settings.CAS_SQL_DBCHARSET
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            if settings.CAS_SQL_PASSWORD_CHECK == "plain":
 | 
			
		||||
                return password == self.user["password"]
 | 
			
		||||
            elif settings.CAS_SQL_PASSWORD_CHECK == "crypt":
 | 
			
		||||
                if self.user["password"].startswith('$'):
 | 
			
		||||
                    salt = '$'.join(self.user["password"].split('$', 3)[:-1])
 | 
			
		||||
                    return crypt.crypt(password, salt) == self.user["password"]
 | 
			
		||||
                else:
 | 
			
		||||
                    return crypt.crypt(
 | 
			
		||||
                        password,
 | 
			
		||||
                        self.user["password"][:2]
 | 
			
		||||
                    ) == self.user["password"]
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
    def attributs(self):
 | 
			
		||||
        """return a dict of user attributes"""
 | 
			
		||||
        if not self.user:
 | 
			
		||||
            return {}
 | 
			
		||||
        else:
 | 
			
		||||
        if self.user:
 | 
			
		||||
            return self.user
 | 
			
		||||
        else:
 | 
			
		||||
            return {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DjangoAuthUser(AuthUser):
 | 
			
		||||
class DjangoAuthUser(AuthUser):  # pragma: no cover
 | 
			
		||||
    """A django auth class: authenticate user agains django internal users"""
 | 
			
		||||
    user = None
 | 
			
		||||
 | 
			
		||||
@@ -131,21 +126,20 @@ class DjangoAuthUser(AuthUser):
 | 
			
		||||
 | 
			
		||||
    def test_password(self, password):
 | 
			
		||||
        """test `password` agains the user"""
 | 
			
		||||
        if not self.user:
 | 
			
		||||
            return False
 | 
			
		||||
        else:
 | 
			
		||||
        if self.user:
 | 
			
		||||
            return self.user.check_password(password)
 | 
			
		||||
        else:
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
    def attributs(self):
 | 
			
		||||
        """return a dict of user attributes"""
 | 
			
		||||
        if not self.user:
 | 
			
		||||
            return {}
 | 
			
		||||
        else:
 | 
			
		||||
        if self.user:
 | 
			
		||||
            attr = {}
 | 
			
		||||
            for field in self.user._meta.fields:
 | 
			
		||||
                attr[field.attname] = getattr(self.user, field.attname)
 | 
			
		||||
            return attr
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            return {}
 | 
			
		||||
 | 
			
		||||
class CASFederateAuth(AuthUser):
 | 
			
		||||
    user = None
 | 
			
		||||
 
 | 
			
		||||
@@ -78,6 +78,17 @@ setting_default('CAS_SQL_USER_QUERY', 'SELECT user AS usersame, pass AS '
 | 
			
		||||
                'password, users.* FROM users WHERE user = %s')
 | 
			
		||||
setting_default('CAS_SQL_PASSWORD_CHECK', 'crypt')  # crypt or plain
 | 
			
		||||
 | 
			
		||||
setting_default('CAS_TEST_USER', 'test')
 | 
			
		||||
setting_default('CAS_TEST_PASSWORD', 'test')
 | 
			
		||||
setting_default(
 | 
			
		||||
    'CAS_TEST_ATTRIBUTES',
 | 
			
		||||
    {
 | 
			
		||||
        'nom': 'Nymous',
 | 
			
		||||
        'prenom': 'Ano',
 | 
			
		||||
        'email': 'anonymous@example.net',
 | 
			
		||||
        'alias': ['demo1', 'demo2']
 | 
			
		||||
    }
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
setting_default('CAS_FEDERATE', False)
 | 
			
		||||
# A dict of "provider suffix" -> (provider CAS server url, CAS version, verbose name)
 | 
			
		||||
 
 | 
			
		||||
@@ -1,55 +1,52 @@
 | 
			
		||||
function cas_login(cas_server_login, service, login_service, callback){
 | 
			
		||||
  url = cas_server_login + '?service=' + encodeURIComponent(service);
 | 
			
		||||
  var url = cas_server_login + "?service=" + encodeURIComponent(service);
 | 
			
		||||
  $.ajax({
 | 
			
		||||
    type: 'GET',
 | 
			
		||||
    url:url,
 | 
			
		||||
    beforeSend: function (request) {   
 | 
			
		||||
    type: "GET",
 | 
			
		||||
    url,
 | 
			
		||||
    beforeSend(request) {
 | 
			
		||||
      request.setRequestHeader("X-AJAX", "1");
 | 
			
		||||
    },
 | 
			
		||||
    xhrFields: {
 | 
			
		||||
      withCredentials: true
 | 
			
		||||
    },
 | 
			
		||||
    success: function(data, textStatus, request){
 | 
			
		||||
      if(data.status == 'success'){
 | 
			
		||||
    success(data, textStatus, request){
 | 
			
		||||
      if(data.status === "success"){
 | 
			
		||||
        $.ajax({
 | 
			
		||||
          type: 'GET',
 | 
			
		||||
          type: "GET",
 | 
			
		||||
          url: data.url,
 | 
			
		||||
          xhrFields: {
 | 
			
		||||
            withCredentials: true
 | 
			
		||||
          },
 | 
			
		||||
          success: callback,
 | 
			
		||||
          error: function (request, textStatus, errorThrown) {},
 | 
			
		||||
          error(request, textStatus, errorThrown) {},
 | 
			
		||||
        });
 | 
			
		||||
      } else {
 | 
			
		||||
        if(data.detail == "login required"){
 | 
			
		||||
          window.location.href = cas_server_login + '?service=' + encodeURIComponent(login_service);
 | 
			
		||||
        if(data.detail === "login required"){
 | 
			
		||||
          window.location.href = cas_server_login + "?service=" + encodeURIComponent(login_service);
 | 
			
		||||
        } else {
 | 
			
		||||
          alert('error: ' + data.messages[1].message);
 | 
			
		||||
          alert("error: " + data.messages[1].message);
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    },
 | 
			
		||||
    error: function (request, textStatus, errorThrown) {},
 | 
			
		||||
    error(request, textStatus, errorThrown) {},
 | 
			
		||||
  });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
function cas_logout(cas_server_logout){
 | 
			
		||||
  $.ajax({
 | 
			
		||||
    type: 'GET',
 | 
			
		||||
    url:cas_server_logout,
 | 
			
		||||
    beforeSend: function (request) {   
 | 
			
		||||
    type: "GET",
 | 
			
		||||
    url: cas_server_logout,
 | 
			
		||||
    beforeSend(request) {
 | 
			
		||||
      request.setRequestHeader("X-AJAX", "1");
 | 
			
		||||
    },
 | 
			
		||||
    xhrFields: {
 | 
			
		||||
      withCredentials: true
 | 
			
		||||
    },
 | 
			
		||||
    error: function (request, textStatus, errorThrown) {},
 | 
			
		||||
    success: function(data, textStatus, request){
 | 
			
		||||
      if(data.status == 'error'){
 | 
			
		||||
        alert('error: ' + data.messages[1].message);
 | 
			
		||||
    error(request, textStatus, errorThrown) {},
 | 
			
		||||
    success(data, textStatus, request){
 | 
			
		||||
      if(data.status === "error"){
 | 
			
		||||
        alert("error: " + data.messages[1].message);
 | 
			
		||||
      }
 | 
			
		||||
    },
 | 
			
		||||
  });
 | 
			
		||||
}
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
   
 | 
			
		||||
 
 | 
			
		||||
@@ -43,14 +43,14 @@ body {
 | 
			
		||||
 | 
			
		||||
@media screen and (max-width: 680px) {
 | 
			
		||||
    #app-name {
 | 
			
		||||
        margin: 0px;
 | 
			
		||||
        margin: 0;
 | 
			
		||||
    }
 | 
			
		||||
    #app-name img {
 | 
			
		||||
        display: block;
 | 
			
		||||
        margin: auto;
 | 
			
		||||
    }
 | 
			
		||||
    body {
 | 
			
		||||
      padding-top: 0px;
 | 
			
		||||
      padding-bottom: 0px;
 | 
			
		||||
      padding-top: 0;
 | 
			
		||||
      padding-bottom: 0;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										943
									
								
								cas_server/tests.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										943
									
								
								cas_server/tests.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,943 @@
 | 
			
		||||
from .default_settings import settings
 | 
			
		||||
 | 
			
		||||
import django
 | 
			
		||||
from django.test import TestCase
 | 
			
		||||
from django.test import Client
 | 
			
		||||
 | 
			
		||||
import re
 | 
			
		||||
import six
 | 
			
		||||
import random
 | 
			
		||||
from lxml import etree
 | 
			
		||||
from six.moves import range
 | 
			
		||||
 | 
			
		||||
from cas_server import models
 | 
			
		||||
from cas_server import utils
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def copy_form(form):
 | 
			
		||||
    """Copy form value into a dict"""
 | 
			
		||||
    params = {}
 | 
			
		||||
    for field in form:
 | 
			
		||||
        if field.value():
 | 
			
		||||
            params[field.name] = field.value()
 | 
			
		||||
        else:
 | 
			
		||||
            params[field.name] = ""
 | 
			
		||||
    return params
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_login_page_params(client=None):
 | 
			
		||||
    """Return a client and the POST params for the client to login"""
 | 
			
		||||
    if client is None:
 | 
			
		||||
        client = Client()
 | 
			
		||||
    response = client.get('/login')
 | 
			
		||||
    params = copy_form(response.context["form"])
 | 
			
		||||
    return client, params
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_auth_client(**update):
 | 
			
		||||
    """return a authenticated client"""
 | 
			
		||||
    client, params = get_login_page_params()
 | 
			
		||||
    params["username"] = settings.CAS_TEST_USER
 | 
			
		||||
    params["password"] = settings.CAS_TEST_PASSWORD
 | 
			
		||||
    params.update(update)
 | 
			
		||||
 | 
			
		||||
    client.post('/login', params)
 | 
			
		||||
    return client
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_user_ticket_request(service):
 | 
			
		||||
    """Make an auth client to request a ticket for `service`, return the tuple (user, ticket)"""
 | 
			
		||||
    client = get_auth_client()
 | 
			
		||||
    response = client.get("/login", {"service": service})
 | 
			
		||||
    ticket_value = response['Location'].split('ticket=')[-1]
 | 
			
		||||
    user = models.User.objects.get(
 | 
			
		||||
        username=settings.CAS_TEST_USER,
 | 
			
		||||
        session_key=client.session.session_key
 | 
			
		||||
    )
 | 
			
		||||
    ticket = models.ServiceTicket.objects.get(value=ticket_value)
 | 
			
		||||
    return (user, ticket)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_pgt():
 | 
			
		||||
    """return a dict contening a service, user and PGT ticket for this service"""
 | 
			
		||||
    (host, port) = utils.PGTUrlHandler.run()[1:3]
 | 
			
		||||
    service = "http://%s:%s" % (host, port)
 | 
			
		||||
 | 
			
		||||
    (user, ticket) = get_user_ticket_request(service)
 | 
			
		||||
 | 
			
		||||
    client = Client()
 | 
			
		||||
    client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service})
 | 
			
		||||
    params = utils.PGTUrlHandler.PARAMS.copy()
 | 
			
		||||
 | 
			
		||||
    params["service"] = service
 | 
			
		||||
    params["user"] = user
 | 
			
		||||
 | 
			
		||||
    return params
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CheckPasswordCase(TestCase):
 | 
			
		||||
    """Tests for the utils function `utils.check_password`"""
 | 
			
		||||
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        """Generate random bytes string that will be used ass passwords"""
 | 
			
		||||
        self.password1 = utils.gen_saml_id()
 | 
			
		||||
        self.password2 = utils.gen_saml_id()
 | 
			
		||||
        if not isinstance(self.password1, bytes):
 | 
			
		||||
            self.password1 = self.password1.encode("utf8")
 | 
			
		||||
            self.password2 = self.password2.encode("utf8")
 | 
			
		||||
 | 
			
		||||
    def test_setup(self):
 | 
			
		||||
        """check that generated password are bytes"""
 | 
			
		||||
        self.assertIsInstance(self.password1, bytes)
 | 
			
		||||
        self.assertIsInstance(self.password2, bytes)
 | 
			
		||||
 | 
			
		||||
    def test_plain(self):
 | 
			
		||||
        """test the plain auth method"""
 | 
			
		||||
        self.assertTrue(utils.check_password("plain", self.password1, self.password1, "utf8"))
 | 
			
		||||
        self.assertFalse(utils.check_password("plain", self.password1, self.password2, "utf8"))
 | 
			
		||||
 | 
			
		||||
    def test_crypt(self):
 | 
			
		||||
        """test the crypt auth method"""
 | 
			
		||||
        if six.PY3:
 | 
			
		||||
            hashed_password1 = utils.crypt.crypt(
 | 
			
		||||
                self.password1.decode("utf8"),
 | 
			
		||||
                "$6$UVVAQvrMyXMF3FF3"
 | 
			
		||||
            ).encode("utf8")
 | 
			
		||||
        else:
 | 
			
		||||
            hashed_password1 = utils.crypt.crypt(self.password1, "$6$UVVAQvrMyXMF3FF3")
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(utils.check_password("crypt", self.password1, hashed_password1, "utf8"))
 | 
			
		||||
        self.assertFalse(utils.check_password("crypt", self.password2, hashed_password1, "utf8"))
 | 
			
		||||
 | 
			
		||||
    def test_ldap_ssha(self):
 | 
			
		||||
        """test the ldap auth method with a {SSHA} scheme"""
 | 
			
		||||
        salt = b"UVVAQvrMyXMF3FF3"
 | 
			
		||||
        hashed_password1 = utils.LdapHashUserPassword.hash(b'{SSHA}', self.password1, salt, "utf8")
 | 
			
		||||
 | 
			
		||||
        self.assertIsInstance(hashed_password1, bytes)
 | 
			
		||||
        self.assertTrue(utils.check_password("ldap", self.password1, hashed_password1, "utf8"))
 | 
			
		||||
        self.assertFalse(utils.check_password("ldap", self.password2, hashed_password1, "utf8"))
 | 
			
		||||
 | 
			
		||||
    def test_hex_md5(self):
 | 
			
		||||
        """test the hex_md5 auth method"""
 | 
			
		||||
        hashed_password1 = utils.hashlib.md5(self.password1).hexdigest()
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(utils.check_password("hex_md5", self.password1, hashed_password1, "utf8"))
 | 
			
		||||
        self.assertFalse(utils.check_password("hex_md5", self.password2, hashed_password1, "utf8"))
 | 
			
		||||
 | 
			
		||||
    def test_hex_sha512(self):
 | 
			
		||||
        """test the hex_sha512 auth method"""
 | 
			
		||||
        hashed_password1 = utils.hashlib.sha512(self.password1).hexdigest()
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            utils.check_password("hex_sha512", self.password1, hashed_password1, "utf8")
 | 
			
		||||
        )
 | 
			
		||||
        self.assertFalse(
 | 
			
		||||
            utils.check_password("hex_sha512", self.password2, hashed_password1, "utf8")
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LoginTestCase(TestCase):
 | 
			
		||||
    """Tests for the login view"""
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        """
 | 
			
		||||
            Prepare the test context:
 | 
			
		||||
                * set the auth class to 'cas_server.auth.TestAuthUser'
 | 
			
		||||
                * create a service pattern for https://www.example.com/**
 | 
			
		||||
                * Set the service pattern to return all user attributes
 | 
			
		||||
        """
 | 
			
		||||
        settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
 | 
			
		||||
 | 
			
		||||
        # For general purpose testing
 | 
			
		||||
        self.service_pattern = models.ServicePattern.objects.create(
 | 
			
		||||
            name="example",
 | 
			
		||||
            pattern="^https://www\.example\.com(/.*)?$",
 | 
			
		||||
        )
 | 
			
		||||
        models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
 | 
			
		||||
 | 
			
		||||
        # For testing the restrict_users attributes
 | 
			
		||||
        self.service_pattern_restrict_user_fail = models.ServicePattern.objects.create(
 | 
			
		||||
            name="restrict_user_fail",
 | 
			
		||||
            pattern="^https://restrict_user_fail\.example\.com(/.*)?$",
 | 
			
		||||
            restrict_users=True,
 | 
			
		||||
        )
 | 
			
		||||
        self.service_pattern_restrict_user_success = models.ServicePattern.objects.create(
 | 
			
		||||
            name="restrict_user_success",
 | 
			
		||||
            pattern="^https://restrict_user_success\.example\.com(/.*)?$",
 | 
			
		||||
            restrict_users=True,
 | 
			
		||||
        )
 | 
			
		||||
        models.Username.objects.create(
 | 
			
		||||
            value=settings.CAS_TEST_USER,
 | 
			
		||||
            service_pattern=self.service_pattern_restrict_user_success
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # For testing the user attributes filtering conditions
 | 
			
		||||
        self.service_pattern_filter_fail = models.ServicePattern.objects.create(
 | 
			
		||||
            name="filter_fail",
 | 
			
		||||
            pattern="^https://filter_fail\.example\.com(/.*)?$",
 | 
			
		||||
        )
 | 
			
		||||
        models.FilterAttributValue.objects.create(
 | 
			
		||||
            attribut="right",
 | 
			
		||||
            pattern="^admin$",
 | 
			
		||||
            service_pattern=self.service_pattern_filter_fail
 | 
			
		||||
        )
 | 
			
		||||
        self.service_pattern_filter_success = models.ServicePattern.objects.create(
 | 
			
		||||
            name="filter_success",
 | 
			
		||||
            pattern="^https://filter_success\.example\.com(/.*)?$",
 | 
			
		||||
        )
 | 
			
		||||
        models.FilterAttributValue.objects.create(
 | 
			
		||||
            attribut="email",
 | 
			
		||||
            pattern="^%s$" % re.escape(settings.CAS_TEST_ATTRIBUTES['email']),
 | 
			
		||||
            service_pattern=self.service_pattern_filter_success
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # For testing the user_field attributes
 | 
			
		||||
        self.service_pattern_field_needed_fail = models.ServicePattern.objects.create(
 | 
			
		||||
            name="field_needed_fail",
 | 
			
		||||
            pattern="^https://field_needed_fail\.example\.com(/.*)?$",
 | 
			
		||||
            user_field="uid"
 | 
			
		||||
        )
 | 
			
		||||
        self.service_pattern_field_needed_success = models.ServicePattern.objects.create(
 | 
			
		||||
            name="field_needed_success",
 | 
			
		||||
            pattern="^https://field_needed_success\.example\.com(/.*)?$",
 | 
			
		||||
            user_field="nom"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def assert_logged(self, client, response, warn=False, code=200):
 | 
			
		||||
        """Assertions testing that client is well authenticated"""
 | 
			
		||||
        self.assertEqual(response.status_code, code)
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            (
 | 
			
		||||
                b"You have successfully logged into "
 | 
			
		||||
                b"the Central Authentication Service"
 | 
			
		||||
            ) in response.content
 | 
			
		||||
        )
 | 
			
		||||
        self.assertTrue(client.session["username"] == settings.CAS_TEST_USER)
 | 
			
		||||
        self.assertTrue(client.session["warn"] is warn)
 | 
			
		||||
        self.assertTrue(client.session["authenticated"] is True)
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            models.User.objects.get(
 | 
			
		||||
                username=settings.CAS_TEST_USER,
 | 
			
		||||
                session_key=client.session.session_key
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def assert_login_failed(self, client, response, code=200):
 | 
			
		||||
        """Assertions testing a failed login attempt"""
 | 
			
		||||
        self.assertEqual(response.status_code, code)
 | 
			
		||||
        self.assertFalse(
 | 
			
		||||
            (
 | 
			
		||||
                b"You have successfully logged into "
 | 
			
		||||
                b"the Central Authentication Service"
 | 
			
		||||
            ) in response.content
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(client.session.get("username") is None)
 | 
			
		||||
        self.assertTrue(client.session.get("warn") is None)
 | 
			
		||||
        self.assertTrue(client.session.get("authenticated") is None)
 | 
			
		||||
 | 
			
		||||
    def test_login_view_post_goodpass_goodlt(self):
 | 
			
		||||
        """Test a successul login"""
 | 
			
		||||
        client, params = get_login_page_params()
 | 
			
		||||
        params["username"] = settings.CAS_TEST_USER
 | 
			
		||||
        params["password"] = settings.CAS_TEST_PASSWORD
 | 
			
		||||
        self.assertTrue(params['lt'] in client.session['lt'])
 | 
			
		||||
 | 
			
		||||
        response = client.post('/login', params)
 | 
			
		||||
        self.assert_logged(client, response)
 | 
			
		||||
        # LoginTicket conssumed
 | 
			
		||||
        self.assertTrue(params['lt'] not in client.session['lt'])
 | 
			
		||||
 | 
			
		||||
    def test_login_view_post_goodpass_goodlt_warn(self):
 | 
			
		||||
        """Test a successul login requesting to be warned before creating services tickets"""
 | 
			
		||||
        client, params = get_login_page_params()
 | 
			
		||||
        params["username"] = settings.CAS_TEST_USER
 | 
			
		||||
        params["password"] = settings.CAS_TEST_PASSWORD
 | 
			
		||||
        params["warn"] = "on"
 | 
			
		||||
 | 
			
		||||
        response = client.post('/login', params)
 | 
			
		||||
        self.assert_logged(client, response, warn=True)
 | 
			
		||||
 | 
			
		||||
    def test_lt_max(self):
 | 
			
		||||
        """Check we only keep the last 100 Login Ticket for a user"""
 | 
			
		||||
        client, params = get_login_page_params()
 | 
			
		||||
        current_lt = params["lt"]
 | 
			
		||||
        i_in_test = random.randint(0, 100)
 | 
			
		||||
        i_not_in_test = random.randint(100, 150)
 | 
			
		||||
        for i in range(150):
 | 
			
		||||
            if i == i_in_test:
 | 
			
		||||
                self.assertTrue(current_lt in client.session['lt'])
 | 
			
		||||
            if i == i_not_in_test:
 | 
			
		||||
                self.assertTrue(current_lt not in client.session['lt'])
 | 
			
		||||
                self.assertTrue(len(client.session['lt']) <= 100)
 | 
			
		||||
            client, params = get_login_page_params(client)
 | 
			
		||||
        self.assertTrue(len(client.session['lt']) <= 100)
 | 
			
		||||
 | 
			
		||||
    def test_login_view_post_badlt(self):
 | 
			
		||||
        """Login attempt with a bad LoginTicket"""
 | 
			
		||||
        client, params = get_login_page_params()
 | 
			
		||||
        params["username"] = settings.CAS_TEST_USER
 | 
			
		||||
        params["password"] = settings.CAS_TEST_PASSWORD
 | 
			
		||||
        params["lt"] = 'LT-random'
 | 
			
		||||
 | 
			
		||||
        response = client.post('/login', params)
 | 
			
		||||
 | 
			
		||||
        self.assert_login_failed(client, response)
 | 
			
		||||
        self.assertTrue(b"Invalid login ticket" in response.content)
 | 
			
		||||
 | 
			
		||||
    def test_login_view_post_badpass_good_lt(self):
 | 
			
		||||
        """Login attempt with a bad password"""
 | 
			
		||||
        client, params = get_login_page_params()
 | 
			
		||||
        params["username"] = settings.CAS_TEST_USER
 | 
			
		||||
        params["password"] = "test2"
 | 
			
		||||
        response = client.post('/login', params)
 | 
			
		||||
 | 
			
		||||
        self.assert_login_failed(client, response)
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            (
 | 
			
		||||
                b"The credentials you provided cannot be "
 | 
			
		||||
                b"determined to be authentic"
 | 
			
		||||
            ) in response.content
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def assert_ticket_attributes(self, client, ticket_value):
 | 
			
		||||
        """check the ticket attributes in the db"""
 | 
			
		||||
        user = models.User.objects.get(
 | 
			
		||||
            username=settings.CAS_TEST_USER,
 | 
			
		||||
            session_key=client.session.session_key
 | 
			
		||||
        )
 | 
			
		||||
        self.assertTrue(user)
 | 
			
		||||
        ticket = models.ServiceTicket.objects.get(value=ticket_value)
 | 
			
		||||
        self.assertEqual(ticket.user, user)
 | 
			
		||||
        self.assertEqual(ticket.attributs, settings.CAS_TEST_ATTRIBUTES)
 | 
			
		||||
        self.assertEqual(ticket.validate, False)
 | 
			
		||||
        self.assertEqual(ticket.service_pattern, self.service_pattern)
 | 
			
		||||
 | 
			
		||||
    def assert_service_ticket(self, client, response):
 | 
			
		||||
        """check that a ticket is well emited when requested on a allowed service"""
 | 
			
		||||
        self.assertEqual(response.status_code, 302)
 | 
			
		||||
        self.assertTrue(response.has_header('Location'))
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            response['Location'].startswith(
 | 
			
		||||
                "https://www.example.com?ticket=%s-" % settings.CAS_SERVICE_TICKET_PREFIX
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        ticket_value = response['Location'].split('ticket=')[-1]
 | 
			
		||||
        self.assert_ticket_attributes(client, ticket_value)
 | 
			
		||||
 | 
			
		||||
    def test_view_login_get_allowed_service(self):
 | 
			
		||||
        """Request a ticket for an allowed service by an unauthenticated client"""
 | 
			
		||||
        client = Client()
 | 
			
		||||
        response = client.get("/login?service=https://www.example.com")
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            (
 | 
			
		||||
                b"Authentication required by service "
 | 
			
		||||
                b"example (https://www.example.com)"
 | 
			
		||||
            ) in response.content
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_view_login_get_denied_service(self):
 | 
			
		||||
        """Request a ticket for an denied service by an unauthenticated client"""
 | 
			
		||||
        client = Client()
 | 
			
		||||
        response = client.get("/login?service=https://www.example.net")
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        self.assertTrue(b"Service https://www.example.net non allowed" in response.content)
 | 
			
		||||
 | 
			
		||||
    def test_view_login_get_auth_allowed_service(self):
 | 
			
		||||
        """Request a ticket for an allowed service by an authenticated client"""
 | 
			
		||||
        # client is already authenticated
 | 
			
		||||
        client = get_auth_client()
 | 
			
		||||
        response = client.get("/login?service=https://www.example.com")
 | 
			
		||||
        self.assert_service_ticket(client, response)
 | 
			
		||||
 | 
			
		||||
    def test_view_login_get_auth_allowed_service_warn(self):
 | 
			
		||||
        """Request a ticket for an allowed service by an authenticated client"""
 | 
			
		||||
        # client is already authenticated
 | 
			
		||||
        client = get_auth_client(warn="on")
 | 
			
		||||
        response = client.get("/login?service=https://www.example.com")
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            (
 | 
			
		||||
                b"Authentication has been required by service "
 | 
			
		||||
                b"example (https://www.example.com)"
 | 
			
		||||
            ) in response.content
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        params = copy_form(response.context["form"])
 | 
			
		||||
        response = client.post("/login", params)
 | 
			
		||||
        self.assert_service_ticket(client, response)
 | 
			
		||||
 | 
			
		||||
    def test_view_login_get_auth_denied_service(self):
 | 
			
		||||
        """Request a ticket for a not allowed service by an authenticated client"""
 | 
			
		||||
        client = get_auth_client()
 | 
			
		||||
        response = client.get("/login?service=https://www.example.org")
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        self.assertTrue(b"Service https://www.example.org non allowed" in response.content)
 | 
			
		||||
 | 
			
		||||
    def test_user_logged_not_in_db(self):
 | 
			
		||||
        """If the user is logged but has been delete from the database, it should be logged out"""
 | 
			
		||||
        client = get_auth_client()
 | 
			
		||||
        models.User.objects.get(
 | 
			
		||||
            username=settings.CAS_TEST_USER,
 | 
			
		||||
            session_key=client.session.session_key
 | 
			
		||||
        ).delete()
 | 
			
		||||
        response = client.get("/login")
 | 
			
		||||
 | 
			
		||||
        self.assert_login_failed(client, response, code=302)
 | 
			
		||||
        if django.VERSION < (1, 9):
 | 
			
		||||
            self.assertEqual(response["Location"], "http://testserver/login")
 | 
			
		||||
        else:
 | 
			
		||||
            self.assertEqual(response["Location"], "/login?")
 | 
			
		||||
 | 
			
		||||
    def test_service_restrict_user(self):
 | 
			
		||||
        """Testing the restric user capability fro a service"""
 | 
			
		||||
        service = "https://restrict_user_fail.example.com"
 | 
			
		||||
        client = get_auth_client()
 | 
			
		||||
        response = client.get("/login", {'service': service})
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        self.assertTrue(b"Username non allowed" in response.content)
 | 
			
		||||
 | 
			
		||||
        service = "https://restrict_user_success.example.com"
 | 
			
		||||
        response = client.get("/login", {'service': service})
 | 
			
		||||
        self.assertEqual(response.status_code, 302)
 | 
			
		||||
        self.assertTrue(response["Location"].startswith("%s?ticket=" % service))
 | 
			
		||||
 | 
			
		||||
    def test_service_filter(self):
 | 
			
		||||
        """Test the filtering on user attributes"""
 | 
			
		||||
        service = "https://filter_fail.example.com"
 | 
			
		||||
        client = get_auth_client()
 | 
			
		||||
        response = client.get("/login", {'service': service})
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        self.assertTrue(b"User charateristics non allowed" in response.content)
 | 
			
		||||
 | 
			
		||||
        service = "https://filter_success.example.com"
 | 
			
		||||
        response = client.get("/login", {'service': service})
 | 
			
		||||
        self.assertEqual(response.status_code, 302)
 | 
			
		||||
        self.assertTrue(response["Location"].startswith("%s?ticket=" % service))
 | 
			
		||||
 | 
			
		||||
    def test_service_user_field(self):
 | 
			
		||||
        """Test using a user attribute as username: case on if the attribute exists or not"""
 | 
			
		||||
        service = "https://field_needed_fail.example.com"
 | 
			
		||||
        client = get_auth_client()
 | 
			
		||||
        response = client.get("/login", {'service': service})
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        self.assertTrue(b"The attribut uid is needed to use that service" in response.content)
 | 
			
		||||
 | 
			
		||||
        service = "https://field_needed_success.example.com"
 | 
			
		||||
        response = client.get("/login", {'service': service})
 | 
			
		||||
        self.assertEqual(response.status_code, 302)
 | 
			
		||||
        self.assertTrue(response["Location"].startswith("%s?ticket=" % service))
 | 
			
		||||
 | 
			
		||||
    def test_gateway(self):
 | 
			
		||||
        """test gateway parameter"""
 | 
			
		||||
 | 
			
		||||
        # First with an authenticated client that fail to get a ticket for a service
 | 
			
		||||
        service = "https://restrict_user_fail.example.com"
 | 
			
		||||
        client = get_auth_client()
 | 
			
		||||
        response = client.get("/login", {'service': service, 'gateway': 'on'})
 | 
			
		||||
        self.assertEqual(response.status_code, 302)
 | 
			
		||||
        self.assertEqual(response["Location"], service)
 | 
			
		||||
 | 
			
		||||
        # second for an user not yet authenticated on a valid service
 | 
			
		||||
        client = Client()
 | 
			
		||||
        response = client.get('/login', {'service': service, 'gateway': 'on'})
 | 
			
		||||
        self.assertEqual(response.status_code, 302)
 | 
			
		||||
        self.assertEqual(response["Location"], service)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LogoutTestCase(TestCase):
 | 
			
		||||
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
 | 
			
		||||
 | 
			
		||||
    def test_logout_view(self):
 | 
			
		||||
        client = get_auth_client()
 | 
			
		||||
 | 
			
		||||
        response = client.get("/login")
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            (
 | 
			
		||||
                b"You have successfully logged into "
 | 
			
		||||
                b"the Central Authentication Service"
 | 
			
		||||
            ) in response.content
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        response = client.get("/logout")
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            (
 | 
			
		||||
                b"You have successfully logged out from "
 | 
			
		||||
                b"the Central Authentication Service"
 | 
			
		||||
            ) in response.content
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        response = client.get("/login")
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        self.assertFalse(
 | 
			
		||||
            (
 | 
			
		||||
                b"You have successfully logged into "
 | 
			
		||||
                b"the Central Authentication Service"
 | 
			
		||||
            ) in response.content
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_logout_view_url(self):
 | 
			
		||||
        client = get_auth_client()
 | 
			
		||||
 | 
			
		||||
        response = client.get('/logout?url=https://www.example.com')
 | 
			
		||||
        self.assertEqual(response.status_code, 302)
 | 
			
		||||
        self.assertTrue(response.has_header("Location"))
 | 
			
		||||
        self.assertEqual(response["Location"], "https://www.example.com")
 | 
			
		||||
 | 
			
		||||
        response = client.get("/login")
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        self.assertFalse(
 | 
			
		||||
            (
 | 
			
		||||
                b"You have successfully logged into "
 | 
			
		||||
                b"the Central Authentication Service"
 | 
			
		||||
            ) in response.content
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_logout_view_service(self):
 | 
			
		||||
        client = get_auth_client()
 | 
			
		||||
 | 
			
		||||
        response = client.get('/logout?service=https://www.example.com')
 | 
			
		||||
        self.assertEqual(response.status_code, 302)
 | 
			
		||||
        self.assertTrue(response.has_header("Location"))
 | 
			
		||||
        self.assertEqual(response["Location"], "https://www.example.com")
 | 
			
		||||
 | 
			
		||||
        response = client.get("/login")
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        self.assertFalse(
 | 
			
		||||
            (
 | 
			
		||||
                b"You have successfully logged into "
 | 
			
		||||
                b"the Central Authentication Service"
 | 
			
		||||
            ) in response.content
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AuthTestCase(TestCase):
 | 
			
		||||
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
 | 
			
		||||
        self.service = 'https://www.example.com'
 | 
			
		||||
        models.ServicePattern.objects.create(
 | 
			
		||||
            name="example",
 | 
			
		||||
            pattern="^https://www\.example\.com(/.*)?$"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_auth_view_goodpass(self):
 | 
			
		||||
        settings.CAS_AUTH_SHARED_SECRET = 'test'
 | 
			
		||||
        client = Client()
 | 
			
		||||
        response = client.post(
 | 
			
		||||
            '/auth',
 | 
			
		||||
            {
 | 
			
		||||
                'username': settings.CAS_TEST_USER,
 | 
			
		||||
                'password': settings.CAS_TEST_PASSWORD,
 | 
			
		||||
                'service': self.service,
 | 
			
		||||
                'secret': 'test'
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        self.assertEqual(response.content, b'yes\n')
 | 
			
		||||
 | 
			
		||||
    def test_auth_view_badpass(self):
 | 
			
		||||
        settings.CAS_AUTH_SHARED_SECRET = 'test'
 | 
			
		||||
        client = Client()
 | 
			
		||||
        response = client.post(
 | 
			
		||||
            '/auth',
 | 
			
		||||
            {
 | 
			
		||||
                'username': settings.CAS_TEST_USER,
 | 
			
		||||
                'password': 'badpass',
 | 
			
		||||
                'service': self.service,
 | 
			
		||||
                'secret': 'test'
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        self.assertEqual(response.content, b'no\n')
 | 
			
		||||
 | 
			
		||||
    def test_auth_view_badservice(self):
 | 
			
		||||
        settings.CAS_AUTH_SHARED_SECRET = 'test'
 | 
			
		||||
        client = Client()
 | 
			
		||||
        response = client.post(
 | 
			
		||||
            '/auth',
 | 
			
		||||
            {
 | 
			
		||||
                'username': settings.CAS_TEST_USER,
 | 
			
		||||
                'password': settings.CAS_TEST_PASSWORD,
 | 
			
		||||
                'service': 'https://www.example.org',
 | 
			
		||||
                'secret': 'test'
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        self.assertEqual(response.content, b'no\n')
 | 
			
		||||
 | 
			
		||||
    def test_auth_view_badsecret(self):
 | 
			
		||||
        settings.CAS_AUTH_SHARED_SECRET = 'test'
 | 
			
		||||
        client = Client()
 | 
			
		||||
        response = client.post(
 | 
			
		||||
            '/auth',
 | 
			
		||||
            {
 | 
			
		||||
                'username': settings.CAS_TEST_USER,
 | 
			
		||||
                'password': settings.CAS_TEST_PASSWORD,
 | 
			
		||||
                'service': self.service,
 | 
			
		||||
                'secret': 'badsecret'
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        self.assertEqual(response.content, b'no\n')
 | 
			
		||||
 | 
			
		||||
    def test_auth_view_badsettings(self):
 | 
			
		||||
        settings.CAS_AUTH_SHARED_SECRET = None
 | 
			
		||||
        client = Client()
 | 
			
		||||
        response = client.post(
 | 
			
		||||
            '/auth',
 | 
			
		||||
            {
 | 
			
		||||
                'username': settings.CAS_TEST_USER,
 | 
			
		||||
                'password': settings.CAS_TEST_PASSWORD,
 | 
			
		||||
                'service': self.service,
 | 
			
		||||
                'secret': 'test'
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        self.assertEqual(response.content, b"no\nplease set CAS_AUTH_SHARED_SECRET")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ValidateTestCase(TestCase):
 | 
			
		||||
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
 | 
			
		||||
        self.service = 'https://www.example.com'
 | 
			
		||||
        self.service_pattern = models.ServicePattern.objects.create(
 | 
			
		||||
            name="example",
 | 
			
		||||
            pattern="^https://www\.example\.com(/.*)?$"
 | 
			
		||||
        )
 | 
			
		||||
        models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
 | 
			
		||||
 | 
			
		||||
    def test_validate_view_ok(self):
 | 
			
		||||
        ticket = get_user_ticket_request(self.service)[1]
 | 
			
		||||
 | 
			
		||||
        client = Client()
 | 
			
		||||
        response = client.get('/validate', {'ticket': ticket.value, 'service': self.service})
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        self.assertEqual(response.content, b'yes\ntest\n')
 | 
			
		||||
 | 
			
		||||
    def test_validate_view_badservice(self):
 | 
			
		||||
        ticket = get_user_ticket_request(self.service)[1]
 | 
			
		||||
 | 
			
		||||
        client = Client()
 | 
			
		||||
        response = client.get(
 | 
			
		||||
            '/validate',
 | 
			
		||||
            {'ticket': ticket.value, 'service': "https://www.example.org"}
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        self.assertEqual(response.content, b'no\n')
 | 
			
		||||
 | 
			
		||||
    def test_validate_view_badticket(self):
 | 
			
		||||
        get_user_ticket_request(self.service)
 | 
			
		||||
 | 
			
		||||
        client = Client()
 | 
			
		||||
        response = client.get(
 | 
			
		||||
            '/validate',
 | 
			
		||||
            {'ticket': "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX, 'service': self.service}
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        self.assertEqual(response.content, b'no\n')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ValidateServiceTestCase(TestCase):
 | 
			
		||||
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
 | 
			
		||||
        self.service = 'http://127.0.0.1:45678'
 | 
			
		||||
        self.service_pattern = models.ServicePattern.objects.create(
 | 
			
		||||
            name="localhost",
 | 
			
		||||
            pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
 | 
			
		||||
            proxy_callback=True
 | 
			
		||||
        )
 | 
			
		||||
        models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
 | 
			
		||||
 | 
			
		||||
    def test_validate_service_view_ok(self):
 | 
			
		||||
        ticket = get_user_ticket_request(self.service)[1]
 | 
			
		||||
 | 
			
		||||
        client = Client()
 | 
			
		||||
        response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': self.service})
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
 | 
			
		||||
        root = etree.fromstring(response.content)
 | 
			
		||||
        sucess = root.xpath(
 | 
			
		||||
            "//cas:authenticationSuccess",
 | 
			
		||||
            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
			
		||||
        )
 | 
			
		||||
        self.assertTrue(sucess)
 | 
			
		||||
 | 
			
		||||
        users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
 | 
			
		||||
        self.assertEqual(len(users), 1)
 | 
			
		||||
        self.assertEqual(users[0].text, settings.CAS_TEST_USER)
 | 
			
		||||
 | 
			
		||||
        attributes = root.xpath(
 | 
			
		||||
            "//cas:attributes",
 | 
			
		||||
            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(len(attributes), 1)
 | 
			
		||||
        attrs1 = set()
 | 
			
		||||
        for attr in attributes[0]:
 | 
			
		||||
            attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text))
 | 
			
		||||
 | 
			
		||||
        attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
 | 
			
		||||
        self.assertEqual(len(attributes), len(attrs1))
 | 
			
		||||
        attrs2 = set()
 | 
			
		||||
        for attr in attributes:
 | 
			
		||||
            attrs2.add((attr.attrib['name'], attr.attrib['value']))
 | 
			
		||||
        original = set()
 | 
			
		||||
        for key, value in settings.CAS_TEST_ATTRIBUTES.items():
 | 
			
		||||
            if isinstance(value, list):
 | 
			
		||||
                for sub_value in value:
 | 
			
		||||
                    original.add((key, sub_value))
 | 
			
		||||
            else:
 | 
			
		||||
                original.add((key, value))
 | 
			
		||||
        self.assertEqual(attrs1, attrs2)
 | 
			
		||||
        self.assertEqual(attrs1, original)
 | 
			
		||||
 | 
			
		||||
    def test_validate_service_view_badservice(self):
 | 
			
		||||
        ticket = get_user_ticket_request(self.service)[1]
 | 
			
		||||
 | 
			
		||||
        client = Client()
 | 
			
		||||
        bad_service = "https://www.example.org"
 | 
			
		||||
        response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': bad_service})
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
 | 
			
		||||
        root = etree.fromstring(response.content)
 | 
			
		||||
        error = root.xpath(
 | 
			
		||||
            "//cas:authenticationFailure",
 | 
			
		||||
            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(len(error), 1)
 | 
			
		||||
        self.assertEqual(error[0].attrib['code'], "INVALID_SERVICE")
 | 
			
		||||
        self.assertEqual(error[0].text, bad_service)
 | 
			
		||||
 | 
			
		||||
    def test_validate_service_view_badticket_goodprefix(self):
 | 
			
		||||
        get_user_ticket_request(self.service)
 | 
			
		||||
 | 
			
		||||
        client = Client()
 | 
			
		||||
        bad_ticket = "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX
 | 
			
		||||
        response = client.get('/serviceValidate', {'ticket': bad_ticket, 'service': self.service})
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
 | 
			
		||||
        root = etree.fromstring(response.content)
 | 
			
		||||
        error = root.xpath(
 | 
			
		||||
            "//cas:authenticationFailure",
 | 
			
		||||
            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(len(error), 1)
 | 
			
		||||
        self.assertEqual(error[0].attrib['code'], "INVALID_TICKET")
 | 
			
		||||
        self.assertEqual(error[0].text, 'ticket not found')
 | 
			
		||||
 | 
			
		||||
    def test_validate_service_view_badticket_badprefix(self):
 | 
			
		||||
        get_user_ticket_request(self.service)
 | 
			
		||||
 | 
			
		||||
        client = Client()
 | 
			
		||||
        bad_ticket = "RANDOM"
 | 
			
		||||
        response = client.get('/serviceValidate', {'ticket': bad_ticket, 'service': self.service})
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
 | 
			
		||||
        root = etree.fromstring(response.content)
 | 
			
		||||
        error = root.xpath(
 | 
			
		||||
            "//cas:authenticationFailure",
 | 
			
		||||
            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(len(error), 1)
 | 
			
		||||
        self.assertEqual(error[0].attrib['code'], "INVALID_TICKET")
 | 
			
		||||
        self.assertEqual(error[0].text, bad_ticket)
 | 
			
		||||
 | 
			
		||||
    def test_validate_service_view_ok_pgturl(self):
 | 
			
		||||
        (host, port) = utils.PGTUrlHandler.run()[1:3]
 | 
			
		||||
        service = "http://%s:%s" % (host, port)
 | 
			
		||||
 | 
			
		||||
        ticket = get_user_ticket_request(service)[1]
 | 
			
		||||
 | 
			
		||||
        client = Client()
 | 
			
		||||
        response = client.get(
 | 
			
		||||
            '/serviceValidate',
 | 
			
		||||
            {'ticket': ticket.value, 'service': service, 'pgtUrl': service}
 | 
			
		||||
        )
 | 
			
		||||
        pgt_params = utils.PGTUrlHandler.PARAMS.copy()
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
 | 
			
		||||
        root = etree.fromstring(response.content)
 | 
			
		||||
        pgtiou = root.xpath(
 | 
			
		||||
            "//cas:proxyGrantingTicket",
 | 
			
		||||
            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(len(pgtiou), 1)
 | 
			
		||||
        self.assertEqual(pgt_params["pgtIou"], pgtiou[0].text)
 | 
			
		||||
        self.assertTrue("pgtId" in pgt_params)
 | 
			
		||||
 | 
			
		||||
    def test_validate_service_pgturl_bad_proxy_callback(self):
 | 
			
		||||
        self.service_pattern.proxy_callback = False
 | 
			
		||||
        self.service_pattern.save()
 | 
			
		||||
        ticket = get_user_ticket_request(self.service)[1]
 | 
			
		||||
 | 
			
		||||
        client = Client()
 | 
			
		||||
        response = client.get(
 | 
			
		||||
            '/serviceValidate',
 | 
			
		||||
            {'ticket': ticket.value, 'service': self.service, 'pgtUrl': self.service}
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
 | 
			
		||||
        root = etree.fromstring(response.content)
 | 
			
		||||
        error = root.xpath(
 | 
			
		||||
            "//cas:authenticationFailure",
 | 
			
		||||
            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(len(error), 1)
 | 
			
		||||
        self.assertEqual(error[0].attrib['code'], "INVALID_PROXY_CALLBACK")
 | 
			
		||||
        self.assertEqual(error[0].text, "callback url not allowed by configuration")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ProxyTestCase(TestCase):
 | 
			
		||||
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
 | 
			
		||||
        self.service = 'http://127.0.0.1'
 | 
			
		||||
        self.service_pattern = models.ServicePattern.objects.create(
 | 
			
		||||
            name="localhost",
 | 
			
		||||
            pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
 | 
			
		||||
            proxy=True,
 | 
			
		||||
            proxy_callback=True
 | 
			
		||||
        )
 | 
			
		||||
        models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
 | 
			
		||||
 | 
			
		||||
    def test_validate_proxy_ok(self):
 | 
			
		||||
        params = get_pgt()
 | 
			
		||||
 | 
			
		||||
        # get a proxy ticket
 | 
			
		||||
        client1 = Client()
 | 
			
		||||
        response = client1.get('/proxy', {'pgt': params['pgtId'], 'targetService': self.service})
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
 | 
			
		||||
        root = etree.fromstring(response.content)
 | 
			
		||||
        sucess = root.xpath("//cas:proxySuccess", namespaces={'cas': "http://www.yale.edu/tp/cas"})
 | 
			
		||||
        self.assertTrue(sucess)
 | 
			
		||||
 | 
			
		||||
        proxy_ticket = root.xpath(
 | 
			
		||||
            "//cas:proxyTicket",
 | 
			
		||||
            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(len(proxy_ticket), 1)
 | 
			
		||||
        proxy_ticket = proxy_ticket[0].text
 | 
			
		||||
 | 
			
		||||
        # validate the proxy ticket
 | 
			
		||||
        client2 = Client()
 | 
			
		||||
        response = client2.get('/proxyValidate', {'ticket': proxy_ticket, 'service': self.service})
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
 | 
			
		||||
        root = etree.fromstring(response.content)
 | 
			
		||||
        sucess = root.xpath(
 | 
			
		||||
            "//cas:authenticationSuccess",
 | 
			
		||||
            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
			
		||||
        )
 | 
			
		||||
        self.assertTrue(sucess)
 | 
			
		||||
 | 
			
		||||
        # check that the proxy is send to the end service
 | 
			
		||||
        proxies = root.xpath("//cas:proxies", namespaces={'cas': "http://www.yale.edu/tp/cas"})
 | 
			
		||||
        self.assertEqual(len(proxies), 1)
 | 
			
		||||
        proxy = proxies[0].xpath("//cas:proxy", namespaces={'cas': "http://www.yale.edu/tp/cas"})
 | 
			
		||||
        self.assertEqual(len(proxy), 1)
 | 
			
		||||
        self.assertEqual(proxy[0].text, params["service"])
 | 
			
		||||
 | 
			
		||||
        # same tests than those for serviceValidate
 | 
			
		||||
        users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
 | 
			
		||||
        self.assertEqual(len(users), 1)
 | 
			
		||||
        self.assertEqual(users[0].text, settings.CAS_TEST_USER)
 | 
			
		||||
 | 
			
		||||
        attributes = root.xpath(
 | 
			
		||||
            "//cas:attributes",
 | 
			
		||||
            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(len(attributes), 1)
 | 
			
		||||
        attrs1 = set()
 | 
			
		||||
        for attr in attributes[0]:
 | 
			
		||||
            attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text))
 | 
			
		||||
 | 
			
		||||
        attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
 | 
			
		||||
        self.assertEqual(len(attributes), len(attrs1))
 | 
			
		||||
        attrs2 = set()
 | 
			
		||||
        for attr in attributes:
 | 
			
		||||
            attrs2.add((attr.attrib['name'], attr.attrib['value']))
 | 
			
		||||
        original = set()
 | 
			
		||||
        for key, value in settings.CAS_TEST_ATTRIBUTES.items():
 | 
			
		||||
            if isinstance(value, list):
 | 
			
		||||
                for sub_value in value:
 | 
			
		||||
                    original.add((key, sub_value))
 | 
			
		||||
            else:
 | 
			
		||||
                original.add((key, value))
 | 
			
		||||
        self.assertEqual(attrs1, attrs2)
 | 
			
		||||
        self.assertEqual(attrs1, original)
 | 
			
		||||
 | 
			
		||||
    def test_validate_proxy_bad(self):
 | 
			
		||||
        params = get_pgt()
 | 
			
		||||
 | 
			
		||||
        # bad PGT
 | 
			
		||||
        client1 = Client()
 | 
			
		||||
        response = client1.get(
 | 
			
		||||
            '/proxy',
 | 
			
		||||
            {
 | 
			
		||||
                'pgt': "%s-RANDOM" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX,
 | 
			
		||||
                'targetService': params['service']
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
 | 
			
		||||
        root = etree.fromstring(response.content)
 | 
			
		||||
        error = root.xpath(
 | 
			
		||||
            "//cas:authenticationFailure",
 | 
			
		||||
            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(len(error), 1)
 | 
			
		||||
        self.assertEqual(error[0].attrib['code'], "INVALID_TICKET")
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            error[0].text,
 | 
			
		||||
            "PGT %s-RANDOM not found" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # bad targetService
 | 
			
		||||
        client2 = Client()
 | 
			
		||||
        response = client2.get(
 | 
			
		||||
            '/proxy',
 | 
			
		||||
            {'pgt': params['pgtId'], 'targetService': "https://www.example.org"}
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
 | 
			
		||||
        root = etree.fromstring(response.content)
 | 
			
		||||
        error = root.xpath(
 | 
			
		||||
            "//cas:authenticationFailure",
 | 
			
		||||
            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(len(error), 1)
 | 
			
		||||
        self.assertEqual(error[0].attrib['code'], "UNAUTHORIZED_SERVICE")
 | 
			
		||||
        self.assertEqual(error[0].text, "https://www.example.org")
 | 
			
		||||
 | 
			
		||||
        # service do not allow proxy ticket
 | 
			
		||||
        self.service_pattern.proxy = False
 | 
			
		||||
        self.service_pattern.save()
 | 
			
		||||
 | 
			
		||||
        client3 = Client()
 | 
			
		||||
        response = client3.get(
 | 
			
		||||
            '/proxy',
 | 
			
		||||
            {'pgt': params['pgtId'], 'targetService': params['service']}
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
 | 
			
		||||
        root = etree.fromstring(response.content)
 | 
			
		||||
        error = root.xpath(
 | 
			
		||||
            "//cas:authenticationFailure",
 | 
			
		||||
            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(len(error), 1)
 | 
			
		||||
        self.assertEqual(error[0].attrib['code'], "UNAUTHORIZED_SERVICE")
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            error[0].text,
 | 
			
		||||
            'the service %s do not allow proxy ticket' % params['service']
 | 
			
		||||
        )
 | 
			
		||||
@@ -14,7 +14,7 @@ from django.conf.urls import patterns, url
 | 
			
		||||
from django.views.generic import RedirectView
 | 
			
		||||
from django.views.decorators.debug import sensitive_post_parameters, sensitive_variables
 | 
			
		||||
 | 
			
		||||
import views
 | 
			
		||||
from cas_server import views
 | 
			
		||||
 | 
			
		||||
urlpatterns = patterns(
 | 
			
		||||
    '',
 | 
			
		||||
 
 | 
			
		||||
@@ -19,14 +19,15 @@ from django.contrib import messages
 | 
			
		||||
import random
 | 
			
		||||
import string
 | 
			
		||||
import json
 | 
			
		||||
import hashlib
 | 
			
		||||
import crypt
 | 
			
		||||
import base64
 | 
			
		||||
import six
 | 
			
		||||
from threading import Thread
 | 
			
		||||
from importlib import import_module
 | 
			
		||||
from datetime import datetime, timedelta
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    from urlparse import urlparse, urlunparse, parse_qsl
 | 
			
		||||
    from urllib import urlencode
 | 
			
		||||
except ImportError:
 | 
			
		||||
    from urllib.parse import urlparse, urlunparse, parse_qsl, urlencode
 | 
			
		||||
from six.moves import BaseHTTPServer
 | 
			
		||||
from six.moves.urllib.parse import urlparse, urlunparse, parse_qsl, urlencode
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def context(params):
 | 
			
		||||
@@ -34,7 +35,7 @@ def context(params):
 | 
			
		||||
    return params
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def JsonResponse(request, data):
 | 
			
		||||
def json_response(request, data):
 | 
			
		||||
    data["messages"] = []
 | 
			
		||||
    for msg in messages.get_messages(request):
 | 
			
		||||
        data["messages"].append({'message': msg.message, 'level': msg.level_tag})
 | 
			
		||||
@@ -120,9 +121,9 @@ def update_url(url, params):
 | 
			
		||||
    query = dict(parse_qsl(url_parts[4]))
 | 
			
		||||
    query.update(params)
 | 
			
		||||
    url_parts[4] = urlencode(query)
 | 
			
		||||
    for i in range(len(url_parts)):
 | 
			
		||||
        if not isinstance(url_parts[i], bytes):
 | 
			
		||||
            url_parts[i] = url_parts[i].encode('utf-8')
 | 
			
		||||
    for i, url_part in enumerate(url_parts):
 | 
			
		||||
        if not isinstance(url_part, bytes):
 | 
			
		||||
            url_parts[i] = url_part.encode('utf-8')
 | 
			
		||||
    return urlunparse(url_parts).decode('utf-8')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -190,3 +191,207 @@ def get_tuple(tuple, index, default=None):
 | 
			
		||||
        return tuple[index]
 | 
			
		||||
    except IndexError:
 | 
			
		||||
        return default
 | 
			
		||||
 | 
			
		||||
class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler):
 | 
			
		||||
    PARAMS = {}
 | 
			
		||||
 | 
			
		||||
    def do_GET(self):
 | 
			
		||||
        self.send_response(200)
 | 
			
		||||
        self.send_header(b"Content-type", "text/plain")
 | 
			
		||||
        self.end_headers()
 | 
			
		||||
        self.wfile.write(b"ok")
 | 
			
		||||
        url = urlparse(self.path)
 | 
			
		||||
        params = dict(parse_qsl(url.query))
 | 
			
		||||
        PGTUrlHandler.PARAMS.update(params)
 | 
			
		||||
 | 
			
		||||
    def log_message(self, *args):
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def run():
 | 
			
		||||
        server_class = BaseHTTPServer.HTTPServer
 | 
			
		||||
        httpd = server_class(("127.0.0.1", 0), PGTUrlHandler)
 | 
			
		||||
        (host, port) = httpd.socket.getsockname()
 | 
			
		||||
 | 
			
		||||
        def lauch():
 | 
			
		||||
            httpd.handle_request()
 | 
			
		||||
            httpd.server_close()
 | 
			
		||||
 | 
			
		||||
        httpd_thread = Thread(target=lauch)
 | 
			
		||||
        httpd_thread.daemon = True
 | 
			
		||||
        httpd_thread.start()
 | 
			
		||||
        return (httpd_thread, host, port)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LdapHashUserPassword(object):
 | 
			
		||||
    """Please see https://tools.ietf.org/id/draft-stroeder-hashed-userpassword-values-01.html"""
 | 
			
		||||
 | 
			
		||||
    schemes_salt = {b"{SMD5}", b"{SSHA}", b"{SSHA256}", b"{SSHA384}", b"{SSHA512}", b"{CRYPT}"}
 | 
			
		||||
    schemes_nosalt = {b"{MD5}", b"{SHA}", b"{SHA256}", b"{SHA384}", b"{SHA512}"}
 | 
			
		||||
 | 
			
		||||
    _schemes_to_hash = {
 | 
			
		||||
        b"{SMD5}": hashlib.md5,
 | 
			
		||||
        b"{MD5}": hashlib.md5,
 | 
			
		||||
        b"{SSHA}": hashlib.sha1,
 | 
			
		||||
        b"{SHA}": hashlib.sha1,
 | 
			
		||||
        b"{SSHA256}": hashlib.sha256,
 | 
			
		||||
        b"{SHA256}": hashlib.sha256,
 | 
			
		||||
        b"{SSHA384}": hashlib.sha384,
 | 
			
		||||
        b"{SHA384}": hashlib.sha384,
 | 
			
		||||
        b"{SSHA512}": hashlib.sha512,
 | 
			
		||||
        b"{SHA512}": hashlib.sha512
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    _schemes_to_len = {
 | 
			
		||||
        b"{SMD5}": 16,
 | 
			
		||||
        b"{SSHA}": 20,
 | 
			
		||||
        b"{SSHA256}": 32,
 | 
			
		||||
        b"{SSHA384}": 48,
 | 
			
		||||
        b"{SSHA512}": 64,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    class BadScheme(ValueError):
 | 
			
		||||
        """Error raised then the hash scheme is not in schemes_salt + schemes_nosalt"""
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    class BadHash(ValueError):
 | 
			
		||||
        """Error raised then the hash is too short"""
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    class BadSalt(ValueError):
 | 
			
		||||
        """Error raised then with the scheme {CRYPT} the salt is invalid"""
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def _raise_bad_scheme(cls, scheme, valid, msg):
 | 
			
		||||
        """
 | 
			
		||||
            Raise BadScheme error for `scheme`, possible valid scheme are
 | 
			
		||||
            in `valid`, the error message is `msg`
 | 
			
		||||
        """
 | 
			
		||||
        valid_schemes = [s.decode() for s in valid]
 | 
			
		||||
        valid_schemes.sort()
 | 
			
		||||
        raise cls.BadScheme(msg % (scheme, u", ".join(valid_schemes)))
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def _test_scheme(cls, scheme):
 | 
			
		||||
        """Test if a scheme is valide or raise BadScheme"""
 | 
			
		||||
        if scheme not in cls.schemes_salt and scheme not in cls.schemes_nosalt:
 | 
			
		||||
            cls._raise_bad_scheme(
 | 
			
		||||
                scheme,
 | 
			
		||||
                cls.schemes_salt | cls.schemes_nosalt,
 | 
			
		||||
                "The scheme %r is not valid. Valide schemes are %s."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def _test_scheme_salt(cls, scheme):
 | 
			
		||||
        """Test if the scheme need a salt or raise BadScheme"""
 | 
			
		||||
        if scheme not in cls.schemes_salt:
 | 
			
		||||
            cls._raise_bad_scheme(
 | 
			
		||||
                scheme,
 | 
			
		||||
                cls.schemes_salt,
 | 
			
		||||
                "The scheme %r is only valid without a salt. Valide schemes with salt are %s."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def _test_scheme_nosalt(cls, scheme):
 | 
			
		||||
        """Test if the scheme need no salt or raise BadScheme"""
 | 
			
		||||
        if scheme not in cls.schemes_nosalt:
 | 
			
		||||
            cls._raise_bad_scheme(
 | 
			
		||||
                scheme,
 | 
			
		||||
                cls.schemes_nosalt,
 | 
			
		||||
                "The scheme %r is only valid with a salt. Valide schemes without salt are %s."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def hash(cls, scheme, password, salt=None, charset="utf8"):
 | 
			
		||||
        """
 | 
			
		||||
           Hash `password` with `scheme` using `salt`.
 | 
			
		||||
           This three variable beeing encoded in `charset`.
 | 
			
		||||
        """
 | 
			
		||||
        scheme = scheme.upper()
 | 
			
		||||
        cls._test_scheme(scheme)
 | 
			
		||||
        if salt is None or salt == b"":
 | 
			
		||||
            salt = b""
 | 
			
		||||
            cls._test_scheme_nosalt(scheme)
 | 
			
		||||
        elif salt is not None:
 | 
			
		||||
            cls._test_scheme_salt(scheme)
 | 
			
		||||
        try:
 | 
			
		||||
            return scheme + base64.b64encode(
 | 
			
		||||
                cls._schemes_to_hash[scheme](password + salt).digest() + salt
 | 
			
		||||
            )
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            if six.PY3:
 | 
			
		||||
                password = password.decode(charset)
 | 
			
		||||
                salt = salt.decode(charset)
 | 
			
		||||
            hashed_password = crypt.crypt(password, salt)
 | 
			
		||||
            if hashed_password is None:
 | 
			
		||||
                raise cls.BadSalt("System crypt implementation do not support the salt %r" % salt)
 | 
			
		||||
            if six.PY3:
 | 
			
		||||
                hashed_password = hashed_password.encode(charset)
 | 
			
		||||
            return scheme + hashed_password
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def get_scheme(cls, hashed_passord):
 | 
			
		||||
        """Return the scheme of `hashed_passord` or raise BadHash"""
 | 
			
		||||
        if not hashed_passord[0] == b'{'[0] or b'}' not in hashed_passord:
 | 
			
		||||
            raise cls.BadHash("%r should start with the scheme enclosed with { }" % hashed_passord)
 | 
			
		||||
        scheme = hashed_passord.split(b'}', 1)[0]
 | 
			
		||||
        scheme = scheme.upper() + b"}"
 | 
			
		||||
        return scheme
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def get_salt(cls, hashed_passord):
 | 
			
		||||
        """Return the salt of `hashed_passord` possibly empty"""
 | 
			
		||||
        scheme = cls.get_scheme(hashed_passord)
 | 
			
		||||
        cls._test_scheme(scheme)
 | 
			
		||||
        if scheme in cls.schemes_nosalt:
 | 
			
		||||
            return b""
 | 
			
		||||
        elif scheme == b'{CRYPT}':
 | 
			
		||||
            return b'$'.join(hashed_passord.split(b'$', 3)[:-1])
 | 
			
		||||
        else:
 | 
			
		||||
            hashed_passord = base64.b64decode(hashed_passord[len(scheme):])
 | 
			
		||||
            if len(hashed_passord) < cls._schemes_to_len[scheme]:
 | 
			
		||||
                raise cls.BadHash("Hash too short for the scheme %s" % scheme)
 | 
			
		||||
            return hashed_passord[cls._schemes_to_len[scheme]:]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_password(method, password, hashed_password, charset):
 | 
			
		||||
    """
 | 
			
		||||
        Check that `password` match `hashed_password` using `method`,
 | 
			
		||||
        assuming the encoding is `charset`.
 | 
			
		||||
    """
 | 
			
		||||
    if not isinstance(password, six.binary_type):
 | 
			
		||||
        password = password.encode(charset)
 | 
			
		||||
    if not isinstance(hashed_password, six.binary_type):
 | 
			
		||||
        hashed_password = hashed_password.encode(charset)
 | 
			
		||||
    if method == "plain":
 | 
			
		||||
        return password == hashed_password
 | 
			
		||||
    elif method == "crypt":
 | 
			
		||||
        if hashed_password.startswith(b'$'):
 | 
			
		||||
            salt = b'$'.join(hashed_password.split(b'$', 3)[:-1])
 | 
			
		||||
        elif hashed_password.startswith(b'_'):
 | 
			
		||||
            salt = hashed_password[:9]
 | 
			
		||||
        else:
 | 
			
		||||
            salt = hashed_password[:2]
 | 
			
		||||
        if six.PY3:
 | 
			
		||||
            password = password.decode(charset)
 | 
			
		||||
            salt = salt.decode(charset)
 | 
			
		||||
            hashed_password = hashed_password.decode(charset)
 | 
			
		||||
        crypted_password = crypt.crypt(password, salt)
 | 
			
		||||
        if crypted_password is None:
 | 
			
		||||
            raise ValueError("System crypt implementation do not support the salt %r" % salt)
 | 
			
		||||
        return crypted_password == hashed_password
 | 
			
		||||
    elif method == "ldap":
 | 
			
		||||
        scheme = LdapHashUserPassword.get_scheme(hashed_password)
 | 
			
		||||
        salt = LdapHashUserPassword.get_salt(hashed_password)
 | 
			
		||||
        return LdapHashUserPassword.hash(scheme, password, salt, charset=charset) == hashed_password
 | 
			
		||||
    elif (
 | 
			
		||||
       method.startswith("hex_") and
 | 
			
		||||
       method[4:] in {"md5", "sha1", "sha224", "sha256", "sha384", "sha512"}
 | 
			
		||||
    ):
 | 
			
		||||
        return getattr(
 | 
			
		||||
            hashlib,
 | 
			
		||||
            method[4:]
 | 
			
		||||
        )(password).hexdigest().encode("ascii") == hashed_password.lower()
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError("Unknown password method check %r" % method)
 | 
			
		||||
 
 | 
			
		||||
@@ -23,6 +23,7 @@ from django.views.decorators.csrf import csrf_exempt
 | 
			
		||||
from django.middleware.csrf import CsrfViewMiddleware
 | 
			
		||||
from django.views.generic import View
 | 
			
		||||
 | 
			
		||||
import re
 | 
			
		||||
import logging
 | 
			
		||||
import pprint
 | 
			
		||||
import requests
 | 
			
		||||
@@ -34,7 +35,7 @@ import cas_server.utils as utils
 | 
			
		||||
import cas_server.forms as forms
 | 
			
		||||
import cas_server.models as models
 | 
			
		||||
 | 
			
		||||
from .utils import JsonResponse
 | 
			
		||||
from .utils import json_response
 | 
			
		||||
from .models import ServiceTicket, ProxyTicket, ProxyGrantingTicket
 | 
			
		||||
from .models import ServicePattern
 | 
			
		||||
from .federate import CASFederateValidateUser
 | 
			
		||||
@@ -63,12 +64,12 @@ class AttributesMixin(object):
 | 
			
		||||
 | 
			
		||||
class LogoutMixin(object):
 | 
			
		||||
    """destroy CAS session utils"""
 | 
			
		||||
    def logout(self, all=False):
 | 
			
		||||
    def logout(self, all_session=False):
 | 
			
		||||
        """effectively destroy CAS session"""
 | 
			
		||||
        session_nb = 0
 | 
			
		||||
        username = self.request.session.get("username")
 | 
			
		||||
        if username:
 | 
			
		||||
            if all:
 | 
			
		||||
            if all_session:
 | 
			
		||||
                logger.info("Logging out user %s from all of they sessions." % username)
 | 
			
		||||
            else:
 | 
			
		||||
                logger.info("Logging out user %s." % username)
 | 
			
		||||
@@ -91,8 +92,8 @@ class LogoutMixin(object):
 | 
			
		||||
            # if user not found in database, flush the session anyway
 | 
			
		||||
            self.request.session.flush()
 | 
			
		||||
 | 
			
		||||
        # If all is set logout user from alternative sessions
 | 
			
		||||
        if all:
 | 
			
		||||
        # If all_session is set logout user from alternative sessions
 | 
			
		||||
        if all_session:
 | 
			
		||||
            for user in models.User.objects.filter(username=username):
 | 
			
		||||
                session = SessionStore(session_key=user.session_key)
 | 
			
		||||
                session.flush()
 | 
			
		||||
@@ -110,6 +111,7 @@ class LogoutView(View, LogoutMixin):
 | 
			
		||||
    service = None
 | 
			
		||||
 | 
			
		||||
    def init_get(self, request):
 | 
			
		||||
        """Initialize GET received parameters"""
 | 
			
		||||
        self.request = request
 | 
			
		||||
        self.service = request.GET.get('service')
 | 
			
		||||
        self.url = request.GET.get('url')
 | 
			
		||||
@@ -170,13 +172,13 @@ class LogoutView(View, LogoutMixin):
 | 
			
		||||
                        'url': url,
 | 
			
		||||
                        'session_nb': session_nb
 | 
			
		||||
                    }
 | 
			
		||||
                    return JsonResponse(request, data)
 | 
			
		||||
                    return json_response(request, data)
 | 
			
		||||
                else:
 | 
			
		||||
                    return redirect("cas_server:login")
 | 
			
		||||
            else:
 | 
			
		||||
                if self.ajax:
 | 
			
		||||
                    data = {'status': 'success', 'detail': 'logout', 'session_nb': session_nb}
 | 
			
		||||
                    return JsonResponse(request, data)
 | 
			
		||||
                    return json_response(request, data)
 | 
			
		||||
                else:
 | 
			
		||||
                    return render(
 | 
			
		||||
                        request,
 | 
			
		||||
@@ -290,12 +292,10 @@ class LoginView(View, LogoutMixin):
 | 
			
		||||
    USER_NOT_AUTHENTICATED = 6
 | 
			
		||||
 | 
			
		||||
    def init_post(self, request):
 | 
			
		||||
        """Initialize POST received parameters"""
 | 
			
		||||
        self.request = request
 | 
			
		||||
        self.service = request.POST.get('service')
 | 
			
		||||
        if request.POST.get('renew') and request.POST['renew'] != "False":
 | 
			
		||||
            self.renew = True
 | 
			
		||||
        else:
 | 
			
		||||
            self.renew = False
 | 
			
		||||
        self.renew = bool(request.POST.get('renew') and request.POST['renew'] != "False")
 | 
			
		||||
        self.gateway = request.POST.get('gateway')
 | 
			
		||||
        self.method = request.POST.get('method')
 | 
			
		||||
        self.ajax = 'HTTP_X_AJAX' in request.META
 | 
			
		||||
@@ -306,15 +306,19 @@ class LoginView(View, LogoutMixin):
 | 
			
		||||
            self.username = request.POST.get('username')
 | 
			
		||||
            self.ticket = request.POST.get('ticket')
 | 
			
		||||
 | 
			
		||||
    def check_lt(self):
 | 
			
		||||
        # save LT for later check
 | 
			
		||||
        lt_valid = self.request.session.get('lt', [])
 | 
			
		||||
        lt_send = self.request.POST.get('lt')
 | 
			
		||||
        # generate a new LT (by posting the LT has been consumed)
 | 
			
		||||
    def gen_lt(self):
 | 
			
		||||
        """Generate a new LoginTicket and add it to the list of valid LT for the user"""
 | 
			
		||||
        self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()]
 | 
			
		||||
        if len(self.request.session['lt']) > 100:
 | 
			
		||||
            self.request.session['lt'] = self.request.session['lt'][-100:]
 | 
			
		||||
 | 
			
		||||
    def check_lt(self):
 | 
			
		||||
        """Check is the POSTed LoginTicket is valid, if yes invalide it"""
 | 
			
		||||
        # save LT for later check
 | 
			
		||||
        lt_valid = self.request.session.get('lt', [])
 | 
			
		||||
        lt_send = self.request.POST.get('lt')
 | 
			
		||||
        # generate a new LT (by posting the LT has been consumed)
 | 
			
		||||
        self.gen_lt()
 | 
			
		||||
        # check if send LT is valid
 | 
			
		||||
        if lt_valid is None or lt_send not in lt_valid:
 | 
			
		||||
            return False
 | 
			
		||||
@@ -339,7 +343,7 @@ class LoginView(View, LogoutMixin):
 | 
			
		||||
                    username=self.request.session['username'],
 | 
			
		||||
                    session_key=self.request.session.session_key
 | 
			
		||||
                )
 | 
			
		||||
                self.user.save()
 | 
			
		||||
                self.user.save()  # pragma: no cover (should not happend)
 | 
			
		||||
            except models.User.DoesNotExist:
 | 
			
		||||
                self.user = models.User.objects.create(
 | 
			
		||||
                    username=self.request.session['username'],
 | 
			
		||||
@@ -355,10 +359,15 @@ class LoginView(View, LogoutMixin):
 | 
			
		||||
        elif ret == self.USER_ALREADY_LOGGED:
 | 
			
		||||
            pass
 | 
			
		||||
        else:
 | 
			
		||||
            raise EnvironmentError("invalid output for LoginView.process_post")
 | 
			
		||||
            raise EnvironmentError("invalid output for LoginView.process_post")  # pragma: no cover
 | 
			
		||||
        return self.common()
 | 
			
		||||
 | 
			
		||||
    def process_post(self, pytest=False):
 | 
			
		||||
    def process_post(self):
 | 
			
		||||
        """
 | 
			
		||||
            Analyse the POST request:
 | 
			
		||||
                * check that the LoginTicket is valid
 | 
			
		||||
                * check that the user sumited credentials are valid
 | 
			
		||||
        """
 | 
			
		||||
        if not self.check_lt():
 | 
			
		||||
            values = self.request.POST.copy()
 | 
			
		||||
            # if not set a new LT and fail
 | 
			
		||||
@@ -385,12 +394,10 @@ class LoginView(View, LogoutMixin):
 | 
			
		||||
            return self.USER_ALREADY_LOGGED
 | 
			
		||||
 | 
			
		||||
    def init_get(self, request):
 | 
			
		||||
        """Initialize GET received parameters"""
 | 
			
		||||
        self.request = request
 | 
			
		||||
        self.service = request.GET.get('service')
 | 
			
		||||
        if request.GET.get('renew') and request.GET['renew'] != "False":
 | 
			
		||||
            self.renew = True
 | 
			
		||||
        else:
 | 
			
		||||
            self.renew = False
 | 
			
		||||
        self.renew = bool(request.GET.get('renew') and request.GET['renew'] != "False")
 | 
			
		||||
        self.gateway = request.GET.get('gateway')
 | 
			
		||||
        self.method = request.GET.get('method')
 | 
			
		||||
        self.ajax = 'HTTP_X_AJAX' in request.META
 | 
			
		||||
@@ -410,15 +417,16 @@ class LoginView(View, LogoutMixin):
 | 
			
		||||
        return self.common()
 | 
			
		||||
 | 
			
		||||
    def process_get(self):
 | 
			
		||||
        # generate a new LT if none is present
 | 
			
		||||
        self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()]
 | 
			
		||||
 | 
			
		||||
        """Analyse the GET request"""
 | 
			
		||||
        # generate a new LT
 | 
			
		||||
        self.gen_lt()
 | 
			
		||||
        if not self.request.session.get("authenticated") or self.renew:
 | 
			
		||||
            self.init_form()
 | 
			
		||||
            return self.USER_NOT_AUTHENTICATED
 | 
			
		||||
        return self.USER_AUTHENTICATED
 | 
			
		||||
 | 
			
		||||
    def init_form(self, values=None):
 | 
			
		||||
        """Initialization of the good form depending of POST and GET parameters"""
 | 
			
		||||
        form_initial = {
 | 
			
		||||
            'service': self.service,
 | 
			
		||||
            'method': self.method,
 | 
			
		||||
@@ -459,7 +467,7 @@ class LoginView(View, LogoutMixin):
 | 
			
		||||
                )
 | 
			
		||||
                if self.ajax:
 | 
			
		||||
                    data = {"status": "error", "detail": "confirmation needed"}
 | 
			
		||||
                    return JsonResponse(self.request, data)
 | 
			
		||||
                    return json_response(self.request, data)
 | 
			
		||||
                else:
 | 
			
		||||
                    warn_form = forms.WarnForm(initial={
 | 
			
		||||
                        'service': self.service,
 | 
			
		||||
@@ -486,7 +494,7 @@ class LoginView(View, LogoutMixin):
 | 
			
		||||
                    return HttpResponseRedirect(redirect_url)
 | 
			
		||||
                else:
 | 
			
		||||
                    data = {"status": "success", "detail": "auth", "url": redirect_url}
 | 
			
		||||
                    return JsonResponse(self.request, data)
 | 
			
		||||
                    return json_response(self.request, data)
 | 
			
		||||
        except ServicePattern.DoesNotExist:
 | 
			
		||||
            error = 1
 | 
			
		||||
            messages.add_message(
 | 
			
		||||
@@ -530,7 +538,7 @@ class LoginView(View, LogoutMixin):
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            data = {"status": "error", "detail": "auth", "code": error}
 | 
			
		||||
            return JsonResponse(self.request, data)
 | 
			
		||||
            return json_response(self.request, data)
 | 
			
		||||
 | 
			
		||||
    def authenticated(self):
 | 
			
		||||
        """Processing authenticated users"""
 | 
			
		||||
@@ -552,7 +560,7 @@ class LoginView(View, LogoutMixin):
 | 
			
		||||
                    "detail": "login required",
 | 
			
		||||
                    "url": utils.reverse_params("cas_server:login", params=self.request.GET)
 | 
			
		||||
                }
 | 
			
		||||
                return JsonResponse(self.request, data)
 | 
			
		||||
                return json_response(self.request, data)
 | 
			
		||||
            else:
 | 
			
		||||
                return utils.redirect_params("cas_server:login", params=self.request.GET)
 | 
			
		||||
 | 
			
		||||
@@ -562,7 +570,7 @@ class LoginView(View, LogoutMixin):
 | 
			
		||||
        else:
 | 
			
		||||
            if self.ajax:
 | 
			
		||||
                data = {"status": "success", "detail": "logged"}
 | 
			
		||||
                return JsonResponse(self.request, data)
 | 
			
		||||
                return json_response(self.request, data)
 | 
			
		||||
            else:
 | 
			
		||||
                return render(
 | 
			
		||||
                    self.request,
 | 
			
		||||
@@ -605,7 +613,7 @@ class LoginView(View, LogoutMixin):
 | 
			
		||||
                "detail": "login required",
 | 
			
		||||
                "url": utils.reverse_params("cas_server:login",  params=self.request.GET)
 | 
			
		||||
            }
 | 
			
		||||
            return JsonResponse(self.request, data)
 | 
			
		||||
            return json_response(self.request, data)
 | 
			
		||||
        else:
 | 
			
		||||
            if settings.CAS_FEDERATE:
 | 
			
		||||
                if self.username and self.ticket:
 | 
			
		||||
@@ -824,7 +832,10 @@ class ValidateService(View, AttributesMixin):
 | 
			
		||||
                    params['username'] = self.ticket.user.attributs.get(
 | 
			
		||||
                        self.ticket.service_pattern.user_field
 | 
			
		||||
                    )
 | 
			
		||||
                if self.pgt_url and self.pgt_url.startswith("https://"):
 | 
			
		||||
                if self.pgt_url and (
 | 
			
		||||
                    self.pgt_url.startswith("https://") or
 | 
			
		||||
                    re.match("^http://(127\.0\.0\.1|localhost)(:[0-9]+)?(/.*)?$", self.pgt_url)
 | 
			
		||||
                ):
 | 
			
		||||
                    return self.process_pgturl(params)
 | 
			
		||||
                else:
 | 
			
		||||
                    logger.info(
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										22
									
								
								run_tests
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										22
									
								
								run_tests
									
									
									
									
									
										Executable file
									
								
							@@ -0,0 +1,22 @@
 | 
			
		||||
#!/usr/bin/env python
 | 
			
		||||
import os, sys
 | 
			
		||||
import django
 | 
			
		||||
from django.conf import settings
 | 
			
		||||
 | 
			
		||||
import settings_tests
 | 
			
		||||
 | 
			
		||||
settings.configure(**settings_tests.__dict__)
 | 
			
		||||
django.setup()
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    # Django <= 1.8
 | 
			
		||||
    from django.test.simple import DjangoTestSuiteRunner
 | 
			
		||||
    test_runner = DjangoTestSuiteRunner(verbosity=1)
 | 
			
		||||
except ImportError:
 | 
			
		||||
    # Django >= 1.8
 | 
			
		||||
    from django.test.runner import DiscoverRunner
 | 
			
		||||
    test_runner = DiscoverRunner(verbosity=1)
 | 
			
		||||
 | 
			
		||||
failures = test_runner.run_tests(['cas_server'])
 | 
			
		||||
if failures:
 | 
			
		||||
    sys.exit(failures)
 | 
			
		||||
							
								
								
									
										83
									
								
								settings_tests.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										83
									
								
								settings_tests.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,83 @@
 | 
			
		||||
"""
 | 
			
		||||
Django test settings for cas_server application.
 | 
			
		||||
 | 
			
		||||
Generated by 'django-admin startproject' using Django 1.9.7.
 | 
			
		||||
 | 
			
		||||
For more information on this file, see
 | 
			
		||||
https://docs.djangoproject.com/en/1.9/topics/settings/
 | 
			
		||||
 | 
			
		||||
For the full list of settings and their values, see
 | 
			
		||||
https://docs.djangoproject.com/en/1.9/ref/settings/
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
# Build paths inside the project like this: os.path.join(BASE_DIR, ...)
 | 
			
		||||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Quick-start development settings - unsuitable for production
 | 
			
		||||
# See https://docs.djangoproject.com/en/1.9/howto/deployment/checklist/
 | 
			
		||||
 | 
			
		||||
# SECURITY WARNING: keep the secret key used in production secret!
 | 
			
		||||
SECRET_KEY = 'changeme'
 | 
			
		||||
 | 
			
		||||
# SECURITY WARNING: don't run with debug turned on in production!
 | 
			
		||||
DEBUG = True
 | 
			
		||||
 | 
			
		||||
ALLOWED_HOSTS = []
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Application definition
 | 
			
		||||
 | 
			
		||||
INSTALLED_APPS = [
 | 
			
		||||
    'django.contrib.admin',
 | 
			
		||||
    'django.contrib.auth',
 | 
			
		||||
    'django.contrib.contenttypes',
 | 
			
		||||
    'django.contrib.sessions',
 | 
			
		||||
    'django.contrib.messages',
 | 
			
		||||
    'django.contrib.staticfiles',
 | 
			
		||||
    'bootstrap3',
 | 
			
		||||
    'cas_server',
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
MIDDLEWARE_CLASSES = [
 | 
			
		||||
    'django.contrib.sessions.middleware.SessionMiddleware',
 | 
			
		||||
    'django.middleware.common.CommonMiddleware',
 | 
			
		||||
    'django.middleware.csrf.CsrfViewMiddleware',
 | 
			
		||||
    'django.contrib.auth.middleware.AuthenticationMiddleware',
 | 
			
		||||
    'django.contrib.auth.middleware.SessionAuthenticationMiddleware',
 | 
			
		||||
    'django.contrib.messages.middleware.MessageMiddleware',
 | 
			
		||||
    'django.middleware.clickjacking.XFrameOptionsMiddleware',
 | 
			
		||||
    'django.middleware.locale.LocaleMiddleware',
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
ROOT_URLCONF = 'urls_tests'
 | 
			
		||||
 | 
			
		||||
# Database
 | 
			
		||||
# https://docs.djangoproject.com/en/1.9/ref/settings/#databases
 | 
			
		||||
 | 
			
		||||
DATABASES = {
 | 
			
		||||
    'default': {
 | 
			
		||||
        'ENGINE': 'django.db.backends.sqlite3',
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
# Internationalization
 | 
			
		||||
# https://docs.djangoproject.com/en/1.9/topics/i18n/
 | 
			
		||||
 | 
			
		||||
LANGUAGE_CODE = 'en-us'
 | 
			
		||||
 | 
			
		||||
TIME_ZONE = 'UTC'
 | 
			
		||||
 | 
			
		||||
USE_I18N = True
 | 
			
		||||
 | 
			
		||||
USE_L10N = True
 | 
			
		||||
 | 
			
		||||
USE_TZ = True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Static files (CSS, JavaScript, Images)
 | 
			
		||||
# https://docs.djangoproject.com/en/1.9/howto/static-files/
 | 
			
		||||
 | 
			
		||||
STATIC_URL = '/static/'
 | 
			
		||||
							
								
								
									
										136
									
								
								tests/dummy.py
									
									
									
									
									
								
							
							
						
						
									
										136
									
								
								tests/dummy.py
									
									
									
									
									
								
							@@ -1,136 +0,0 @@
 | 
			
		||||
import functools
 | 
			
		||||
from cas_server import models
 | 
			
		||||
 | 
			
		||||
class DummyUserManager(object):
 | 
			
		||||
    def __init__(self, username, session_key):
 | 
			
		||||
        self.username = username
 | 
			
		||||
        self.session_key = session_key
 | 
			
		||||
    def get(self, username=None, session_key=None):
 | 
			
		||||
        if username == self.username and session_key == self.session_key:
 | 
			
		||||
            return models.User(username=username, session_key=session_key)
 | 
			
		||||
        else:
 | 
			
		||||
            raise models.User.DoesNotExist()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def dummy(*args, **kwds):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
def dummy_service_pattern(**kwargs):
 | 
			
		||||
    def decorator(func):
 | 
			
		||||
        @functools.wraps(func)
 | 
			
		||||
        def wrapper(*args, **kwds):
 | 
			
		||||
            service_validate = models.ServicePattern.validate
 | 
			
		||||
            models.ServicePattern.validate = classmethod(lambda x,y: models.ServicePattern(**kwargs))
 | 
			
		||||
            ret = func(*args, **kwds)
 | 
			
		||||
            models.ServicePattern.validate = service_validate
 | 
			
		||||
            return ret
 | 
			
		||||
        return wrapper
 | 
			
		||||
    return decorator
 | 
			
		||||
 | 
			
		||||
def dummy_user(username, session_key):
 | 
			
		||||
    def decorator(func):
 | 
			
		||||
        @functools.wraps(func)
 | 
			
		||||
        def wrapper(*args, **kwds):
 | 
			
		||||
            user_manager = models.User.objects
 | 
			
		||||
            user_save = models.User.save
 | 
			
		||||
            user_delete = models.User.delete
 | 
			
		||||
            models.User.objects = DummyUserManager(username, session_key)
 | 
			
		||||
            models.User.save = dummy
 | 
			
		||||
            models.User.delete = dummy
 | 
			
		||||
            ret = func(*args, **kwds)
 | 
			
		||||
            models.User.objects = user_manager
 | 
			
		||||
            models.User.save = user_save
 | 
			
		||||
            models.User.delete = user_delete
 | 
			
		||||
            return ret
 | 
			
		||||
        return wrapper
 | 
			
		||||
    return decorator
 | 
			
		||||
 | 
			
		||||
def dummy_ticket(ticket_class, service, ticket):
 | 
			
		||||
    def decorator(func):
 | 
			
		||||
        @functools.wraps(func)
 | 
			
		||||
        def wrapper(*args, **kwds):
 | 
			
		||||
            ticket_manager = ticket_class.objects
 | 
			
		||||
            ticket_save = ticket_class.save
 | 
			
		||||
            ticket_delete = ticket_class.delete
 | 
			
		||||
            ticket_class.objects = DummyTicketManager(ticket_class, service, ticket)
 | 
			
		||||
            ticket_class.save = dummy
 | 
			
		||||
            ticket_class.delete = dummy
 | 
			
		||||
            ret = func(*args, **kwds)
 | 
			
		||||
            ticket_class.objects = ticket_manager
 | 
			
		||||
            ticket_class.save = ticket_save
 | 
			
		||||
            ticket_class.delete = ticket_delete
 | 
			
		||||
            return ret
 | 
			
		||||
        return wrapper
 | 
			
		||||
    return decorator
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def dummy_proxy(func):
 | 
			
		||||
    @functools.wraps(func)
 | 
			
		||||
    def wrapper(*args, **kwds):
 | 
			
		||||
        proxy_manager = models.Proxy.objects
 | 
			
		||||
        models.Proxy.objects = DummyProxyManager()
 | 
			
		||||
        ret = func(*args, **kwds)
 | 
			
		||||
        models.Proxy.objects = proxy_manager
 | 
			
		||||
        return ret
 | 
			
		||||
    return wrapper
 | 
			
		||||
 | 
			
		||||
class DummyProxyManager(object):
 | 
			
		||||
    def create(self, **kwargs):
 | 
			
		||||
        for field in models.Proxy._meta.fields:
 | 
			
		||||
            field.allow_unsaved_instance_assignment = True
 | 
			
		||||
        return models.Proxy(**kwargs)
 | 
			
		||||
 | 
			
		||||
class DummyTicketManager(object):
 | 
			
		||||
    def __init__(self, ticket_class, service, ticket):
 | 
			
		||||
        self.ticket_class = ticket_class
 | 
			
		||||
        self.service = service
 | 
			
		||||
        self.ticket = ticket
 | 
			
		||||
 | 
			
		||||
    def create(self, **kwargs):
 | 
			
		||||
        for field in self.ticket_class._meta.fields:
 | 
			
		||||
            field.allow_unsaved_instance_assignment = True
 | 
			
		||||
        return self.ticket_class(**kwargs)
 | 
			
		||||
 | 
			
		||||
    def filter(self, *args, **kwargs):
 | 
			
		||||
        return DummyQuerySet()
 | 
			
		||||
 | 
			
		||||
    def get(self, **kwargs):
 | 
			
		||||
        for field in self.ticket_class._meta.fields:
 | 
			
		||||
            field.allow_unsaved_instance_assignment = True
 | 
			
		||||
        if 'value' in kwargs:
 | 
			
		||||
            if kwargs['value'] != self.ticket:
 | 
			
		||||
                raise self.ticket_class.DoesNotExist()
 | 
			
		||||
        else:
 | 
			
		||||
            kwargs['value'] = self.ticket
 | 
			
		||||
        
 | 
			
		||||
        if 'service' in kwargs:
 | 
			
		||||
            if kwargs['service'] != self.service:
 | 
			
		||||
                raise self.ticket_class.DoesNotExist()
 | 
			
		||||
        else:
 | 
			
		||||
            kwargs['service'] = self.service
 | 
			
		||||
        if not 'user' in kwargs:
 | 
			
		||||
            kwargs['user'] = models.User(username="test")
 | 
			
		||||
        
 | 
			
		||||
        for field in models.ServiceTicket._meta.fields:
 | 
			
		||||
            field.allow_unsaved_instance_assignment = True
 | 
			
		||||
        for key in list(kwargs):
 | 
			
		||||
            if '__' in key:
 | 
			
		||||
                del kwargs[key]
 | 
			
		||||
        kwargs['attributs'] = {'mail': 'test@example.com'}
 | 
			
		||||
        kwargs['service_pattern'] = models.ServicePattern()
 | 
			
		||||
        return self.ticket_class(**kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DummySession(dict):
 | 
			
		||||
    session_key = "test_session"
 | 
			
		||||
 | 
			
		||||
    def set_expiry(self, int):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def flush(self):
 | 
			
		||||
        self.clear()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DummyQuerySet(set):
 | 
			
		||||
    pass
 | 
			
		||||
@@ -1,32 +0,0 @@
 | 
			
		||||
import django
 | 
			
		||||
from django.conf import settings
 | 
			
		||||
from django.contrib import messages
 | 
			
		||||
 | 
			
		||||
settings.configure()
 | 
			
		||||
settings.STATIC_URL = "/static/"
 | 
			
		||||
settings.DATABASES = {
 | 
			
		||||
    'default': {
 | 
			
		||||
        'ENGINE': 'django.db.backends.sqlite3',
 | 
			
		||||
        'NAME': '/dev/null',
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
settings.INSTALLED_APPS = (
 | 
			
		||||
    'django.contrib.admin',
 | 
			
		||||
    'django.contrib.auth',
 | 
			
		||||
    'django.contrib.contenttypes',
 | 
			
		||||
    'django.contrib.sessions',
 | 
			
		||||
    'django.contrib.messages',
 | 
			
		||||
    'django.contrib.staticfiles',
 | 
			
		||||
    'bootstrap3',
 | 
			
		||||
    'cas_server',
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
settings.ROOT_URLCONF = "/"
 | 
			
		||||
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    django.setup()
 | 
			
		||||
except AttributeError:
 | 
			
		||||
    pass
 | 
			
		||||
messages.add_message = lambda x,y,z:None
 | 
			
		||||
 | 
			
		||||
@@ -1,52 +0,0 @@
 | 
			
		||||
from __future__ import absolute_import
 | 
			
		||||
from tests.init import *
 | 
			
		||||
 | 
			
		||||
from django.test import RequestFactory
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
import pytest
 | 
			
		||||
from lxml import etree
 | 
			
		||||
from cas_server.views import ValidateService, Proxy
 | 
			
		||||
from cas_server import models
 | 
			
		||||
 | 
			
		||||
from tests.dummy import *
 | 
			
		||||
 | 
			
		||||
@pytest.mark.django_db
 | 
			
		||||
@dummy_ticket(models.ProxyGrantingTicket, '', "PGT-random")
 | 
			
		||||
@dummy_service_pattern(proxy=True)
 | 
			
		||||
@dummy_user(username="test", session_key="test_session")
 | 
			
		||||
@dummy_ticket(models.ProxyTicket, "https://www.example.com", "PT-random")
 | 
			
		||||
@dummy_proxy
 | 
			
		||||
def test_proxy_ok():
 | 
			
		||||
    factory = RequestFactory()
 | 
			
		||||
    request = factory.get('/proxy?pgt=PGT-random&targetService=https://www.example.com')
 | 
			
		||||
 | 
			
		||||
    request.session = DummySession()
 | 
			
		||||
 | 
			
		||||
    proxy = Proxy()
 | 
			
		||||
    response = proxy.get(request)
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
 | 
			
		||||
    root = etree.fromstring(response.content)
 | 
			
		||||
    proxy_tickets = root.xpath("//cas:proxyTicket", namespaces={'cas': "http://www.yale.edu/tp/cas"})
 | 
			
		||||
 | 
			
		||||
    assert len(proxy_tickets) == 1
 | 
			
		||||
 | 
			
		||||
    factory = RequestFactory()
 | 
			
		||||
    request = factory.get('/proxyValidate?ticket=PT-random&service=https://www.example.com')
 | 
			
		||||
 | 
			
		||||
    validate = ValidateService()
 | 
			
		||||
    validate.allow_proxy_ticket = True
 | 
			
		||||
    response = validate.get(request)
 | 
			
		||||
    
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
 | 
			
		||||
    root = etree.fromstring(response.content)
 | 
			
		||||
    users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
 | 
			
		||||
 | 
			
		||||
    assert len(users) == 1
 | 
			
		||||
    assert users[0].text == "test"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -1,87 +0,0 @@
 | 
			
		||||
from __future__ import absolute_import
 | 
			
		||||
from .init import *
 | 
			
		||||
 | 
			
		||||
from django.test import RequestFactory
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
import pytest
 | 
			
		||||
from lxml import etree
 | 
			
		||||
from cas_server.views import ValidateService
 | 
			
		||||
from cas_server import models
 | 
			
		||||
 | 
			
		||||
from .dummy import *
 | 
			
		||||
 | 
			
		||||
@pytest.mark.django_db
 | 
			
		||||
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
 | 
			
		||||
def test_validate_service_view_ok():
 | 
			
		||||
    factory = RequestFactory()
 | 
			
		||||
    request = factory.get('/serviceValidate?ticket=ST-random&service=https://www.example.com')
 | 
			
		||||
 | 
			
		||||
    request.session = DummySession()
 | 
			
		||||
 | 
			
		||||
    validate = ValidateService()
 | 
			
		||||
    validate.allow_proxy_ticket = False
 | 
			
		||||
    response = validate.get(request)
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
 | 
			
		||||
    root = etree.fromstring(response.content)
 | 
			
		||||
    users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
 | 
			
		||||
 | 
			
		||||
    assert len(users) == 1
 | 
			
		||||
    assert users[0].text == "test"
 | 
			
		||||
 | 
			
		||||
    attributes = root.xpath("//cas:attributes", namespaces={'cas': "http://www.yale.edu/tp/cas"})
 | 
			
		||||
 | 
			
		||||
    assert len(attributes) == 1
 | 
			
		||||
    
 | 
			
		||||
    attrs = {}
 | 
			
		||||
    for attr in attributes[0]:
 | 
			
		||||
        attrs[attr.tag[len("http://www.yale.edu/tp/cas")+2:]]=attr.text
 | 
			
		||||
 | 
			
		||||
    assert 'mail' in attrs
 | 
			
		||||
    assert attrs['mail'] == 'test@example.com'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.django_db
 | 
			
		||||
@dummy_ticket(models.ServiceTicket, 'https://www.example2.com', "ST-random")
 | 
			
		||||
def test_validate_service_view_badservice():
 | 
			
		||||
    factory = RequestFactory()
 | 
			
		||||
    request = factory.get('/serviceValidate?ticket=ST-random&service=https://www.example1.com')
 | 
			
		||||
 | 
			
		||||
    request.session = DummySession()
 | 
			
		||||
 | 
			
		||||
    validate = ValidateService()
 | 
			
		||||
    validate.allow_proxy_ticket = False
 | 
			
		||||
    response = validate.get(request)
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
 | 
			
		||||
    root = etree.fromstring(response.content)
 | 
			
		||||
 | 
			
		||||
    error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"})
 | 
			
		||||
    
 | 
			
		||||
    assert len(error) == 1
 | 
			
		||||
    assert error[0].attrib['code'] == 'INVALID_SERVICE'
 | 
			
		||||
 | 
			
		||||
@pytest.mark.django_db
 | 
			
		||||
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random2")
 | 
			
		||||
def test_validate_service_view_badticket():
 | 
			
		||||
    factory = RequestFactory()
 | 
			
		||||
    request = factory.get('/serviceValidate?ticket=ST-random1&service=https://www.example.com')
 | 
			
		||||
 | 
			
		||||
    request.session = DummySession()
 | 
			
		||||
 | 
			
		||||
    validate = ValidateService()
 | 
			
		||||
    validate.allow_proxy_ticket = False
 | 
			
		||||
    response = validate.get(request)
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
 | 
			
		||||
    root = etree.fromstring(response.content)
 | 
			
		||||
 | 
			
		||||
    error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"})
 | 
			
		||||
    
 | 
			
		||||
    assert len(error) == 1
 | 
			
		||||
    assert error[0].attrib['code'] == 'INVALID_TICKET'
 | 
			
		||||
@@ -1,46 +0,0 @@
 | 
			
		||||
from __future__ import absolute_import
 | 
			
		||||
from .init import *
 | 
			
		||||
 | 
			
		||||
from django.test import RequestFactory
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
import pytest
 | 
			
		||||
 | 
			
		||||
from cas_server.views import Auth
 | 
			
		||||
from cas_server import models
 | 
			
		||||
 | 
			
		||||
from .dummy import *
 | 
			
		||||
 | 
			
		||||
settings.CAS_AUTH_SHARED_SECRET = "test"
 | 
			
		||||
 | 
			
		||||
@pytest.mark.django_db
 | 
			
		||||
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
 | 
			
		||||
@dummy_user(username="test", session_key="test_session")
 | 
			
		||||
@dummy_service_pattern()
 | 
			
		||||
def test_auth_view_goodpass():
 | 
			
		||||
    factory = RequestFactory()
 | 
			
		||||
    request = factory.post('/auth', {'username':'test', 'password':'test', 'service':'https://www.example.com', 'secret':'test'})
 | 
			
		||||
 | 
			
		||||
    request.session = DummySession()
 | 
			
		||||
 | 
			
		||||
    auth = Auth()
 | 
			
		||||
    response = auth.post(request)
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
    assert response.content == b"yes\n"
 | 
			
		||||
 | 
			
		||||
@dummy_service_pattern()
 | 
			
		||||
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
 | 
			
		||||
@dummy_user(username="test", session_key="test_session")
 | 
			
		||||
def test_auth_view_badpass():
 | 
			
		||||
    factory = RequestFactory()
 | 
			
		||||
    request = factory.post('/auth', {'username':'test', 'password':'badpass', 'service':'https://www.example.com', 'secret':'test'})
 | 
			
		||||
 | 
			
		||||
    request.session = DummySession()
 | 
			
		||||
 | 
			
		||||
    auth = Auth()
 | 
			
		||||
    response = auth.post(request)
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
    assert response.content == b"no\n"
 | 
			
		||||
 | 
			
		||||
@@ -1,163 +0,0 @@
 | 
			
		||||
from __future__ import absolute_import
 | 
			
		||||
from .init import *
 | 
			
		||||
 | 
			
		||||
from django.test import RequestFactory
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
import pytest
 | 
			
		||||
 | 
			
		||||
from cas_server.views import LoginView
 | 
			
		||||
from cas_server import models
 | 
			
		||||
 | 
			
		||||
from .dummy import *
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_login_view_post_goodpass_goodlt():
 | 
			
		||||
    factory = RequestFactory()
 | 
			
		||||
    request = factory.post('/login', {'username':'test', 'password':'test', 'lt':'LT-random'})
 | 
			
		||||
    request.session = DummySession()
 | 
			
		||||
 | 
			
		||||
    request.session['lt'] = ['LT-random']
 | 
			
		||||
 | 
			
		||||
    request.session["username"] = os.urandom(20)
 | 
			
		||||
    request.session["warn"] = os.urandom(20)
 | 
			
		||||
 | 
			
		||||
    login = LoginView()
 | 
			
		||||
    login.init_post(request)
 | 
			
		||||
 | 
			
		||||
    ret = login.process_post(pytest=True)
 | 
			
		||||
 | 
			
		||||
    assert ret == LoginView.USER_LOGIN_OK
 | 
			
		||||
    assert request.session.get("authenticated") == True
 | 
			
		||||
    assert request.session.get("username") == "test"
 | 
			
		||||
    assert request.session.get("warn") == False
 | 
			
		||||
 | 
			
		||||
def test_login_view_post_badlt():
 | 
			
		||||
    factory = RequestFactory()
 | 
			
		||||
    request = factory.post('/login', {'username':'test', 'password':'test', 'lt':'LT-random1'})
 | 
			
		||||
    request.session = DummySession()
 | 
			
		||||
 | 
			
		||||
    request.session['lt'] = ['LT-random2']
 | 
			
		||||
 | 
			
		||||
    authenticated = os.urandom(20)
 | 
			
		||||
    username = os.urandom(20)
 | 
			
		||||
    warn = os.urandom(20)
 | 
			
		||||
 | 
			
		||||
    request.session["authenticated"] = authenticated
 | 
			
		||||
    request.session["username"] = username
 | 
			
		||||
    request.session["warn"] = warn
 | 
			
		||||
 | 
			
		||||
    login = LoginView()
 | 
			
		||||
    login.init_post(request)
 | 
			
		||||
 | 
			
		||||
    ret = login.process_post(pytest=True)
 | 
			
		||||
 | 
			
		||||
    assert ret == LoginView.INVALID_LOGIN_TICKET
 | 
			
		||||
    assert request.session.get("authenticated") == authenticated
 | 
			
		||||
    assert request.session.get("username") == username
 | 
			
		||||
    assert request.session.get("warn") == warn
 | 
			
		||||
 | 
			
		||||
def test_login_view_post_badpass_good_lt():
 | 
			
		||||
    factory = RequestFactory()
 | 
			
		||||
    request = factory.post('/login', {'username':'test', 'password':'badpassword', 'lt':'LT-random'})
 | 
			
		||||
    request.session = DummySession()
 | 
			
		||||
 | 
			
		||||
    request.session['lt'] = ['LT-random']
 | 
			
		||||
 | 
			
		||||
    login = LoginView()
 | 
			
		||||
    login.init_post(request)
 | 
			
		||||
    ret = login.process_post()
 | 
			
		||||
 | 
			
		||||
    assert ret == LoginView.USER_LOGIN_FAILURE
 | 
			
		||||
    assert not request.session.get("authenticated")
 | 
			
		||||
    assert not request.session.get("username")
 | 
			
		||||
    assert not request.session.get("warn")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_view_login_get_unauth():
 | 
			
		||||
    factory = RequestFactory()
 | 
			
		||||
    request = factory.post('/login')
 | 
			
		||||
    request.session = DummySession()
 | 
			
		||||
 | 
			
		||||
    login = LoginView()
 | 
			
		||||
    login.init_get(request)
 | 
			
		||||
    ret = login.process_get()
 | 
			
		||||
 | 
			
		||||
    assert ret == LoginView.USER_NOT_AUTHENTICATED
 | 
			
		||||
 | 
			
		||||
    login = LoginView()
 | 
			
		||||
    response = login.get(request)
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
 | 
			
		||||
@pytest.mark.django_db
 | 
			
		||||
@dummy_user(username="test", session_key="test_session")
 | 
			
		||||
def test_view_login_get_auth():
 | 
			
		||||
    factory = RequestFactory()
 | 
			
		||||
    request = factory.post('/login')
 | 
			
		||||
    request.session = DummySession()
 | 
			
		||||
 | 
			
		||||
    request.session["authenticated"] = True
 | 
			
		||||
    request.session["username"] = "test"
 | 
			
		||||
    request.session["warn"] = False
 | 
			
		||||
 | 
			
		||||
    login = LoginView()
 | 
			
		||||
    login.init_get(request)
 | 
			
		||||
    ret = login.process_get()
 | 
			
		||||
 | 
			
		||||
    assert ret == LoginView.USER_AUTHENTICATED
 | 
			
		||||
 | 
			
		||||
    login = LoginView()
 | 
			
		||||
    response = login.get(request)
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
 | 
			
		||||
@pytest.mark.django_db
 | 
			
		||||
@dummy_service_pattern()
 | 
			
		||||
@dummy_user(username="test", session_key="test_session")
 | 
			
		||||
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
 | 
			
		||||
def test_view_login_get_auth_service():
 | 
			
		||||
    factory = RequestFactory()
 | 
			
		||||
    request = factory.post('/login?service=https://www.example.com')
 | 
			
		||||
    request.session = DummySession()
 | 
			
		||||
 | 
			
		||||
    request.session["authenticated"] = True
 | 
			
		||||
    request.session["username"] = "test"
 | 
			
		||||
    request.session["warn"] = False
 | 
			
		||||
 | 
			
		||||
    login = LoginView()
 | 
			
		||||
    login.init_get(request)
 | 
			
		||||
    ret = login.process_get()
 | 
			
		||||
 | 
			
		||||
    assert ret == LoginView.USER_AUTHENTICATED
 | 
			
		||||
 | 
			
		||||
    login = LoginView()
 | 
			
		||||
    response = login.get(request)
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 302
 | 
			
		||||
    assert response['Location'].startswith('https://www.example.com?ticket=ST-')
 | 
			
		||||
 | 
			
		||||
@pytest.mark.django_db
 | 
			
		||||
@dummy_service_pattern()
 | 
			
		||||
@dummy_user(username="test", session_key="test_session")
 | 
			
		||||
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
 | 
			
		||||
def test_view_login_get_auth_service_warn():
 | 
			
		||||
    factory = RequestFactory()
 | 
			
		||||
    request = factory.post('/login?service=https://www.example.com')
 | 
			
		||||
    request.session = DummySession()
 | 
			
		||||
 | 
			
		||||
    request.session["authenticated"] = True
 | 
			
		||||
    request.session["username"] = "test"
 | 
			
		||||
    request.session["warn"] = True
 | 
			
		||||
 | 
			
		||||
    login = LoginView()
 | 
			
		||||
    login.init_get(request)
 | 
			
		||||
    ret = login.process_get()
 | 
			
		||||
 | 
			
		||||
    assert ret == LoginView.USER_AUTHENTICATED
 | 
			
		||||
 | 
			
		||||
    login = LoginView()
 | 
			
		||||
    response = login.get(request)
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
@@ -1,80 +0,0 @@
 | 
			
		||||
from __future__ import absolute_import
 | 
			
		||||
from .init import *
 | 
			
		||||
 | 
			
		||||
from django.test import RequestFactory
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
import pytest
 | 
			
		||||
 | 
			
		||||
from cas_server.views import LogoutView
 | 
			
		||||
from cas_server import models
 | 
			
		||||
 | 
			
		||||
from .dummy import *
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.django_db
 | 
			
		||||
@dummy_user(username="test", session_key="test_session")
 | 
			
		||||
def test_logout_view():
 | 
			
		||||
    factory = RequestFactory()
 | 
			
		||||
    request = factory.get('/logout')
 | 
			
		||||
 | 
			
		||||
    request.session = DummySession()
 | 
			
		||||
 | 
			
		||||
    request.session["authenticated"] = True
 | 
			
		||||
    request.session["username"] = "test"
 | 
			
		||||
    request.session["warn"] = False
 | 
			
		||||
 | 
			
		||||
    logout = LogoutView()
 | 
			
		||||
    response = logout.get(request)
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
    assert not request.session.get("authenticated")
 | 
			
		||||
    assert not request.session.get("username")
 | 
			
		||||
    assert not request.session.get("warn")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.django_db
 | 
			
		||||
@dummy_user(username="test", session_key="test_session")
 | 
			
		||||
def test_logout_view_url():
 | 
			
		||||
    factory = RequestFactory()
 | 
			
		||||
    request = factory.get('/logout?url=https://www.example.com')
 | 
			
		||||
 | 
			
		||||
    request.session = DummySession()
 | 
			
		||||
 | 
			
		||||
    request.session["authenticated"] = True
 | 
			
		||||
    request.session["username"] = "test"
 | 
			
		||||
    request.session["warn"] = False
 | 
			
		||||
 | 
			
		||||
    logout = LogoutView()
 | 
			
		||||
    response = logout.get(request)
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 302
 | 
			
		||||
    assert response['Location'] == 'https://www.example.com'
 | 
			
		||||
    assert not request.session.get("authenticated")
 | 
			
		||||
    assert not request.session.get("username")
 | 
			
		||||
    assert not request.session.get("warn")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.django_db
 | 
			
		||||
@dummy_user(username="test", session_key="test_session")
 | 
			
		||||
def test_logout_view_service():
 | 
			
		||||
    factory = RequestFactory()
 | 
			
		||||
    request = factory.get('/logout?service=https://www.example.com')
 | 
			
		||||
 | 
			
		||||
    request.session = DummySession()
 | 
			
		||||
 | 
			
		||||
    request.session["authenticated"] = True
 | 
			
		||||
    request.session["username"] = "test"
 | 
			
		||||
    request.session["warn"] = False
 | 
			
		||||
 | 
			
		||||
    logout = LogoutView()
 | 
			
		||||
    response = logout.get(request)
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 302
 | 
			
		||||
    assert response['Location'] == 'https://www.example.com'
 | 
			
		||||
    assert not request.session.get("authenticated")
 | 
			
		||||
    assert not request.session.get("username")
 | 
			
		||||
    assert not request.session.get("warn")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -1,58 +0,0 @@
 | 
			
		||||
from __future__ import absolute_import
 | 
			
		||||
from .init import *
 | 
			
		||||
 | 
			
		||||
from django.test import RequestFactory
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
import pytest
 | 
			
		||||
 | 
			
		||||
from cas_server.views import Validate
 | 
			
		||||
from cas_server import models
 | 
			
		||||
 | 
			
		||||
from .dummy import *
 | 
			
		||||
 | 
			
		||||
@pytest.mark.django_db
 | 
			
		||||
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
 | 
			
		||||
def test_validate_view_ok():
 | 
			
		||||
    factory = RequestFactory()
 | 
			
		||||
    request = factory.get('/validate?ticket=ST-random&service=https://www.example.com')
 | 
			
		||||
 | 
			
		||||
    request.session = DummySession()
 | 
			
		||||
 | 
			
		||||
    validate = Validate()
 | 
			
		||||
    response = validate.get(request)
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
    assert response.content == b"yes\ntest\n"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.django_db
 | 
			
		||||
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
 | 
			
		||||
def test_validate_view_badservice():
 | 
			
		||||
    factory = RequestFactory()
 | 
			
		||||
    request = factory.get('/validate?ticket=ST-random&service=https://www.example2.com')
 | 
			
		||||
 | 
			
		||||
    request.session = DummySession()
 | 
			
		||||
 | 
			
		||||
    validate = Validate()
 | 
			
		||||
    response = validate.get(request)
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
    assert response.content == b"no\n"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.django_db
 | 
			
		||||
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random1")
 | 
			
		||||
def test_validate_view_badticket():
 | 
			
		||||
    factory = RequestFactory()
 | 
			
		||||
    request = factory.get('/validate?ticket=ST-random2&service=https://www.example.com')
 | 
			
		||||
 | 
			
		||||
    request.session = DummySession()
 | 
			
		||||
 | 
			
		||||
    validate = Validate()
 | 
			
		||||
    response = validate.get(request)
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
    assert response.content == b"no\n"
 | 
			
		||||
							
								
								
									
										2
									
								
								tox.ini
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								tox.ini
									
									
									
									
									
								
							@@ -17,7 +17,7 @@ deps =
 | 
			
		||||
    -r{toxinidir}/requirements-dev.txt
 | 
			
		||||
 | 
			
		||||
[testenv]
 | 
			
		||||
commands=py.test --tb native {posargs:tests}
 | 
			
		||||
commands=python run_tests {posargs:tests}
 | 
			
		||||
 | 
			
		||||
[testenv:py27-django17]
 | 
			
		||||
basepython=python2.7
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										22
									
								
								urls_tests.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								urls_tests.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,22 @@
 | 
			
		||||
"""cas URL Configuration
 | 
			
		||||
 | 
			
		||||
The `urlpatterns` list routes URLs to views. For more information please see:
 | 
			
		||||
    https://docs.djangoproject.com/en/1.9/topics/http/urls/
 | 
			
		||||
Examples:
 | 
			
		||||
Function views
 | 
			
		||||
    1. Add an import:  from my_app import views
 | 
			
		||||
    2. Add a URL to urlpatterns:  url(r'^$', views.home, name='home')
 | 
			
		||||
Class-based views
 | 
			
		||||
    1. Add an import:  from other_app.views import Home
 | 
			
		||||
    2. Add a URL to urlpatterns:  url(r'^$', Home.as_view(), name='home')
 | 
			
		||||
Including another URLconf
 | 
			
		||||
    1. Import the include() function: from django.conf.urls import url, include, include
 | 
			
		||||
    2. Add a URL to urlpatterns:  url(r'^blog/', include('blog.urls'))
 | 
			
		||||
"""
 | 
			
		||||
from django.conf.urls import url, include
 | 
			
		||||
from django.contrib import admin
 | 
			
		||||
 | 
			
		||||
urlpatterns = [
 | 
			
		||||
    url(r'^admin/', admin.site.urls),
 | 
			
		||||
    url(r'^', include('cas_server.urls', namespace='cas_server')),
 | 
			
		||||
]
 | 
			
		||||
		Reference in New Issue
	
	Block a user