diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..b4da6da --- /dev/null +++ b/.coveragerc @@ -0,0 +1,8 @@ +[report] +exclude_lines = + pragma: no cover + def __repr__ + def __unicode__ + raise AssertionError + raise NotImplementedError + if six.PY3: diff --git a/.gitignore b/.gitignore index 2ba2ee7..3b1bcb6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ *.pyc *.egg-info +*~ *.swp build/ @@ -8,6 +9,9 @@ cas/ dist/ db.sqlite3 manage.py +coverage.xml .tox test_venv +.coverage +htmlcov/ diff --git a/Makefile b/Makefile index e5b19f1..2273da9 100644 --- a/Makefile +++ b/Makefile @@ -44,3 +44,14 @@ test_project: test_venv test_venv/cas/manage.py run_test_server: test_project test_venv/bin/python test_venv/cas/manage.py runserver + +coverage: test_venv + test_venv/bin/pip install coverage + test_venv/bin/coverage run --source='cas_server' --omit='cas_server/migrations*' run_tests + test_venv/bin/coverage html + rm htmlcov/coverage_html.js # I am really pissed off by those keybord shortcuts + +coverage_codacy: coverage + test_venv/bin/coverage xml + test_venv/bin/pip install codacy-coverage + test_venv/bin/python-codacy-coverage -r coverage.xml diff --git a/README.rst b/README.rst index 6357f8e..124076b 100644 --- a/README.rst +++ b/README.rst @@ -10,8 +10,14 @@ CAS Server .. image:: https://img.shields.io/pypi/l/django-cas-server.svg :target: https://www.gnu.org/licenses/gpl-3.0.html +.. image:: https://api.codacy.com/project/badge/Grade/255c21623d6946ef8802fa7995b61366 + :target: https://www.codacy.com/app/valentin-samir/django-cas-server + +.. image:: https://api.codacy.com/project/badge/Coverage/255c21623d6946ef8802fa7995b61366 + :target: https://www.codacy.com/app/valentin-samir/django-cas-server + CAS Server is a Django application implementing the `CAS Protocol 3.0 Specification -`_. +`_. By defaut, the authentication process use django internal users but you can easily use any sources (see auth classes in the auth.py file) @@ -37,6 +43,15 @@ Features Quick start ----------- +0. If you want to make a virtualenv for ``django-cas-server``, you will need the following + dependencies on a bare debian like system:: + + virtualenv build-essential python-dev libxml2-dev libxslt1-dev zlib1g-dev + + If you want to use python3 instead of python2, replace ``python-dev`` with ``python3-dev``. + + If you intend to run the tox tests you will also need ``python3.4-dev`` depending of the current + version of python3 on your system. 1. Add "cas_server" to your INSTALLED_APPS setting like this:: @@ -70,7 +85,7 @@ Quick start 4. You should add some management commands to a crontab: ``clearsessions``, ``cas_clean_tickets`` and ``cas_clean_sessions``. - * ``clearsessions``: please see `Clearing the session store `_. + * ``clearsessions``: please see `Clearing the session store `_. * ``cas_clean_tickets``: old tickets and timed-out tickets do not get purge from the database automatically. They are just marked as invalid. ``cas_clean_tickets`` is a clean-up management command for this purpose. It send SingleLogOut request @@ -122,14 +137,14 @@ Template settings: Authentication settings: -* ``CAS_AUTH_CLASS``: A dotted path to a class implementing ``cas_server.auth.AuthUser``. - The default is ``"cas_server.auth.DjangoAuthUser"`` +* ``CAS_AUTH_CLASS``: A dotted path to a class or a class implementing + ``cas_server.auth.AuthUser``. The default is ``"cas_server.auth.DjangoAuthUser"`` * ``SESSION_COOKIE_AGE``: This is a django settings. Here, it control the delay in seconds after which inactive users are logged out. The default is ``1209600`` (2 weeks). You probably should reduce it to something like ``86400`` seconds (1 day). -* ``CAS_PROXY_CA_CERTIFICATE_PATH``: Path to certificates authority file. Usually on linux +* ``CAS_PROXY_CA_CERTIFICATE_PATH``: Path to certificate authorities file. Usually on linux the local CAs are in ``/etc/ssl/certs/ca-certificates.crt``. The default is ``True`` which tell requests to use its internal certificat authorities. Settings it to ``False`` should disable all x509 certificates validation and MUST not be done in production. @@ -162,7 +177,7 @@ Tickets validity settings: application. The default is ``60``. * ``CAS_PGT_VALIDITY``: Number of seconds the proxy granting tickets are valid. The default is ``3600`` (1 hour). -* ``CAS_TICKET_TIMEOUT``: Number of seconds a ticket is kept is the database before sending +* ``CAS_TICKET_TIMEOUT``: Number of seconds a ticket is kept in the database before sending Single Log Out request and being cleared. The default is ``86400`` (24 hours). Tickets miscellaneous settings: @@ -184,12 +199,12 @@ Tickets miscellaneous settings: * ``CAS_SERVICE_TICKET_PREFIX``: Prefix of service tickets. The default is ``"ST"``. The CAS specification mandate that service tickets MUST begin with the characters ST so you should not change this. -* ``CAS_PROXY_TICKET_PREFIX``: Prefix of proxy ticket. The default is ``"ST"``. +* ``CAS_PROXY_TICKET_PREFIX``: Prefix of proxy ticket. The default is ``"PT"``. * ``CAS_PROXY_GRANTING_TICKET_PREFIX``: Prefix of proxy granting ticket. The default is ``"PGT"``. * ``CAS_PROXY_GRANTING_TICKET_IOU_PREFIX``: Prefix of proxy granting ticket IOU. The default is ``"PGTIOU"``. -Mysql backend settings. Only usefull is you use the mysql authentication backend: +Mysql backend settings. Only usefull if you are using the mysql authentication backend: * ``CAS_SQL_HOST``: Host for the SQL server. The default is ``"localhost"``. * ``CAS_SQL_USERNAME``: Username for connecting to the SQL server. @@ -200,8 +215,29 @@ Mysql backend settings. Only usefull is you use the mysql authentication backend The username must be in field ``username``, the password in ``password``, additional fields are used as the user attributes. The default is ``"SELECT user AS usersame, pass AS password, users.* FROM users WHERE user = %s"`` -* ``CAS_SQL_PASSWORD_CHECK``: The method used to check the user password. Must be - ``"crypt"`` or ``"plain``". The default is ``"crypt"``. +* ``CAS_SQL_PASSWORD_CHECK``: The method used to check the user password. Must be one of the following: + + * ``"crypt"`` (see ), the password in the database + should begin this $ + * ``"ldap"`` (see https://tools.ietf.org/id/draft-stroeder-hashed-userpassword-values-01.html) + the password in the database must begin with one of {MD5}, {SMD5}, {SHA}, {SSHA}, {SHA256}, + {SSHA256}, {SHA384}, {SSHA384}, {SHA512}, {SSHA512}, {CRYPT}. + * ``"hex_HASH_NAME"`` with ``HASH_NAME`` in md5, sha1, sha224, sha256, sha384, sha512. + The hashed password in the database is compare to the hexadecimal digest of the clear + password hashed with the corresponding algorithm. + * ``"plain"``, the password in the database must be in clear. + + The default is ``"crypt"``. + + +Test backend settings. Only usefull if you are using the test authentication backend: + +* ``CAS_TEST_USER``: Username of the test user. The default is ``"test"``. +* ``CAS_TEST_PASSWORD``: Password of the test user. The default is ``"test"``. +* ``CAS_TEST_ATTRIBUTES``: Attributes of the test user. The default is + ``{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net', + 'alias': ['demo1', 'demo2']}``. + Authentication backend ---------------------- @@ -209,8 +245,8 @@ Authentication backend ``django-cas-server`` comes with some authentication backends: * dummy backend ``cas_server.auth.DummyAuthUser``: all authentication attempt fails. -* test backend ``cas_server.auth.TestAuthUser``: username is ``test`` and password is ``test`` - the returned attributes for the user are: ``{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'}`` +* test backend ``cas_server.auth.TestAuthUser``: username, password and returned attributes + for the user are defined by the ``CAS_TEST_*`` settings. * django backend ``cas_server.auth.DjangoAuthUser``: Users are authenticated agains django users system. This is the default backend. The returned attributes are the fields available on the user model. * mysql backend ``cas_server.auth.MysqlAuthUser``: see the 'Mysql backend settings' section. @@ -222,7 +258,7 @@ Logs ---- ``django-cas-server`` logs most of its actions. To enable login, you must set the ``LOGGING`` -(https://docs.djangoproject.com/en/dev/topics/logging) variable is ``settings.py``. +(https://docs.djangoproject.com/en/stable/topics/logging) variable in ``settings.py``. Users successful actions (login, logout) are logged with the level ``INFO``, failures are logged with the level ``WARNING`` and user attributes transmitted to a service are logged with the level ``DEBUG``. diff --git a/cas_server/__init__.py b/cas_server/__init__.py index 1bb1fa4..f830740 100644 --- a/cas_server/__init__.py +++ b/cas_server/__init__.py @@ -9,4 +9,4 @@ # # (c) 2015 Valentin Samir -default_app_config = 'cas_server.apps.AppConfig' +default_app_config = 'cas_server.apps.CasAppConfig' diff --git a/cas_server/admin.py b/cas_server/admin.py index bfa5a73..a6a9be4 100644 --- a/cas_server/admin.py +++ b/cas_server/admin.py @@ -14,9 +14,9 @@ from .models import ServiceTicket, ProxyTicket, ProxyGrantingTicket, User, Servi from .models import Username, ReplaceAttributName, ReplaceAttributValue, FilterAttributValue from .forms import TicketForm -tickets_readonly_fields = ('validate', 'service', 'service_pattern', +TICKETS_READONLY_FIELDS = ('validate', 'service', 'service_pattern', 'creation', 'renew', 'single_log_out', 'value') -tickets_fields = ('validate', 'service', 'service_pattern', +TICKETS_FIELDS = ('validate', 'service', 'service_pattern', 'creation', 'renew', 'single_log_out') @@ -25,8 +25,8 @@ class ServiceTicketInline(admin.TabularInline): model = ServiceTicket extra = 0 form = TicketForm - readonly_fields = tickets_readonly_fields - fields = tickets_fields + readonly_fields = TICKETS_READONLY_FIELDS + fields = TICKETS_FIELDS class ProxyTicketInline(admin.TabularInline): @@ -34,8 +34,8 @@ class ProxyTicketInline(admin.TabularInline): model = ProxyTicket extra = 0 form = TicketForm - readonly_fields = tickets_readonly_fields - fields = tickets_fields + readonly_fields = TICKETS_READONLY_FIELDS + fields = TICKETS_FIELDS class ProxyGrantingInline(admin.TabularInline): @@ -43,8 +43,8 @@ class ProxyGrantingInline(admin.TabularInline): model = ProxyGrantingTicket extra = 0 form = TicketForm - readonly_fields = tickets_readonly_fields - fields = tickets_fields[1:] + readonly_fields = TICKETS_READONLY_FIELDS + fields = TICKETS_FIELDS[1:] class UserAdmin(admin.ModelAdmin): diff --git a/cas_server/apps.py b/cas_server/apps.py index bb93d57..c34b6eb 100644 --- a/cas_server/apps.py +++ b/cas_server/apps.py @@ -2,6 +2,6 @@ from django.utils.translation import ugettext_lazy as _ from django.apps import AppConfig -class AppConfig(AppConfig): +class CasAppConfig(AppConfig): name = 'cas_server' verbose_name = _('Central Authentication Service') diff --git a/cas_server/auth.py b/cas_server/auth.py index 99018a4..231a489 100644 --- a/cas_server/auth.py +++ b/cas_server/auth.py @@ -15,10 +15,10 @@ from django.contrib.auth import get_user_model from django.utils import timezone from datetime import timedelta -try: +try: # pragma: no cover import MySQLdb import MySQLdb.cursors - import crypt + from utils import check_password except ImportError: MySQLdb = None @@ -31,14 +31,14 @@ class AuthUser(object): def test_password(self, password): """test `password` agains the user""" - raise NotImplemented() + raise NotImplementedError() def attributs(self): """return a dict of user attributes""" - raise NotImplemented() + raise NotImplementedError() -class DummyAuthUser(AuthUser): +class DummyAuthUser(AuthUser): # pragma: no cover """A Dummy authentication class""" def __init__(self, username): @@ -62,14 +62,14 @@ class TestAuthUser(AuthUser): def test_password(self, password): """test `password` agains the user""" - return self.username == "test" and password == "test" + return self.username == settings.CAS_TEST_USER and password == settings.CAS_TEST_PASSWORD def attributs(self): """return a dict of user attributes""" - return {'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'} + return settings.CAS_TEST_ATTRIBUTES -class MysqlAuthUser(AuthUser): +class MysqlAuthUser(AuthUser): # pragma: no cover """A mysql auth class: authentication user agains a mysql database""" user = None @@ -94,30 +94,25 @@ class MysqlAuthUser(AuthUser): def test_password(self, password): """test `password` agains the user""" - if not self.user: - return False + if self.user: + return check_password( + settings.CAS_SQL_PASSWORD_CHECK, + password, + self.user["password"], + settings.CAS_SQL_DBCHARSET + ) else: - if settings.CAS_SQL_PASSWORD_CHECK == "plain": - return password == self.user["password"] - elif settings.CAS_SQL_PASSWORD_CHECK == "crypt": - if self.user["password"].startswith('$'): - salt = '$'.join(self.user["password"].split('$', 3)[:-1]) - return crypt.crypt(password, salt) == self.user["password"] - else: - return crypt.crypt( - password, - self.user["password"][:2] - ) == self.user["password"] + return False def attributs(self): """return a dict of user attributes""" - if not self.user: - return {} - else: + if self.user: return self.user + else: + return {} -class DjangoAuthUser(AuthUser): +class DjangoAuthUser(AuthUser): # pragma: no cover """A django auth class: authenticate user agains django internal users""" user = None @@ -131,21 +126,20 @@ class DjangoAuthUser(AuthUser): def test_password(self, password): """test `password` agains the user""" - if not self.user: - return False - else: + if self.user: return self.user.check_password(password) + else: + return False def attributs(self): """return a dict of user attributes""" - if not self.user: - return {} - else: + if self.user: attr = {} for field in self.user._meta.fields: attr[field.attname] = getattr(self.user, field.attname) return attr - + else: + return {} class CASFederateAuth(AuthUser): user = None diff --git a/cas_server/default_settings.py b/cas_server/default_settings.py index 0705b19..750a9c9 100644 --- a/cas_server/default_settings.py +++ b/cas_server/default_settings.py @@ -78,6 +78,17 @@ setting_default('CAS_SQL_USER_QUERY', 'SELECT user AS usersame, pass AS ' 'password, users.* FROM users WHERE user = %s') setting_default('CAS_SQL_PASSWORD_CHECK', 'crypt') # crypt or plain +setting_default('CAS_TEST_USER', 'test') +setting_default('CAS_TEST_PASSWORD', 'test') +setting_default( + 'CAS_TEST_ATTRIBUTES', + { + 'nom': 'Nymous', + 'prenom': 'Ano', + 'email': 'anonymous@example.net', + 'alias': ['demo1', 'demo2'] + } +) setting_default('CAS_FEDERATE', False) # A dict of "provider suffix" -> (provider CAS server url, CAS version, verbose name) diff --git a/cas_server/static/cas_server/cas.js b/cas_server/static/cas_server/cas.js index 4c42dde..06e1a5d 100644 --- a/cas_server/static/cas_server/cas.js +++ b/cas_server/static/cas_server/cas.js @@ -1,55 +1,52 @@ function cas_login(cas_server_login, service, login_service, callback){ - url = cas_server_login + '?service=' + encodeURIComponent(service); + var url = cas_server_login + "?service=" + encodeURIComponent(service); $.ajax({ - type: 'GET', - url:url, - beforeSend: function (request) { + type: "GET", + url, + beforeSend(request) { request.setRequestHeader("X-AJAX", "1"); }, xhrFields: { withCredentials: true }, - success: function(data, textStatus, request){ - if(data.status == 'success'){ + success(data, textStatus, request){ + if(data.status === "success"){ $.ajax({ - type: 'GET', + type: "GET", url: data.url, xhrFields: { withCredentials: true }, success: callback, - error: function (request, textStatus, errorThrown) {}, + error(request, textStatus, errorThrown) {}, }); } else { - if(data.detail == "login required"){ - window.location.href = cas_server_login + '?service=' + encodeURIComponent(login_service); + if(data.detail === "login required"){ + window.location.href = cas_server_login + "?service=" + encodeURIComponent(login_service); } else { - alert('error: ' + data.messages[1].message); + alert("error: " + data.messages[1].message); } } }, - error: function (request, textStatus, errorThrown) {}, + error(request, textStatus, errorThrown) {}, }); } function cas_logout(cas_server_logout){ $.ajax({ - type: 'GET', - url:cas_server_logout, - beforeSend: function (request) { + type: "GET", + url: cas_server_logout, + beforeSend(request) { request.setRequestHeader("X-AJAX", "1"); }, xhrFields: { withCredentials: true }, - error: function (request, textStatus, errorThrown) {}, - success: function(data, textStatus, request){ - if(data.status == 'error'){ - alert('error: ' + data.messages[1].message); + error(request, textStatus, errorThrown) {}, + success(data, textStatus, request){ + if(data.status === "error"){ + alert("error: " + data.messages[1].message); } }, }); } - - - diff --git a/cas_server/static/cas_server/login.css b/cas_server/static/cas_server/login.css index b29433d..6d3524b 100644 --- a/cas_server/static/cas_server/login.css +++ b/cas_server/static/cas_server/login.css @@ -43,14 +43,14 @@ body { @media screen and (max-width: 680px) { #app-name { - margin: 0px; + margin: 0; } #app-name img { display: block; margin: auto; } body { - padding-top: 0px; - padding-bottom: 0px; + padding-top: 0; + padding-bottom: 0; } } diff --git a/cas_server/tests.py b/cas_server/tests.py new file mode 100644 index 0000000..916a6d4 --- /dev/null +++ b/cas_server/tests.py @@ -0,0 +1,943 @@ +from .default_settings import settings + +import django +from django.test import TestCase +from django.test import Client + +import re +import six +import random +from lxml import etree +from six.moves import range + +from cas_server import models +from cas_server import utils + + +def copy_form(form): + """Copy form value into a dict""" + params = {} + for field in form: + if field.value(): + params[field.name] = field.value() + else: + params[field.name] = "" + return params + + +def get_login_page_params(client=None): + """Return a client and the POST params for the client to login""" + if client is None: + client = Client() + response = client.get('/login') + params = copy_form(response.context["form"]) + return client, params + + +def get_auth_client(**update): + """return a authenticated client""" + client, params = get_login_page_params() + params["username"] = settings.CAS_TEST_USER + params["password"] = settings.CAS_TEST_PASSWORD + params.update(update) + + client.post('/login', params) + return client + + +def get_user_ticket_request(service): + """Make an auth client to request a ticket for `service`, return the tuple (user, ticket)""" + client = get_auth_client() + response = client.get("/login", {"service": service}) + ticket_value = response['Location'].split('ticket=')[-1] + user = models.User.objects.get( + username=settings.CAS_TEST_USER, + session_key=client.session.session_key + ) + ticket = models.ServiceTicket.objects.get(value=ticket_value) + return (user, ticket) + + +def get_pgt(): + """return a dict contening a service, user and PGT ticket for this service""" + (host, port) = utils.PGTUrlHandler.run()[1:3] + service = "http://%s:%s" % (host, port) + + (user, ticket) = get_user_ticket_request(service) + + client = Client() + client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service}) + params = utils.PGTUrlHandler.PARAMS.copy() + + params["service"] = service + params["user"] = user + + return params + + +class CheckPasswordCase(TestCase): + """Tests for the utils function `utils.check_password`""" + + def setUp(self): + """Generate random bytes string that will be used ass passwords""" + self.password1 = utils.gen_saml_id() + self.password2 = utils.gen_saml_id() + if not isinstance(self.password1, bytes): + self.password1 = self.password1.encode("utf8") + self.password2 = self.password2.encode("utf8") + + def test_setup(self): + """check that generated password are bytes""" + self.assertIsInstance(self.password1, bytes) + self.assertIsInstance(self.password2, bytes) + + def test_plain(self): + """test the plain auth method""" + self.assertTrue(utils.check_password("plain", self.password1, self.password1, "utf8")) + self.assertFalse(utils.check_password("plain", self.password1, self.password2, "utf8")) + + def test_crypt(self): + """test the crypt auth method""" + if six.PY3: + hashed_password1 = utils.crypt.crypt( + self.password1.decode("utf8"), + "$6$UVVAQvrMyXMF3FF3" + ).encode("utf8") + else: + hashed_password1 = utils.crypt.crypt(self.password1, "$6$UVVAQvrMyXMF3FF3") + + self.assertTrue(utils.check_password("crypt", self.password1, hashed_password1, "utf8")) + self.assertFalse(utils.check_password("crypt", self.password2, hashed_password1, "utf8")) + + def test_ldap_ssha(self): + """test the ldap auth method with a {SSHA} scheme""" + salt = b"UVVAQvrMyXMF3FF3" + hashed_password1 = utils.LdapHashUserPassword.hash(b'{SSHA}', self.password1, salt, "utf8") + + self.assertIsInstance(hashed_password1, bytes) + self.assertTrue(utils.check_password("ldap", self.password1, hashed_password1, "utf8")) + self.assertFalse(utils.check_password("ldap", self.password2, hashed_password1, "utf8")) + + def test_hex_md5(self): + """test the hex_md5 auth method""" + hashed_password1 = utils.hashlib.md5(self.password1).hexdigest() + + self.assertTrue(utils.check_password("hex_md5", self.password1, hashed_password1, "utf8")) + self.assertFalse(utils.check_password("hex_md5", self.password2, hashed_password1, "utf8")) + + def test_hex_sha512(self): + """test the hex_sha512 auth method""" + hashed_password1 = utils.hashlib.sha512(self.password1).hexdigest() + + self.assertTrue( + utils.check_password("hex_sha512", self.password1, hashed_password1, "utf8") + ) + self.assertFalse( + utils.check_password("hex_sha512", self.password2, hashed_password1, "utf8") + ) + + +class LoginTestCase(TestCase): + """Tests for the login view""" + def setUp(self): + """ + Prepare the test context: + * set the auth class to 'cas_server.auth.TestAuthUser' + * create a service pattern for https://www.example.com/** + * Set the service pattern to return all user attributes + """ + settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' + + # For general purpose testing + self.service_pattern = models.ServicePattern.objects.create( + name="example", + pattern="^https://www\.example\.com(/.*)?$", + ) + models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) + + # For testing the restrict_users attributes + self.service_pattern_restrict_user_fail = models.ServicePattern.objects.create( + name="restrict_user_fail", + pattern="^https://restrict_user_fail\.example\.com(/.*)?$", + restrict_users=True, + ) + self.service_pattern_restrict_user_success = models.ServicePattern.objects.create( + name="restrict_user_success", + pattern="^https://restrict_user_success\.example\.com(/.*)?$", + restrict_users=True, + ) + models.Username.objects.create( + value=settings.CAS_TEST_USER, + service_pattern=self.service_pattern_restrict_user_success + ) + + # For testing the user attributes filtering conditions + self.service_pattern_filter_fail = models.ServicePattern.objects.create( + name="filter_fail", + pattern="^https://filter_fail\.example\.com(/.*)?$", + ) + models.FilterAttributValue.objects.create( + attribut="right", + pattern="^admin$", + service_pattern=self.service_pattern_filter_fail + ) + self.service_pattern_filter_success = models.ServicePattern.objects.create( + name="filter_success", + pattern="^https://filter_success\.example\.com(/.*)?$", + ) + models.FilterAttributValue.objects.create( + attribut="email", + pattern="^%s$" % re.escape(settings.CAS_TEST_ATTRIBUTES['email']), + service_pattern=self.service_pattern_filter_success + ) + + # For testing the user_field attributes + self.service_pattern_field_needed_fail = models.ServicePattern.objects.create( + name="field_needed_fail", + pattern="^https://field_needed_fail\.example\.com(/.*)?$", + user_field="uid" + ) + self.service_pattern_field_needed_success = models.ServicePattern.objects.create( + name="field_needed_success", + pattern="^https://field_needed_success\.example\.com(/.*)?$", + user_field="nom" + ) + + def assert_logged(self, client, response, warn=False, code=200): + """Assertions testing that client is well authenticated""" + self.assertEqual(response.status_code, code) + self.assertTrue( + ( + b"You have successfully logged into " + b"the Central Authentication Service" + ) in response.content + ) + self.assertTrue(client.session["username"] == settings.CAS_TEST_USER) + self.assertTrue(client.session["warn"] is warn) + self.assertTrue(client.session["authenticated"] is True) + + self.assertTrue( + models.User.objects.get( + username=settings.CAS_TEST_USER, + session_key=client.session.session_key + ) + ) + + def assert_login_failed(self, client, response, code=200): + """Assertions testing a failed login attempt""" + self.assertEqual(response.status_code, code) + self.assertFalse( + ( + b"You have successfully logged into " + b"the Central Authentication Service" + ) in response.content + ) + + self.assertTrue(client.session.get("username") is None) + self.assertTrue(client.session.get("warn") is None) + self.assertTrue(client.session.get("authenticated") is None) + + def test_login_view_post_goodpass_goodlt(self): + """Test a successul login""" + client, params = get_login_page_params() + params["username"] = settings.CAS_TEST_USER + params["password"] = settings.CAS_TEST_PASSWORD + self.assertTrue(params['lt'] in client.session['lt']) + + response = client.post('/login', params) + self.assert_logged(client, response) + # LoginTicket conssumed + self.assertTrue(params['lt'] not in client.session['lt']) + + def test_login_view_post_goodpass_goodlt_warn(self): + """Test a successul login requesting to be warned before creating services tickets""" + client, params = get_login_page_params() + params["username"] = settings.CAS_TEST_USER + params["password"] = settings.CAS_TEST_PASSWORD + params["warn"] = "on" + + response = client.post('/login', params) + self.assert_logged(client, response, warn=True) + + def test_lt_max(self): + """Check we only keep the last 100 Login Ticket for a user""" + client, params = get_login_page_params() + current_lt = params["lt"] + i_in_test = random.randint(0, 100) + i_not_in_test = random.randint(100, 150) + for i in range(150): + if i == i_in_test: + self.assertTrue(current_lt in client.session['lt']) + if i == i_not_in_test: + self.assertTrue(current_lt not in client.session['lt']) + self.assertTrue(len(client.session['lt']) <= 100) + client, params = get_login_page_params(client) + self.assertTrue(len(client.session['lt']) <= 100) + + def test_login_view_post_badlt(self): + """Login attempt with a bad LoginTicket""" + client, params = get_login_page_params() + params["username"] = settings.CAS_TEST_USER + params["password"] = settings.CAS_TEST_PASSWORD + params["lt"] = 'LT-random' + + response = client.post('/login', params) + + self.assert_login_failed(client, response) + self.assertTrue(b"Invalid login ticket" in response.content) + + def test_login_view_post_badpass_good_lt(self): + """Login attempt with a bad password""" + client, params = get_login_page_params() + params["username"] = settings.CAS_TEST_USER + params["password"] = "test2" + response = client.post('/login', params) + + self.assert_login_failed(client, response) + self.assertTrue( + ( + b"The credentials you provided cannot be " + b"determined to be authentic" + ) in response.content + ) + + def assert_ticket_attributes(self, client, ticket_value): + """check the ticket attributes in the db""" + user = models.User.objects.get( + username=settings.CAS_TEST_USER, + session_key=client.session.session_key + ) + self.assertTrue(user) + ticket = models.ServiceTicket.objects.get(value=ticket_value) + self.assertEqual(ticket.user, user) + self.assertEqual(ticket.attributs, settings.CAS_TEST_ATTRIBUTES) + self.assertEqual(ticket.validate, False) + self.assertEqual(ticket.service_pattern, self.service_pattern) + + def assert_service_ticket(self, client, response): + """check that a ticket is well emited when requested on a allowed service""" + self.assertEqual(response.status_code, 302) + self.assertTrue(response.has_header('Location')) + self.assertTrue( + response['Location'].startswith( + "https://www.example.com?ticket=%s-" % settings.CAS_SERVICE_TICKET_PREFIX + ) + ) + + ticket_value = response['Location'].split('ticket=')[-1] + self.assert_ticket_attributes(client, ticket_value) + + def test_view_login_get_allowed_service(self): + """Request a ticket for an allowed service by an unauthenticated client""" + client = Client() + response = client.get("/login?service=https://www.example.com") + self.assertEqual(response.status_code, 200) + self.assertTrue( + ( + b"Authentication required by service " + b"example (https://www.example.com)" + ) in response.content + ) + + def test_view_login_get_denied_service(self): + """Request a ticket for an denied service by an unauthenticated client""" + client = Client() + response = client.get("/login?service=https://www.example.net") + self.assertEqual(response.status_code, 200) + self.assertTrue(b"Service https://www.example.net non allowed" in response.content) + + def test_view_login_get_auth_allowed_service(self): + """Request a ticket for an allowed service by an authenticated client""" + # client is already authenticated + client = get_auth_client() + response = client.get("/login?service=https://www.example.com") + self.assert_service_ticket(client, response) + + def test_view_login_get_auth_allowed_service_warn(self): + """Request a ticket for an allowed service by an authenticated client""" + # client is already authenticated + client = get_auth_client(warn="on") + response = client.get("/login?service=https://www.example.com") + self.assertEqual(response.status_code, 200) + self.assertTrue( + ( + b"Authentication has been required by service " + b"example (https://www.example.com)" + ) in response.content + ) + + params = copy_form(response.context["form"]) + response = client.post("/login", params) + self.assert_service_ticket(client, response) + + def test_view_login_get_auth_denied_service(self): + """Request a ticket for a not allowed service by an authenticated client""" + client = get_auth_client() + response = client.get("/login?service=https://www.example.org") + self.assertEqual(response.status_code, 200) + self.assertTrue(b"Service https://www.example.org non allowed" in response.content) + + def test_user_logged_not_in_db(self): + """If the user is logged but has been delete from the database, it should be logged out""" + client = get_auth_client() + models.User.objects.get( + username=settings.CAS_TEST_USER, + session_key=client.session.session_key + ).delete() + response = client.get("/login") + + self.assert_login_failed(client, response, code=302) + if django.VERSION < (1, 9): + self.assertEqual(response["Location"], "http://testserver/login") + else: + self.assertEqual(response["Location"], "/login?") + + def test_service_restrict_user(self): + """Testing the restric user capability fro a service""" + service = "https://restrict_user_fail.example.com" + client = get_auth_client() + response = client.get("/login", {'service': service}) + self.assertEqual(response.status_code, 200) + self.assertTrue(b"Username non allowed" in response.content) + + service = "https://restrict_user_success.example.com" + response = client.get("/login", {'service': service}) + self.assertEqual(response.status_code, 302) + self.assertTrue(response["Location"].startswith("%s?ticket=" % service)) + + def test_service_filter(self): + """Test the filtering on user attributes""" + service = "https://filter_fail.example.com" + client = get_auth_client() + response = client.get("/login", {'service': service}) + self.assertEqual(response.status_code, 200) + self.assertTrue(b"User charateristics non allowed" in response.content) + + service = "https://filter_success.example.com" + response = client.get("/login", {'service': service}) + self.assertEqual(response.status_code, 302) + self.assertTrue(response["Location"].startswith("%s?ticket=" % service)) + + def test_service_user_field(self): + """Test using a user attribute as username: case on if the attribute exists or not""" + service = "https://field_needed_fail.example.com" + client = get_auth_client() + response = client.get("/login", {'service': service}) + self.assertEqual(response.status_code, 200) + self.assertTrue(b"The attribut uid is needed to use that service" in response.content) + + service = "https://field_needed_success.example.com" + response = client.get("/login", {'service': service}) + self.assertEqual(response.status_code, 302) + self.assertTrue(response["Location"].startswith("%s?ticket=" % service)) + + def test_gateway(self): + """test gateway parameter""" + + # First with an authenticated client that fail to get a ticket for a service + service = "https://restrict_user_fail.example.com" + client = get_auth_client() + response = client.get("/login", {'service': service, 'gateway': 'on'}) + self.assertEqual(response.status_code, 302) + self.assertEqual(response["Location"], service) + + # second for an user not yet authenticated on a valid service + client = Client() + response = client.get('/login', {'service': service, 'gateway': 'on'}) + self.assertEqual(response.status_code, 302) + self.assertEqual(response["Location"], service) + + +class LogoutTestCase(TestCase): + + def setUp(self): + settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' + + def test_logout_view(self): + client = get_auth_client() + + response = client.get("/login") + self.assertEqual(response.status_code, 200) + self.assertTrue( + ( + b"You have successfully logged into " + b"the Central Authentication Service" + ) in response.content + ) + + response = client.get("/logout") + self.assertEqual(response.status_code, 200) + self.assertTrue( + ( + b"You have successfully logged out from " + b"the Central Authentication Service" + ) in response.content + ) + + response = client.get("/login") + self.assertEqual(response.status_code, 200) + self.assertFalse( + ( + b"You have successfully logged into " + b"the Central Authentication Service" + ) in response.content + ) + + def test_logout_view_url(self): + client = get_auth_client() + + response = client.get('/logout?url=https://www.example.com') + self.assertEqual(response.status_code, 302) + self.assertTrue(response.has_header("Location")) + self.assertEqual(response["Location"], "https://www.example.com") + + response = client.get("/login") + self.assertEqual(response.status_code, 200) + self.assertFalse( + ( + b"You have successfully logged into " + b"the Central Authentication Service" + ) in response.content + ) + + def test_logout_view_service(self): + client = get_auth_client() + + response = client.get('/logout?service=https://www.example.com') + self.assertEqual(response.status_code, 302) + self.assertTrue(response.has_header("Location")) + self.assertEqual(response["Location"], "https://www.example.com") + + response = client.get("/login") + self.assertEqual(response.status_code, 200) + self.assertFalse( + ( + b"You have successfully logged into " + b"the Central Authentication Service" + ) in response.content + ) + + +class AuthTestCase(TestCase): + + def setUp(self): + settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' + self.service = 'https://www.example.com' + models.ServicePattern.objects.create( + name="example", + pattern="^https://www\.example\.com(/.*)?$" + ) + + def test_auth_view_goodpass(self): + settings.CAS_AUTH_SHARED_SECRET = 'test' + client = Client() + response = client.post( + '/auth', + { + 'username': settings.CAS_TEST_USER, + 'password': settings.CAS_TEST_PASSWORD, + 'service': self.service, + 'secret': 'test' + } + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b'yes\n') + + def test_auth_view_badpass(self): + settings.CAS_AUTH_SHARED_SECRET = 'test' + client = Client() + response = client.post( + '/auth', + { + 'username': settings.CAS_TEST_USER, + 'password': 'badpass', + 'service': self.service, + 'secret': 'test' + } + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b'no\n') + + def test_auth_view_badservice(self): + settings.CAS_AUTH_SHARED_SECRET = 'test' + client = Client() + response = client.post( + '/auth', + { + 'username': settings.CAS_TEST_USER, + 'password': settings.CAS_TEST_PASSWORD, + 'service': 'https://www.example.org', + 'secret': 'test' + } + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b'no\n') + + def test_auth_view_badsecret(self): + settings.CAS_AUTH_SHARED_SECRET = 'test' + client = Client() + response = client.post( + '/auth', + { + 'username': settings.CAS_TEST_USER, + 'password': settings.CAS_TEST_PASSWORD, + 'service': self.service, + 'secret': 'badsecret' + } + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b'no\n') + + def test_auth_view_badsettings(self): + settings.CAS_AUTH_SHARED_SECRET = None + client = Client() + response = client.post( + '/auth', + { + 'username': settings.CAS_TEST_USER, + 'password': settings.CAS_TEST_PASSWORD, + 'service': self.service, + 'secret': 'test' + } + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b"no\nplease set CAS_AUTH_SHARED_SECRET") + + +class ValidateTestCase(TestCase): + + def setUp(self): + settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' + self.service = 'https://www.example.com' + self.service_pattern = models.ServicePattern.objects.create( + name="example", + pattern="^https://www\.example\.com(/.*)?$" + ) + models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) + + def test_validate_view_ok(self): + ticket = get_user_ticket_request(self.service)[1] + + client = Client() + response = client.get('/validate', {'ticket': ticket.value, 'service': self.service}) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b'yes\ntest\n') + + def test_validate_view_badservice(self): + ticket = get_user_ticket_request(self.service)[1] + + client = Client() + response = client.get( + '/validate', + {'ticket': ticket.value, 'service': "https://www.example.org"} + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b'no\n') + + def test_validate_view_badticket(self): + get_user_ticket_request(self.service) + + client = Client() + response = client.get( + '/validate', + {'ticket': "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX, 'service': self.service} + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b'no\n') + + +class ValidateServiceTestCase(TestCase): + + def setUp(self): + settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' + self.service = 'http://127.0.0.1:45678' + self.service_pattern = models.ServicePattern.objects.create( + name="localhost", + pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$", + proxy_callback=True + ) + models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) + + def test_validate_service_view_ok(self): + ticket = get_user_ticket_request(self.service)[1] + + client = Client() + response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': self.service}) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + sucess = root.xpath( + "//cas:authenticationSuccess", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) + self.assertTrue(sucess) + + users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(users), 1) + self.assertEqual(users[0].text, settings.CAS_TEST_USER) + + attributes = root.xpath( + "//cas:attributes", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) + self.assertEqual(len(attributes), 1) + attrs1 = set() + for attr in attributes[0]: + attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text)) + + attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(attributes), len(attrs1)) + attrs2 = set() + for attr in attributes: + attrs2.add((attr.attrib['name'], attr.attrib['value'])) + original = set() + for key, value in settings.CAS_TEST_ATTRIBUTES.items(): + if isinstance(value, list): + for sub_value in value: + original.add((key, sub_value)) + else: + original.add((key, value)) + self.assertEqual(attrs1, attrs2) + self.assertEqual(attrs1, original) + + def test_validate_service_view_badservice(self): + ticket = get_user_ticket_request(self.service)[1] + + client = Client() + bad_service = "https://www.example.org" + response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': bad_service}) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + error = root.xpath( + "//cas:authenticationFailure", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) + self.assertEqual(len(error), 1) + self.assertEqual(error[0].attrib['code'], "INVALID_SERVICE") + self.assertEqual(error[0].text, bad_service) + + def test_validate_service_view_badticket_goodprefix(self): + get_user_ticket_request(self.service) + + client = Client() + bad_ticket = "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX + response = client.get('/serviceValidate', {'ticket': bad_ticket, 'service': self.service}) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + error = root.xpath( + "//cas:authenticationFailure", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) + self.assertEqual(len(error), 1) + self.assertEqual(error[0].attrib['code'], "INVALID_TICKET") + self.assertEqual(error[0].text, 'ticket not found') + + def test_validate_service_view_badticket_badprefix(self): + get_user_ticket_request(self.service) + + client = Client() + bad_ticket = "RANDOM" + response = client.get('/serviceValidate', {'ticket': bad_ticket, 'service': self.service}) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + error = root.xpath( + "//cas:authenticationFailure", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) + self.assertEqual(len(error), 1) + self.assertEqual(error[0].attrib['code'], "INVALID_TICKET") + self.assertEqual(error[0].text, bad_ticket) + + def test_validate_service_view_ok_pgturl(self): + (host, port) = utils.PGTUrlHandler.run()[1:3] + service = "http://%s:%s" % (host, port) + + ticket = get_user_ticket_request(service)[1] + + client = Client() + response = client.get( + '/serviceValidate', + {'ticket': ticket.value, 'service': service, 'pgtUrl': service} + ) + pgt_params = utils.PGTUrlHandler.PARAMS.copy() + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + pgtiou = root.xpath( + "//cas:proxyGrantingTicket", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) + self.assertEqual(len(pgtiou), 1) + self.assertEqual(pgt_params["pgtIou"], pgtiou[0].text) + self.assertTrue("pgtId" in pgt_params) + + def test_validate_service_pgturl_bad_proxy_callback(self): + self.service_pattern.proxy_callback = False + self.service_pattern.save() + ticket = get_user_ticket_request(self.service)[1] + + client = Client() + response = client.get( + '/serviceValidate', + {'ticket': ticket.value, 'service': self.service, 'pgtUrl': self.service} + ) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + error = root.xpath( + "//cas:authenticationFailure", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) + self.assertEqual(len(error), 1) + self.assertEqual(error[0].attrib['code'], "INVALID_PROXY_CALLBACK") + self.assertEqual(error[0].text, "callback url not allowed by configuration") + + +class ProxyTestCase(TestCase): + + def setUp(self): + settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' + self.service = 'http://127.0.0.1' + self.service_pattern = models.ServicePattern.objects.create( + name="localhost", + pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$", + proxy=True, + proxy_callback=True + ) + models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) + + def test_validate_proxy_ok(self): + params = get_pgt() + + # get a proxy ticket + client1 = Client() + response = client1.get('/proxy', {'pgt': params['pgtId'], 'targetService': self.service}) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + sucess = root.xpath("//cas:proxySuccess", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertTrue(sucess) + + proxy_ticket = root.xpath( + "//cas:proxyTicket", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) + self.assertEqual(len(proxy_ticket), 1) + proxy_ticket = proxy_ticket[0].text + + # validate the proxy ticket + client2 = Client() + response = client2.get('/proxyValidate', {'ticket': proxy_ticket, 'service': self.service}) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + sucess = root.xpath( + "//cas:authenticationSuccess", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) + self.assertTrue(sucess) + + # check that the proxy is send to the end service + proxies = root.xpath("//cas:proxies", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(proxies), 1) + proxy = proxies[0].xpath("//cas:proxy", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(proxy), 1) + self.assertEqual(proxy[0].text, params["service"]) + + # same tests than those for serviceValidate + users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(users), 1) + self.assertEqual(users[0].text, settings.CAS_TEST_USER) + + attributes = root.xpath( + "//cas:attributes", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) + self.assertEqual(len(attributes), 1) + attrs1 = set() + for attr in attributes[0]: + attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text)) + + attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(attributes), len(attrs1)) + attrs2 = set() + for attr in attributes: + attrs2.add((attr.attrib['name'], attr.attrib['value'])) + original = set() + for key, value in settings.CAS_TEST_ATTRIBUTES.items(): + if isinstance(value, list): + for sub_value in value: + original.add((key, sub_value)) + else: + original.add((key, value)) + self.assertEqual(attrs1, attrs2) + self.assertEqual(attrs1, original) + + def test_validate_proxy_bad(self): + params = get_pgt() + + # bad PGT + client1 = Client() + response = client1.get( + '/proxy', + { + 'pgt': "%s-RANDOM" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX, + 'targetService': params['service'] + } + ) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + error = root.xpath( + "//cas:authenticationFailure", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) + self.assertEqual(len(error), 1) + self.assertEqual(error[0].attrib['code'], "INVALID_TICKET") + self.assertEqual( + error[0].text, + "PGT %s-RANDOM not found" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX + ) + + # bad targetService + client2 = Client() + response = client2.get( + '/proxy', + {'pgt': params['pgtId'], 'targetService': "https://www.example.org"} + ) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + error = root.xpath( + "//cas:authenticationFailure", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) + self.assertEqual(len(error), 1) + self.assertEqual(error[0].attrib['code'], "UNAUTHORIZED_SERVICE") + self.assertEqual(error[0].text, "https://www.example.org") + + # service do not allow proxy ticket + self.service_pattern.proxy = False + self.service_pattern.save() + + client3 = Client() + response = client3.get( + '/proxy', + {'pgt': params['pgtId'], 'targetService': params['service']} + ) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + error = root.xpath( + "//cas:authenticationFailure", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) + self.assertEqual(len(error), 1) + self.assertEqual(error[0].attrib['code'], "UNAUTHORIZED_SERVICE") + self.assertEqual( + error[0].text, + 'the service %s do not allow proxy ticket' % params['service'] + ) diff --git a/cas_server/urls.py b/cas_server/urls.py index 2a87ef4..556a8d2 100644 --- a/cas_server/urls.py +++ b/cas_server/urls.py @@ -14,7 +14,7 @@ from django.conf.urls import patterns, url from django.views.generic import RedirectView from django.views.decorators.debug import sensitive_post_parameters, sensitive_variables -import views +from cas_server import views urlpatterns = patterns( '', diff --git a/cas_server/utils.py b/cas_server/utils.py index f274dcd..1314e0e 100644 --- a/cas_server/utils.py +++ b/cas_server/utils.py @@ -19,14 +19,15 @@ from django.contrib import messages import random import string import json +import hashlib +import crypt +import base64 +import six +from threading import Thread from importlib import import_module from datetime import datetime, timedelta - -try: - from urlparse import urlparse, urlunparse, parse_qsl - from urllib import urlencode -except ImportError: - from urllib.parse import urlparse, urlunparse, parse_qsl, urlencode +from six.moves import BaseHTTPServer +from six.moves.urllib.parse import urlparse, urlunparse, parse_qsl, urlencode def context(params): @@ -34,7 +35,7 @@ def context(params): return params -def JsonResponse(request, data): +def json_response(request, data): data["messages"] = [] for msg in messages.get_messages(request): data["messages"].append({'message': msg.message, 'level': msg.level_tag}) @@ -120,9 +121,9 @@ def update_url(url, params): query = dict(parse_qsl(url_parts[4])) query.update(params) url_parts[4] = urlencode(query) - for i in range(len(url_parts)): - if not isinstance(url_parts[i], bytes): - url_parts[i] = url_parts[i].encode('utf-8') + for i, url_part in enumerate(url_parts): + if not isinstance(url_part, bytes): + url_parts[i] = url_part.encode('utf-8') return urlunparse(url_parts).decode('utf-8') @@ -190,3 +191,207 @@ def get_tuple(tuple, index, default=None): return tuple[index] except IndexError: return default + +class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler): + PARAMS = {} + + def do_GET(self): + self.send_response(200) + self.send_header(b"Content-type", "text/plain") + self.end_headers() + self.wfile.write(b"ok") + url = urlparse(self.path) + params = dict(parse_qsl(url.query)) + PGTUrlHandler.PARAMS.update(params) + + def log_message(self, *args): + return + + @staticmethod + def run(): + server_class = BaseHTTPServer.HTTPServer + httpd = server_class(("127.0.0.1", 0), PGTUrlHandler) + (host, port) = httpd.socket.getsockname() + + def lauch(): + httpd.handle_request() + httpd.server_close() + + httpd_thread = Thread(target=lauch) + httpd_thread.daemon = True + httpd_thread.start() + return (httpd_thread, host, port) + + +class LdapHashUserPassword(object): + """Please see https://tools.ietf.org/id/draft-stroeder-hashed-userpassword-values-01.html""" + + schemes_salt = {b"{SMD5}", b"{SSHA}", b"{SSHA256}", b"{SSHA384}", b"{SSHA512}", b"{CRYPT}"} + schemes_nosalt = {b"{MD5}", b"{SHA}", b"{SHA256}", b"{SHA384}", b"{SHA512}"} + + _schemes_to_hash = { + b"{SMD5}": hashlib.md5, + b"{MD5}": hashlib.md5, + b"{SSHA}": hashlib.sha1, + b"{SHA}": hashlib.sha1, + b"{SSHA256}": hashlib.sha256, + b"{SHA256}": hashlib.sha256, + b"{SSHA384}": hashlib.sha384, + b"{SHA384}": hashlib.sha384, + b"{SSHA512}": hashlib.sha512, + b"{SHA512}": hashlib.sha512 + } + + _schemes_to_len = { + b"{SMD5}": 16, + b"{SSHA}": 20, + b"{SSHA256}": 32, + b"{SSHA384}": 48, + b"{SSHA512}": 64, + } + + class BadScheme(ValueError): + """Error raised then the hash scheme is not in schemes_salt + schemes_nosalt""" + pass + + class BadHash(ValueError): + """Error raised then the hash is too short""" + pass + + class BadSalt(ValueError): + """Error raised then with the scheme {CRYPT} the salt is invalid""" + pass + + @classmethod + def _raise_bad_scheme(cls, scheme, valid, msg): + """ + Raise BadScheme error for `scheme`, possible valid scheme are + in `valid`, the error message is `msg` + """ + valid_schemes = [s.decode() for s in valid] + valid_schemes.sort() + raise cls.BadScheme(msg % (scheme, u", ".join(valid_schemes))) + + @classmethod + def _test_scheme(cls, scheme): + """Test if a scheme is valide or raise BadScheme""" + if scheme not in cls.schemes_salt and scheme not in cls.schemes_nosalt: + cls._raise_bad_scheme( + scheme, + cls.schemes_salt | cls.schemes_nosalt, + "The scheme %r is not valid. Valide schemes are %s." + ) + + @classmethod + def _test_scheme_salt(cls, scheme): + """Test if the scheme need a salt or raise BadScheme""" + if scheme not in cls.schemes_salt: + cls._raise_bad_scheme( + scheme, + cls.schemes_salt, + "The scheme %r is only valid without a salt. Valide schemes with salt are %s." + ) + + @classmethod + def _test_scheme_nosalt(cls, scheme): + """Test if the scheme need no salt or raise BadScheme""" + if scheme not in cls.schemes_nosalt: + cls._raise_bad_scheme( + scheme, + cls.schemes_nosalt, + "The scheme %r is only valid with a salt. Valide schemes without salt are %s." + ) + + @classmethod + def hash(cls, scheme, password, salt=None, charset="utf8"): + """ + Hash `password` with `scheme` using `salt`. + This three variable beeing encoded in `charset`. + """ + scheme = scheme.upper() + cls._test_scheme(scheme) + if salt is None or salt == b"": + salt = b"" + cls._test_scheme_nosalt(scheme) + elif salt is not None: + cls._test_scheme_salt(scheme) + try: + return scheme + base64.b64encode( + cls._schemes_to_hash[scheme](password + salt).digest() + salt + ) + except KeyError: + if six.PY3: + password = password.decode(charset) + salt = salt.decode(charset) + hashed_password = crypt.crypt(password, salt) + if hashed_password is None: + raise cls.BadSalt("System crypt implementation do not support the salt %r" % salt) + if six.PY3: + hashed_password = hashed_password.encode(charset) + return scheme + hashed_password + + @classmethod + def get_scheme(cls, hashed_passord): + """Return the scheme of `hashed_passord` or raise BadHash""" + if not hashed_passord[0] == b'{'[0] or b'}' not in hashed_passord: + raise cls.BadHash("%r should start with the scheme enclosed with { }" % hashed_passord) + scheme = hashed_passord.split(b'}', 1)[0] + scheme = scheme.upper() + b"}" + return scheme + + @classmethod + def get_salt(cls, hashed_passord): + """Return the salt of `hashed_passord` possibly empty""" + scheme = cls.get_scheme(hashed_passord) + cls._test_scheme(scheme) + if scheme in cls.schemes_nosalt: + return b"" + elif scheme == b'{CRYPT}': + return b'$'.join(hashed_passord.split(b'$', 3)[:-1]) + else: + hashed_passord = base64.b64decode(hashed_passord[len(scheme):]) + if len(hashed_passord) < cls._schemes_to_len[scheme]: + raise cls.BadHash("Hash too short for the scheme %s" % scheme) + return hashed_passord[cls._schemes_to_len[scheme]:] + + +def check_password(method, password, hashed_password, charset): + """ + Check that `password` match `hashed_password` using `method`, + assuming the encoding is `charset`. + """ + if not isinstance(password, six.binary_type): + password = password.encode(charset) + if not isinstance(hashed_password, six.binary_type): + hashed_password = hashed_password.encode(charset) + if method == "plain": + return password == hashed_password + elif method == "crypt": + if hashed_password.startswith(b'$'): + salt = b'$'.join(hashed_password.split(b'$', 3)[:-1]) + elif hashed_password.startswith(b'_'): + salt = hashed_password[:9] + else: + salt = hashed_password[:2] + if six.PY3: + password = password.decode(charset) + salt = salt.decode(charset) + hashed_password = hashed_password.decode(charset) + crypted_password = crypt.crypt(password, salt) + if crypted_password is None: + raise ValueError("System crypt implementation do not support the salt %r" % salt) + return crypted_password == hashed_password + elif method == "ldap": + scheme = LdapHashUserPassword.get_scheme(hashed_password) + salt = LdapHashUserPassword.get_salt(hashed_password) + return LdapHashUserPassword.hash(scheme, password, salt, charset=charset) == hashed_password + elif ( + method.startswith("hex_") and + method[4:] in {"md5", "sha1", "sha224", "sha256", "sha384", "sha512"} + ): + return getattr( + hashlib, + method[4:] + )(password).hexdigest().encode("ascii") == hashed_password.lower() + else: + raise ValueError("Unknown password method check %r" % method) diff --git a/cas_server/views.py b/cas_server/views.py index a6cf5fe..0d41620 100644 --- a/cas_server/views.py +++ b/cas_server/views.py @@ -23,6 +23,7 @@ from django.views.decorators.csrf import csrf_exempt from django.middleware.csrf import CsrfViewMiddleware from django.views.generic import View +import re import logging import pprint import requests @@ -34,7 +35,7 @@ import cas_server.utils as utils import cas_server.forms as forms import cas_server.models as models -from .utils import JsonResponse +from .utils import json_response from .models import ServiceTicket, ProxyTicket, ProxyGrantingTicket from .models import ServicePattern from .federate import CASFederateValidateUser @@ -63,12 +64,12 @@ class AttributesMixin(object): class LogoutMixin(object): """destroy CAS session utils""" - def logout(self, all=False): + def logout(self, all_session=False): """effectively destroy CAS session""" session_nb = 0 username = self.request.session.get("username") if username: - if all: + if all_session: logger.info("Logging out user %s from all of they sessions." % username) else: logger.info("Logging out user %s." % username) @@ -91,8 +92,8 @@ class LogoutMixin(object): # if user not found in database, flush the session anyway self.request.session.flush() - # If all is set logout user from alternative sessions - if all: + # If all_session is set logout user from alternative sessions + if all_session: for user in models.User.objects.filter(username=username): session = SessionStore(session_key=user.session_key) session.flush() @@ -110,6 +111,7 @@ class LogoutView(View, LogoutMixin): service = None def init_get(self, request): + """Initialize GET received parameters""" self.request = request self.service = request.GET.get('service') self.url = request.GET.get('url') @@ -170,13 +172,13 @@ class LogoutView(View, LogoutMixin): 'url': url, 'session_nb': session_nb } - return JsonResponse(request, data) + return json_response(request, data) else: return redirect("cas_server:login") else: if self.ajax: data = {'status': 'success', 'detail': 'logout', 'session_nb': session_nb} - return JsonResponse(request, data) + return json_response(request, data) else: return render( request, @@ -290,12 +292,10 @@ class LoginView(View, LogoutMixin): USER_NOT_AUTHENTICATED = 6 def init_post(self, request): + """Initialize POST received parameters""" self.request = request self.service = request.POST.get('service') - if request.POST.get('renew') and request.POST['renew'] != "False": - self.renew = True - else: - self.renew = False + self.renew = bool(request.POST.get('renew') and request.POST['renew'] != "False") self.gateway = request.POST.get('gateway') self.method = request.POST.get('method') self.ajax = 'HTTP_X_AJAX' in request.META @@ -306,15 +306,19 @@ class LoginView(View, LogoutMixin): self.username = request.POST.get('username') self.ticket = request.POST.get('ticket') - def check_lt(self): - # save LT for later check - lt_valid = self.request.session.get('lt', []) - lt_send = self.request.POST.get('lt') - # generate a new LT (by posting the LT has been consumed) + def gen_lt(self): + """Generate a new LoginTicket and add it to the list of valid LT for the user""" self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()] if len(self.request.session['lt']) > 100: self.request.session['lt'] = self.request.session['lt'][-100:] + def check_lt(self): + """Check is the POSTed LoginTicket is valid, if yes invalide it""" + # save LT for later check + lt_valid = self.request.session.get('lt', []) + lt_send = self.request.POST.get('lt') + # generate a new LT (by posting the LT has been consumed) + self.gen_lt() # check if send LT is valid if lt_valid is None or lt_send not in lt_valid: return False @@ -339,7 +343,7 @@ class LoginView(View, LogoutMixin): username=self.request.session['username'], session_key=self.request.session.session_key ) - self.user.save() + self.user.save() # pragma: no cover (should not happend) except models.User.DoesNotExist: self.user = models.User.objects.create( username=self.request.session['username'], @@ -355,10 +359,15 @@ class LoginView(View, LogoutMixin): elif ret == self.USER_ALREADY_LOGGED: pass else: - raise EnvironmentError("invalid output for LoginView.process_post") + raise EnvironmentError("invalid output for LoginView.process_post") # pragma: no cover return self.common() - def process_post(self, pytest=False): + def process_post(self): + """ + Analyse the POST request: + * check that the LoginTicket is valid + * check that the user sumited credentials are valid + """ if not self.check_lt(): values = self.request.POST.copy() # if not set a new LT and fail @@ -385,12 +394,10 @@ class LoginView(View, LogoutMixin): return self.USER_ALREADY_LOGGED def init_get(self, request): + """Initialize GET received parameters""" self.request = request self.service = request.GET.get('service') - if request.GET.get('renew') and request.GET['renew'] != "False": - self.renew = True - else: - self.renew = False + self.renew = bool(request.GET.get('renew') and request.GET['renew'] != "False") self.gateway = request.GET.get('gateway') self.method = request.GET.get('method') self.ajax = 'HTTP_X_AJAX' in request.META @@ -410,15 +417,16 @@ class LoginView(View, LogoutMixin): return self.common() def process_get(self): - # generate a new LT if none is present - self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()] - + """Analyse the GET request""" + # generate a new LT + self.gen_lt() if not self.request.session.get("authenticated") or self.renew: self.init_form() return self.USER_NOT_AUTHENTICATED return self.USER_AUTHENTICATED def init_form(self, values=None): + """Initialization of the good form depending of POST and GET parameters""" form_initial = { 'service': self.service, 'method': self.method, @@ -459,7 +467,7 @@ class LoginView(View, LogoutMixin): ) if self.ajax: data = {"status": "error", "detail": "confirmation needed"} - return JsonResponse(self.request, data) + return json_response(self.request, data) else: warn_form = forms.WarnForm(initial={ 'service': self.service, @@ -486,7 +494,7 @@ class LoginView(View, LogoutMixin): return HttpResponseRedirect(redirect_url) else: data = {"status": "success", "detail": "auth", "url": redirect_url} - return JsonResponse(self.request, data) + return json_response(self.request, data) except ServicePattern.DoesNotExist: error = 1 messages.add_message( @@ -530,7 +538,7 @@ class LoginView(View, LogoutMixin): ) else: data = {"status": "error", "detail": "auth", "code": error} - return JsonResponse(self.request, data) + return json_response(self.request, data) def authenticated(self): """Processing authenticated users""" @@ -552,7 +560,7 @@ class LoginView(View, LogoutMixin): "detail": "login required", "url": utils.reverse_params("cas_server:login", params=self.request.GET) } - return JsonResponse(self.request, data) + return json_response(self.request, data) else: return utils.redirect_params("cas_server:login", params=self.request.GET) @@ -562,7 +570,7 @@ class LoginView(View, LogoutMixin): else: if self.ajax: data = {"status": "success", "detail": "logged"} - return JsonResponse(self.request, data) + return json_response(self.request, data) else: return render( self.request, @@ -605,7 +613,7 @@ class LoginView(View, LogoutMixin): "detail": "login required", "url": utils.reverse_params("cas_server:login", params=self.request.GET) } - return JsonResponse(self.request, data) + return json_response(self.request, data) else: if settings.CAS_FEDERATE: if self.username and self.ticket: @@ -824,7 +832,10 @@ class ValidateService(View, AttributesMixin): params['username'] = self.ticket.user.attributs.get( self.ticket.service_pattern.user_field ) - if self.pgt_url and self.pgt_url.startswith("https://"): + if self.pgt_url and ( + self.pgt_url.startswith("https://") or + re.match("^http://(127\.0\.0\.1|localhost)(:[0-9]+)?(/.*)?$", self.pgt_url) + ): return self.process_pgturl(params) else: logger.info( diff --git a/run_tests b/run_tests new file mode 100755 index 0000000..4ea21ee --- /dev/null +++ b/run_tests @@ -0,0 +1,22 @@ +#!/usr/bin/env python +import os, sys +import django +from django.conf import settings + +import settings_tests + +settings.configure(**settings_tests.__dict__) +django.setup() + +try: + # Django <= 1.8 + from django.test.simple import DjangoTestSuiteRunner + test_runner = DjangoTestSuiteRunner(verbosity=1) +except ImportError: + # Django >= 1.8 + from django.test.runner import DiscoverRunner + test_runner = DiscoverRunner(verbosity=1) + +failures = test_runner.run_tests(['cas_server']) +if failures: + sys.exit(failures) diff --git a/settings_tests.py b/settings_tests.py new file mode 100644 index 0000000..e1c0558 --- /dev/null +++ b/settings_tests.py @@ -0,0 +1,83 @@ +""" +Django test settings for cas_server application. + +Generated by 'django-admin startproject' using Django 1.9.7. + +For more information on this file, see +https://docs.djangoproject.com/en/1.9/topics/settings/ + +For the full list of settings and their values, see +https://docs.djangoproject.com/en/1.9/ref/settings/ +""" + +import os + +# Build paths inside the project like this: os.path.join(BASE_DIR, ...) +BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + + +# Quick-start development settings - unsuitable for production +# See https://docs.djangoproject.com/en/1.9/howto/deployment/checklist/ + +# SECURITY WARNING: keep the secret key used in production secret! +SECRET_KEY = 'changeme' + +# SECURITY WARNING: don't run with debug turned on in production! +DEBUG = True + +ALLOWED_HOSTS = [] + + +# Application definition + +INSTALLED_APPS = [ + 'django.contrib.admin', + 'django.contrib.auth', + 'django.contrib.contenttypes', + 'django.contrib.sessions', + 'django.contrib.messages', + 'django.contrib.staticfiles', + 'bootstrap3', + 'cas_server', +] + +MIDDLEWARE_CLASSES = [ + 'django.contrib.sessions.middleware.SessionMiddleware', + 'django.middleware.common.CommonMiddleware', + 'django.middleware.csrf.CsrfViewMiddleware', + 'django.contrib.auth.middleware.AuthenticationMiddleware', + 'django.contrib.auth.middleware.SessionAuthenticationMiddleware', + 'django.contrib.messages.middleware.MessageMiddleware', + 'django.middleware.clickjacking.XFrameOptionsMiddleware', + 'django.middleware.locale.LocaleMiddleware', +] + +ROOT_URLCONF = 'urls_tests' + +# Database +# https://docs.djangoproject.com/en/1.9/ref/settings/#databases + +DATABASES = { + 'default': { + 'ENGINE': 'django.db.backends.sqlite3', + } +} + +# Internationalization +# https://docs.djangoproject.com/en/1.9/topics/i18n/ + +LANGUAGE_CODE = 'en-us' + +TIME_ZONE = 'UTC' + +USE_I18N = True + +USE_L10N = True + +USE_TZ = True + + +# Static files (CSS, JavaScript, Images) +# https://docs.djangoproject.com/en/1.9/howto/static-files/ + +STATIC_URL = '/static/' diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/dummy.py b/tests/dummy.py deleted file mode 100644 index 8266d7b..0000000 --- a/tests/dummy.py +++ /dev/null @@ -1,136 +0,0 @@ -import functools -from cas_server import models - -class DummyUserManager(object): - def __init__(self, username, session_key): - self.username = username - self.session_key = session_key - def get(self, username=None, session_key=None): - if username == self.username and session_key == self.session_key: - return models.User(username=username, session_key=session_key) - else: - raise models.User.DoesNotExist() - - -def dummy(*args, **kwds): - pass - -def dummy_service_pattern(**kwargs): - def decorator(func): - @functools.wraps(func) - def wrapper(*args, **kwds): - service_validate = models.ServicePattern.validate - models.ServicePattern.validate = classmethod(lambda x,y: models.ServicePattern(**kwargs)) - ret = func(*args, **kwds) - models.ServicePattern.validate = service_validate - return ret - return wrapper - return decorator - -def dummy_user(username, session_key): - def decorator(func): - @functools.wraps(func) - def wrapper(*args, **kwds): - user_manager = models.User.objects - user_save = models.User.save - user_delete = models.User.delete - models.User.objects = DummyUserManager(username, session_key) - models.User.save = dummy - models.User.delete = dummy - ret = func(*args, **kwds) - models.User.objects = user_manager - models.User.save = user_save - models.User.delete = user_delete - return ret - return wrapper - return decorator - -def dummy_ticket(ticket_class, service, ticket): - def decorator(func): - @functools.wraps(func) - def wrapper(*args, **kwds): - ticket_manager = ticket_class.objects - ticket_save = ticket_class.save - ticket_delete = ticket_class.delete - ticket_class.objects = DummyTicketManager(ticket_class, service, ticket) - ticket_class.save = dummy - ticket_class.delete = dummy - ret = func(*args, **kwds) - ticket_class.objects = ticket_manager - ticket_class.save = ticket_save - ticket_class.delete = ticket_delete - return ret - return wrapper - return decorator - - -def dummy_proxy(func): - @functools.wraps(func) - def wrapper(*args, **kwds): - proxy_manager = models.Proxy.objects - models.Proxy.objects = DummyProxyManager() - ret = func(*args, **kwds) - models.Proxy.objects = proxy_manager - return ret - return wrapper - -class DummyProxyManager(object): - def create(self, **kwargs): - for field in models.Proxy._meta.fields: - field.allow_unsaved_instance_assignment = True - return models.Proxy(**kwargs) - -class DummyTicketManager(object): - def __init__(self, ticket_class, service, ticket): - self.ticket_class = ticket_class - self.service = service - self.ticket = ticket - - def create(self, **kwargs): - for field in self.ticket_class._meta.fields: - field.allow_unsaved_instance_assignment = True - return self.ticket_class(**kwargs) - - def filter(self, *args, **kwargs): - return DummyQuerySet() - - def get(self, **kwargs): - for field in self.ticket_class._meta.fields: - field.allow_unsaved_instance_assignment = True - if 'value' in kwargs: - if kwargs['value'] != self.ticket: - raise self.ticket_class.DoesNotExist() - else: - kwargs['value'] = self.ticket - - if 'service' in kwargs: - if kwargs['service'] != self.service: - raise self.ticket_class.DoesNotExist() - else: - kwargs['service'] = self.service - if not 'user' in kwargs: - kwargs['user'] = models.User(username="test") - - for field in models.ServiceTicket._meta.fields: - field.allow_unsaved_instance_assignment = True - for key in list(kwargs): - if '__' in key: - del kwargs[key] - kwargs['attributs'] = {'mail': 'test@example.com'} - kwargs['service_pattern'] = models.ServicePattern() - return self.ticket_class(**kwargs) - - - -class DummySession(dict): - session_key = "test_session" - - def set_expiry(self, int): - pass - - def flush(self): - self.clear() - - -class DummyQuerySet(set): - pass diff --git a/tests/init.py b/tests/init.py deleted file mode 100644 index f6ede9e..0000000 --- a/tests/init.py +++ /dev/null @@ -1,32 +0,0 @@ -import django -from django.conf import settings -from django.contrib import messages - -settings.configure() -settings.STATIC_URL = "/static/" -settings.DATABASES = { - 'default': { - 'ENGINE': 'django.db.backends.sqlite3', - 'NAME': '/dev/null', - } -} -settings.INSTALLED_APPS = ( - 'django.contrib.admin', - 'django.contrib.auth', - 'django.contrib.contenttypes', - 'django.contrib.sessions', - 'django.contrib.messages', - 'django.contrib.staticfiles', - 'bootstrap3', - 'cas_server', -) - -settings.ROOT_URLCONF = "/" -settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' - -try: - django.setup() -except AttributeError: - pass -messages.add_message = lambda x,y,z:None - diff --git a/tests/test_proxy.py b/tests/test_proxy.py deleted file mode 100644 index 963d834..0000000 --- a/tests/test_proxy.py +++ /dev/null @@ -1,52 +0,0 @@ -from __future__ import absolute_import -from tests.init import * - -from django.test import RequestFactory - -import os -import pytest -from lxml import etree -from cas_server.views import ValidateService, Proxy -from cas_server import models - -from tests.dummy import * - -@pytest.mark.django_db -@dummy_ticket(models.ProxyGrantingTicket, '', "PGT-random") -@dummy_service_pattern(proxy=True) -@dummy_user(username="test", session_key="test_session") -@dummy_ticket(models.ProxyTicket, "https://www.example.com", "PT-random") -@dummy_proxy -def test_proxy_ok(): - factory = RequestFactory() - request = factory.get('/proxy?pgt=PGT-random&targetService=https://www.example.com') - - request.session = DummySession() - - proxy = Proxy() - response = proxy.get(request) - - assert response.status_code == 200 - - root = etree.fromstring(response.content) - proxy_tickets = root.xpath("//cas:proxyTicket", namespaces={'cas': "http://www.yale.edu/tp/cas"}) - - assert len(proxy_tickets) == 1 - - factory = RequestFactory() - request = factory.get('/proxyValidate?ticket=PT-random&service=https://www.example.com') - - validate = ValidateService() - validate.allow_proxy_ticket = True - response = validate.get(request) - - assert response.status_code == 200 - - root = etree.fromstring(response.content) - users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"}) - - assert len(users) == 1 - assert users[0].text == "test" - - - diff --git a/tests/test_validate_service.py b/tests/test_validate_service.py deleted file mode 100644 index 940e23b..0000000 --- a/tests/test_validate_service.py +++ /dev/null @@ -1,87 +0,0 @@ -from __future__ import absolute_import -from .init import * - -from django.test import RequestFactory - -import os -import pytest -from lxml import etree -from cas_server.views import ValidateService -from cas_server import models - -from .dummy import * - -@pytest.mark.django_db -@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random") -def test_validate_service_view_ok(): - factory = RequestFactory() - request = factory.get('/serviceValidate?ticket=ST-random&service=https://www.example.com') - - request.session = DummySession() - - validate = ValidateService() - validate.allow_proxy_ticket = False - response = validate.get(request) - - assert response.status_code == 200 - - root = etree.fromstring(response.content) - users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"}) - - assert len(users) == 1 - assert users[0].text == "test" - - attributes = root.xpath("//cas:attributes", namespaces={'cas': "http://www.yale.edu/tp/cas"}) - - assert len(attributes) == 1 - - attrs = {} - for attr in attributes[0]: - attrs[attr.tag[len("http://www.yale.edu/tp/cas")+2:]]=attr.text - - assert 'mail' in attrs - assert attrs['mail'] == 'test@example.com' - - - -@pytest.mark.django_db -@dummy_ticket(models.ServiceTicket, 'https://www.example2.com', "ST-random") -def test_validate_service_view_badservice(): - factory = RequestFactory() - request = factory.get('/serviceValidate?ticket=ST-random&service=https://www.example1.com') - - request.session = DummySession() - - validate = ValidateService() - validate.allow_proxy_ticket = False - response = validate.get(request) - - assert response.status_code == 200 - - root = etree.fromstring(response.content) - - error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"}) - - assert len(error) == 1 - assert error[0].attrib['code'] == 'INVALID_SERVICE' - -@pytest.mark.django_db -@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random2") -def test_validate_service_view_badticket(): - factory = RequestFactory() - request = factory.get('/serviceValidate?ticket=ST-random1&service=https://www.example.com') - - request.session = DummySession() - - validate = ValidateService() - validate.allow_proxy_ticket = False - response = validate.get(request) - - assert response.status_code == 200 - - root = etree.fromstring(response.content) - - error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"}) - - assert len(error) == 1 - assert error[0].attrib['code'] == 'INVALID_TICKET' diff --git a/tests/test_views_auth.py b/tests/test_views_auth.py deleted file mode 100644 index 4b4a9eb..0000000 --- a/tests/test_views_auth.py +++ /dev/null @@ -1,46 +0,0 @@ -from __future__ import absolute_import -from .init import * - -from django.test import RequestFactory - -import os -import pytest - -from cas_server.views import Auth -from cas_server import models - -from .dummy import * - -settings.CAS_AUTH_SHARED_SECRET = "test" - -@pytest.mark.django_db -@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random") -@dummy_user(username="test", session_key="test_session") -@dummy_service_pattern() -def test_auth_view_goodpass(): - factory = RequestFactory() - request = factory.post('/auth', {'username':'test', 'password':'test', 'service':'https://www.example.com', 'secret':'test'}) - - request.session = DummySession() - - auth = Auth() - response = auth.post(request) - - assert response.status_code == 200 - assert response.content == b"yes\n" - -@dummy_service_pattern() -@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random") -@dummy_user(username="test", session_key="test_session") -def test_auth_view_badpass(): - factory = RequestFactory() - request = factory.post('/auth', {'username':'test', 'password':'badpass', 'service':'https://www.example.com', 'secret':'test'}) - - request.session = DummySession() - - auth = Auth() - response = auth.post(request) - - assert response.status_code == 200 - assert response.content == b"no\n" - diff --git a/tests/test_views_login.py b/tests/test_views_login.py deleted file mode 100644 index 6aabe80..0000000 --- a/tests/test_views_login.py +++ /dev/null @@ -1,163 +0,0 @@ -from __future__ import absolute_import -from .init import * - -from django.test import RequestFactory - -import os -import pytest - -from cas_server.views import LoginView -from cas_server import models - -from .dummy import * - - - -def test_login_view_post_goodpass_goodlt(): - factory = RequestFactory() - request = factory.post('/login', {'username':'test', 'password':'test', 'lt':'LT-random'}) - request.session = DummySession() - - request.session['lt'] = ['LT-random'] - - request.session["username"] = os.urandom(20) - request.session["warn"] = os.urandom(20) - - login = LoginView() - login.init_post(request) - - ret = login.process_post(pytest=True) - - assert ret == LoginView.USER_LOGIN_OK - assert request.session.get("authenticated") == True - assert request.session.get("username") == "test" - assert request.session.get("warn") == False - -def test_login_view_post_badlt(): - factory = RequestFactory() - request = factory.post('/login', {'username':'test', 'password':'test', 'lt':'LT-random1'}) - request.session = DummySession() - - request.session['lt'] = ['LT-random2'] - - authenticated = os.urandom(20) - username = os.urandom(20) - warn = os.urandom(20) - - request.session["authenticated"] = authenticated - request.session["username"] = username - request.session["warn"] = warn - - login = LoginView() - login.init_post(request) - - ret = login.process_post(pytest=True) - - assert ret == LoginView.INVALID_LOGIN_TICKET - assert request.session.get("authenticated") == authenticated - assert request.session.get("username") == username - assert request.session.get("warn") == warn - -def test_login_view_post_badpass_good_lt(): - factory = RequestFactory() - request = factory.post('/login', {'username':'test', 'password':'badpassword', 'lt':'LT-random'}) - request.session = DummySession() - - request.session['lt'] = ['LT-random'] - - login = LoginView() - login.init_post(request) - ret = login.process_post() - - assert ret == LoginView.USER_LOGIN_FAILURE - assert not request.session.get("authenticated") - assert not request.session.get("username") - assert not request.session.get("warn") - - -def test_view_login_get_unauth(): - factory = RequestFactory() - request = factory.post('/login') - request.session = DummySession() - - login = LoginView() - login.init_get(request) - ret = login.process_get() - - assert ret == LoginView.USER_NOT_AUTHENTICATED - - login = LoginView() - response = login.get(request) - - assert response.status_code == 200 - -@pytest.mark.django_db -@dummy_user(username="test", session_key="test_session") -def test_view_login_get_auth(): - factory = RequestFactory() - request = factory.post('/login') - request.session = DummySession() - - request.session["authenticated"] = True - request.session["username"] = "test" - request.session["warn"] = False - - login = LoginView() - login.init_get(request) - ret = login.process_get() - - assert ret == LoginView.USER_AUTHENTICATED - - login = LoginView() - response = login.get(request) - - assert response.status_code == 200 - -@pytest.mark.django_db -@dummy_service_pattern() -@dummy_user(username="test", session_key="test_session") -@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random") -def test_view_login_get_auth_service(): - factory = RequestFactory() - request = factory.post('/login?service=https://www.example.com') - request.session = DummySession() - - request.session["authenticated"] = True - request.session["username"] = "test" - request.session["warn"] = False - - login = LoginView() - login.init_get(request) - ret = login.process_get() - - assert ret == LoginView.USER_AUTHENTICATED - - login = LoginView() - response = login.get(request) - - assert response.status_code == 302 - assert response['Location'].startswith('https://www.example.com?ticket=ST-') - -@pytest.mark.django_db -@dummy_service_pattern() -@dummy_user(username="test", session_key="test_session") -@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random") -def test_view_login_get_auth_service_warn(): - factory = RequestFactory() - request = factory.post('/login?service=https://www.example.com') - request.session = DummySession() - - request.session["authenticated"] = True - request.session["username"] = "test" - request.session["warn"] = True - - login = LoginView() - login.init_get(request) - ret = login.process_get() - - assert ret == LoginView.USER_AUTHENTICATED - - login = LoginView() - response = login.get(request) - - assert response.status_code == 200 diff --git a/tests/test_views_logout.py b/tests/test_views_logout.py deleted file mode 100644 index 03410bd..0000000 --- a/tests/test_views_logout.py +++ /dev/null @@ -1,80 +0,0 @@ -from __future__ import absolute_import -from .init import * - -from django.test import RequestFactory - -import os -import pytest - -from cas_server.views import LogoutView -from cas_server import models - -from .dummy import * - - -@pytest.mark.django_db -@dummy_user(username="test", session_key="test_session") -def test_logout_view(): - factory = RequestFactory() - request = factory.get('/logout') - - request.session = DummySession() - - request.session["authenticated"] = True - request.session["username"] = "test" - request.session["warn"] = False - - logout = LogoutView() - response = logout.get(request) - - assert response.status_code == 200 - assert not request.session.get("authenticated") - assert not request.session.get("username") - assert not request.session.get("warn") - - -@pytest.mark.django_db -@dummy_user(username="test", session_key="test_session") -def test_logout_view_url(): - factory = RequestFactory() - request = factory.get('/logout?url=https://www.example.com') - - request.session = DummySession() - - request.session["authenticated"] = True - request.session["username"] = "test" - request.session["warn"] = False - - logout = LogoutView() - response = logout.get(request) - - assert response.status_code == 302 - assert response['Location'] == 'https://www.example.com' - assert not request.session.get("authenticated") - assert not request.session.get("username") - assert not request.session.get("warn") - - - -@pytest.mark.django_db -@dummy_user(username="test", session_key="test_session") -def test_logout_view_service(): - factory = RequestFactory() - request = factory.get('/logout?service=https://www.example.com') - - request.session = DummySession() - - request.session["authenticated"] = True - request.session["username"] = "test" - request.session["warn"] = False - - logout = LogoutView() - response = logout.get(request) - - assert response.status_code == 302 - assert response['Location'] == 'https://www.example.com' - assert not request.session.get("authenticated") - assert not request.session.get("username") - assert not request.session.get("warn") - - diff --git a/tests/test_views_validate.py b/tests/test_views_validate.py deleted file mode 100644 index 201387f..0000000 --- a/tests/test_views_validate.py +++ /dev/null @@ -1,58 +0,0 @@ -from __future__ import absolute_import -from .init import * - -from django.test import RequestFactory - -import os -import pytest - -from cas_server.views import Validate -from cas_server import models - -from .dummy import * - -@pytest.mark.django_db -@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random") -def test_validate_view_ok(): - factory = RequestFactory() - request = factory.get('/validate?ticket=ST-random&service=https://www.example.com') - - request.session = DummySession() - - validate = Validate() - response = validate.get(request) - - assert response.status_code == 200 - assert response.content == b"yes\ntest\n" - - - -@pytest.mark.django_db -@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random") -def test_validate_view_badservice(): - factory = RequestFactory() - request = factory.get('/validate?ticket=ST-random&service=https://www.example2.com') - - request.session = DummySession() - - validate = Validate() - response = validate.get(request) - - assert response.status_code == 200 - assert response.content == b"no\n" - - - -@pytest.mark.django_db -@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random1") -def test_validate_view_badticket(): - factory = RequestFactory() - request = factory.get('/validate?ticket=ST-random2&service=https://www.example.com') - - request.session = DummySession() - - validate = Validate() - response = validate.get(request) - - assert response.status_code == 200 - assert response.content == b"no\n" diff --git a/tox.ini b/tox.ini index 997620a..0b65c56 100644 --- a/tox.ini +++ b/tox.ini @@ -17,7 +17,7 @@ deps = -r{toxinidir}/requirements-dev.txt [testenv] -commands=py.test --tb native {posargs:tests} +commands=python run_tests {posargs:tests} [testenv:py27-django17] basepython=python2.7 diff --git a/urls_tests.py b/urls_tests.py new file mode 100644 index 0000000..a9ed25c --- /dev/null +++ b/urls_tests.py @@ -0,0 +1,22 @@ +"""cas URL Configuration + +The `urlpatterns` list routes URLs to views. For more information please see: + https://docs.djangoproject.com/en/1.9/topics/http/urls/ +Examples: +Function views + 1. Add an import: from my_app import views + 2. Add a URL to urlpatterns: url(r'^$', views.home, name='home') +Class-based views + 1. Add an import: from other_app.views import Home + 2. Add a URL to urlpatterns: url(r'^$', Home.as_view(), name='home') +Including another URLconf + 1. Import the include() function: from django.conf.urls import url, include, include + 2. Add a URL to urlpatterns: url(r'^blog/', include('blog.urls')) +""" +from django.conf.urls import url, include +from django.contrib import admin + +urlpatterns = [ + url(r'^admin/', admin.site.urls), + url(r'^', include('cas_server.urls', namespace='cas_server')), +]