Full coverage for saml + split tests
This commit is contained in:
		
							
								
								
									
										145
									
								
								cas_server/tests/mixin.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										145
									
								
								cas_server/tests/mixin.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,145 @@
 | 
			
		||||
"""Some mixin classes for tests"""
 | 
			
		||||
from cas_server.default_settings import settings
 | 
			
		||||
 | 
			
		||||
import re
 | 
			
		||||
from lxml import etree
 | 
			
		||||
 | 
			
		||||
from cas_server import models
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BaseServicePattern(object):
 | 
			
		||||
    """Mixing for setting up service pattern for testing"""
 | 
			
		||||
    def setup_service_patterns(self, proxy=False):
 | 
			
		||||
        """setting up service pattern"""
 | 
			
		||||
        # For general purpose testing
 | 
			
		||||
        self.service = "https://www.example.com"
 | 
			
		||||
        self.service_pattern = models.ServicePattern.objects.create(
 | 
			
		||||
            name="example",
 | 
			
		||||
            pattern="^https://www\.example\.com(/.*)?$",
 | 
			
		||||
            proxy=proxy,
 | 
			
		||||
        )
 | 
			
		||||
        models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
 | 
			
		||||
 | 
			
		||||
        # For testing the restrict_users attributes
 | 
			
		||||
        self.service_restrict_user_fail = "https://restrict_user_fail.example.com"
 | 
			
		||||
        self.service_pattern_restrict_user_fail = models.ServicePattern.objects.create(
 | 
			
		||||
            name="restrict_user_fail",
 | 
			
		||||
            pattern="^https://restrict_user_fail\.example\.com(/.*)?$",
 | 
			
		||||
            restrict_users=True,
 | 
			
		||||
            proxy=proxy,
 | 
			
		||||
        )
 | 
			
		||||
        self.service_restrict_user_success = "https://restrict_user_success.example.com"
 | 
			
		||||
        self.service_pattern_restrict_user_success = models.ServicePattern.objects.create(
 | 
			
		||||
            name="restrict_user_success",
 | 
			
		||||
            pattern="^https://restrict_user_success\.example\.com(/.*)?$",
 | 
			
		||||
            restrict_users=True,
 | 
			
		||||
            proxy=proxy,
 | 
			
		||||
        )
 | 
			
		||||
        models.Username.objects.create(
 | 
			
		||||
            value=settings.CAS_TEST_USER,
 | 
			
		||||
            service_pattern=self.service_pattern_restrict_user_success
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # For testing the user attributes filtering conditions
 | 
			
		||||
        self.service_filter_fail = "https://filter_fail.example.com"
 | 
			
		||||
        self.service_pattern_filter_fail = models.ServicePattern.objects.create(
 | 
			
		||||
            name="filter_fail",
 | 
			
		||||
            pattern="^https://filter_fail\.example\.com(/.*)?$",
 | 
			
		||||
            proxy=proxy,
 | 
			
		||||
        )
 | 
			
		||||
        models.FilterAttributValue.objects.create(
 | 
			
		||||
            attribut="right",
 | 
			
		||||
            pattern="^admin$",
 | 
			
		||||
            service_pattern=self.service_pattern_filter_fail
 | 
			
		||||
        )
 | 
			
		||||
        self.service_filter_success = "https://filter_success.example.com"
 | 
			
		||||
        self.service_pattern_filter_success = models.ServicePattern.objects.create(
 | 
			
		||||
            name="filter_success",
 | 
			
		||||
            pattern="^https://filter_success\.example\.com(/.*)?$",
 | 
			
		||||
            proxy=proxy,
 | 
			
		||||
        )
 | 
			
		||||
        models.FilterAttributValue.objects.create(
 | 
			
		||||
            attribut="email",
 | 
			
		||||
            pattern="^%s$" % re.escape(settings.CAS_TEST_ATTRIBUTES['email']),
 | 
			
		||||
            service_pattern=self.service_pattern_filter_success
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # For testing the user_field attributes
 | 
			
		||||
        self.service_field_needed_fail = "https://field_needed_fail.example.com"
 | 
			
		||||
        self.service_pattern_field_needed_fail = models.ServicePattern.objects.create(
 | 
			
		||||
            name="field_needed_fail",
 | 
			
		||||
            pattern="^https://field_needed_fail\.example\.com(/.*)?$",
 | 
			
		||||
            user_field="uid",
 | 
			
		||||
            proxy=proxy,
 | 
			
		||||
        )
 | 
			
		||||
        self.service_field_needed_success = "https://field_needed_success.example.com"
 | 
			
		||||
        self.service_pattern_field_needed_success = models.ServicePattern.objects.create(
 | 
			
		||||
            name="field_needed_success",
 | 
			
		||||
            pattern="^https://field_needed_success\.example\.com(/.*)?$",
 | 
			
		||||
            user_field="alias",
 | 
			
		||||
            proxy=proxy,
 | 
			
		||||
        )
 | 
			
		||||
        self.service_field_needed_success_alt = "https://field_needed_success_alt.example.com"
 | 
			
		||||
        self.service_pattern_field_needed_success = models.ServicePattern.objects.create(
 | 
			
		||||
            name="field_needed_success_alt",
 | 
			
		||||
            pattern="^https://field_needed_success_alt\.example\.com(/.*)?$",
 | 
			
		||||
            user_field="nom",
 | 
			
		||||
            proxy=proxy,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class XmlContent(object):
 | 
			
		||||
    """Mixin for test on CAS XML responses"""
 | 
			
		||||
    def assert_error(self, response, code, text=None):
 | 
			
		||||
        """Assert a validation error"""
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        root = etree.fromstring(response.content)
 | 
			
		||||
        error = root.xpath(
 | 
			
		||||
            "//cas:authenticationFailure",
 | 
			
		||||
            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(len(error), 1)
 | 
			
		||||
        self.assertEqual(error[0].attrib['code'], code)
 | 
			
		||||
        if text is not None:
 | 
			
		||||
            self.assertEqual(error[0].text, text)
 | 
			
		||||
 | 
			
		||||
    def assert_success(self, response, username, original_attributes):
 | 
			
		||||
        """assert a ticket validation success"""
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
 | 
			
		||||
        root = etree.fromstring(response.content)
 | 
			
		||||
        sucess = root.xpath(
 | 
			
		||||
            "//cas:authenticationSuccess",
 | 
			
		||||
            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
			
		||||
        )
 | 
			
		||||
        self.assertTrue(sucess)
 | 
			
		||||
 | 
			
		||||
        users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
 | 
			
		||||
        self.assertEqual(len(users), 1)
 | 
			
		||||
        self.assertEqual(users[0].text, username)
 | 
			
		||||
 | 
			
		||||
        attributes = root.xpath(
 | 
			
		||||
            "//cas:attributes",
 | 
			
		||||
            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(len(attributes), 1)
 | 
			
		||||
        attrs1 = set()
 | 
			
		||||
        for attr in attributes[0]:
 | 
			
		||||
            attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text))
 | 
			
		||||
 | 
			
		||||
        attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
 | 
			
		||||
        self.assertEqual(len(attributes), len(attrs1))
 | 
			
		||||
        attrs2 = set()
 | 
			
		||||
        for attr in attributes:
 | 
			
		||||
            attrs2.add((attr.attrib['name'], attr.attrib['value']))
 | 
			
		||||
        original = set()
 | 
			
		||||
        for key, value in original_attributes.items():
 | 
			
		||||
            if isinstance(value, list):
 | 
			
		||||
                for sub_value in value:
 | 
			
		||||
                    original.add((key, sub_value))
 | 
			
		||||
            else:
 | 
			
		||||
                original.add((key, value))
 | 
			
		||||
        self.assertEqual(attrs1, attrs2)
 | 
			
		||||
        self.assertEqual(attrs1, original)
 | 
			
		||||
 | 
			
		||||
        return root
 | 
			
		||||
							
								
								
									
										67
									
								
								cas_server/tests/test_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								cas_server/tests/test_utils.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,67 @@
 | 
			
		||||
from django.test import TestCase
 | 
			
		||||
 | 
			
		||||
import six
 | 
			
		||||
 | 
			
		||||
from cas_server import utils
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CheckPasswordCase(TestCase):
 | 
			
		||||
    """Tests for the utils function `utils.check_password`"""
 | 
			
		||||
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        """Generate random bytes string that will be used ass passwords"""
 | 
			
		||||
        self.password1 = utils.gen_saml_id()
 | 
			
		||||
        self.password2 = utils.gen_saml_id()
 | 
			
		||||
        if not isinstance(self.password1, bytes):  # pragma: no cover executed only in python3
 | 
			
		||||
            self.password1 = self.password1.encode("utf8")
 | 
			
		||||
            self.password2 = self.password2.encode("utf8")
 | 
			
		||||
 | 
			
		||||
    def test_setup(self):
 | 
			
		||||
        """check that generated password are bytes"""
 | 
			
		||||
        self.assertIsInstance(self.password1, bytes)
 | 
			
		||||
        self.assertIsInstance(self.password2, bytes)
 | 
			
		||||
 | 
			
		||||
    def test_plain(self):
 | 
			
		||||
        """test the plain auth method"""
 | 
			
		||||
        self.assertTrue(utils.check_password("plain", self.password1, self.password1, "utf8"))
 | 
			
		||||
        self.assertFalse(utils.check_password("plain", self.password1, self.password2, "utf8"))
 | 
			
		||||
 | 
			
		||||
    def test_crypt(self):
 | 
			
		||||
        """test the crypt auth method"""
 | 
			
		||||
        if six.PY3:
 | 
			
		||||
            hashed_password1 = utils.crypt.crypt(
 | 
			
		||||
                self.password1.decode("utf8"),
 | 
			
		||||
                "$6$UVVAQvrMyXMF3FF3"
 | 
			
		||||
            ).encode("utf8")
 | 
			
		||||
        else:
 | 
			
		||||
            hashed_password1 = utils.crypt.crypt(self.password1, "$6$UVVAQvrMyXMF3FF3")
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(utils.check_password("crypt", self.password1, hashed_password1, "utf8"))
 | 
			
		||||
        self.assertFalse(utils.check_password("crypt", self.password2, hashed_password1, "utf8"))
 | 
			
		||||
 | 
			
		||||
    def test_ldap_ssha(self):
 | 
			
		||||
        """test the ldap auth method with a {SSHA} scheme"""
 | 
			
		||||
        salt = b"UVVAQvrMyXMF3FF3"
 | 
			
		||||
        hashed_password1 = utils.LdapHashUserPassword.hash(b'{SSHA}', self.password1, salt, "utf8")
 | 
			
		||||
 | 
			
		||||
        self.assertIsInstance(hashed_password1, bytes)
 | 
			
		||||
        self.assertTrue(utils.check_password("ldap", self.password1, hashed_password1, "utf8"))
 | 
			
		||||
        self.assertFalse(utils.check_password("ldap", self.password2, hashed_password1, "utf8"))
 | 
			
		||||
 | 
			
		||||
    def test_hex_md5(self):
 | 
			
		||||
        """test the hex_md5 auth method"""
 | 
			
		||||
        hashed_password1 = utils.hashlib.md5(self.password1).hexdigest()
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(utils.check_password("hex_md5", self.password1, hashed_password1, "utf8"))
 | 
			
		||||
        self.assertFalse(utils.check_password("hex_md5", self.password2, hashed_password1, "utf8"))
 | 
			
		||||
 | 
			
		||||
    def test_hex_sha512(self):
 | 
			
		||||
        """test the hex_sha512 auth method"""
 | 
			
		||||
        hashed_password1 = utils.hashlib.sha512(self.password1).hexdigest()
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            utils.check_password("hex_sha512", self.password1, hashed_password1, "utf8")
 | 
			
		||||
        )
 | 
			
		||||
        self.assertFalse(
 | 
			
		||||
            utils.check_password("hex_sha512", self.password2, hashed_password1, "utf8")
 | 
			
		||||
        )
 | 
			
		||||
@@ -1,12 +1,12 @@
 | 
			
		||||
"""Tests module"""
 | 
			
		||||
"""Tests module for views"""
 | 
			
		||||
from cas_server.default_settings import settings
 | 
			
		||||
 | 
			
		||||
import django
 | 
			
		||||
from django.test import TestCase, Client
 | 
			
		||||
from django.test.utils import override_settings
 | 
			
		||||
from django.utils import timezone
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import re
 | 
			
		||||
import six
 | 
			
		||||
import random
 | 
			
		||||
import json
 | 
			
		||||
from lxml import etree
 | 
			
		||||
@@ -14,203 +14,15 @@ from six.moves import range
 | 
			
		||||
 | 
			
		||||
from cas_server import models
 | 
			
		||||
from cas_server import utils
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def copy_form(form):
 | 
			
		||||
    """Copy form value into a dict"""
 | 
			
		||||
    params = {}
 | 
			
		||||
    for field in form:
 | 
			
		||||
        if field.value():
 | 
			
		||||
            params[field.name] = field.value()
 | 
			
		||||
        else:
 | 
			
		||||
            params[field.name] = ""
 | 
			
		||||
    return params
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_login_page_params(client=None):
 | 
			
		||||
    """Return a client and the POST params for the client to login"""
 | 
			
		||||
    if client is None:
 | 
			
		||||
        client = Client()
 | 
			
		||||
    response = client.get('/login')
 | 
			
		||||
    params = copy_form(response.context["form"])
 | 
			
		||||
    return client, params
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_auth_client(**update):
 | 
			
		||||
    """return a authenticated client"""
 | 
			
		||||
    client, params = get_login_page_params()
 | 
			
		||||
    params["username"] = settings.CAS_TEST_USER
 | 
			
		||||
    params["password"] = settings.CAS_TEST_PASSWORD
 | 
			
		||||
    params.update(update)
 | 
			
		||||
 | 
			
		||||
    client.post('/login', params)
 | 
			
		||||
    return client
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_user_ticket_request(service):
 | 
			
		||||
    """Make an auth client to request a ticket for `service`, return the tuple (user, ticket)"""
 | 
			
		||||
    client = get_auth_client()
 | 
			
		||||
    response = client.get("/login", {"service": service})
 | 
			
		||||
    ticket_value = response['Location'].split('ticket=')[-1]
 | 
			
		||||
    user = models.User.objects.get(
 | 
			
		||||
        username=settings.CAS_TEST_USER,
 | 
			
		||||
        session_key=client.session.session_key
 | 
			
		||||
    )
 | 
			
		||||
    ticket = models.ServiceTicket.objects.get(value=ticket_value)
 | 
			
		||||
    return (user, ticket)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_pgt():
 | 
			
		||||
    """return a dict contening a service, user and PGT ticket for this service"""
 | 
			
		||||
    (host, port) = utils.PGTUrlHandler.run()[1:3]
 | 
			
		||||
    service = "http://%s:%s" % (host, port)
 | 
			
		||||
 | 
			
		||||
    (user, ticket) = get_user_ticket_request(service)
 | 
			
		||||
 | 
			
		||||
    client = Client()
 | 
			
		||||
    client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service})
 | 
			
		||||
    params = utils.PGTUrlHandler.PARAMS.copy()
 | 
			
		||||
 | 
			
		||||
    params["service"] = service
 | 
			
		||||
    params["user"] = user
 | 
			
		||||
 | 
			
		||||
    return params
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CheckPasswordCase(TestCase):
 | 
			
		||||
    """Tests for the utils function `utils.check_password`"""
 | 
			
		||||
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        """Generate random bytes string that will be used ass passwords"""
 | 
			
		||||
        self.password1 = utils.gen_saml_id()
 | 
			
		||||
        self.password2 = utils.gen_saml_id()
 | 
			
		||||
        if not isinstance(self.password1, bytes):  # pragma: no cover executed only in python3
 | 
			
		||||
            self.password1 = self.password1.encode("utf8")
 | 
			
		||||
            self.password2 = self.password2.encode("utf8")
 | 
			
		||||
 | 
			
		||||
    def test_setup(self):
 | 
			
		||||
        """check that generated password are bytes"""
 | 
			
		||||
        self.assertIsInstance(self.password1, bytes)
 | 
			
		||||
        self.assertIsInstance(self.password2, bytes)
 | 
			
		||||
 | 
			
		||||
    def test_plain(self):
 | 
			
		||||
        """test the plain auth method"""
 | 
			
		||||
        self.assertTrue(utils.check_password("plain", self.password1, self.password1, "utf8"))
 | 
			
		||||
        self.assertFalse(utils.check_password("plain", self.password1, self.password2, "utf8"))
 | 
			
		||||
 | 
			
		||||
    def test_crypt(self):
 | 
			
		||||
        """test the crypt auth method"""
 | 
			
		||||
        if six.PY3:
 | 
			
		||||
            hashed_password1 = utils.crypt.crypt(
 | 
			
		||||
                self.password1.decode("utf8"),
 | 
			
		||||
                "$6$UVVAQvrMyXMF3FF3"
 | 
			
		||||
            ).encode("utf8")
 | 
			
		||||
        else:
 | 
			
		||||
            hashed_password1 = utils.crypt.crypt(self.password1, "$6$UVVAQvrMyXMF3FF3")
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(utils.check_password("crypt", self.password1, hashed_password1, "utf8"))
 | 
			
		||||
        self.assertFalse(utils.check_password("crypt", self.password2, hashed_password1, "utf8"))
 | 
			
		||||
 | 
			
		||||
    def test_ldap_ssha(self):
 | 
			
		||||
        """test the ldap auth method with a {SSHA} scheme"""
 | 
			
		||||
        salt = b"UVVAQvrMyXMF3FF3"
 | 
			
		||||
        hashed_password1 = utils.LdapHashUserPassword.hash(b'{SSHA}', self.password1, salt, "utf8")
 | 
			
		||||
 | 
			
		||||
        self.assertIsInstance(hashed_password1, bytes)
 | 
			
		||||
        self.assertTrue(utils.check_password("ldap", self.password1, hashed_password1, "utf8"))
 | 
			
		||||
        self.assertFalse(utils.check_password("ldap", self.password2, hashed_password1, "utf8"))
 | 
			
		||||
 | 
			
		||||
    def test_hex_md5(self):
 | 
			
		||||
        """test the hex_md5 auth method"""
 | 
			
		||||
        hashed_password1 = utils.hashlib.md5(self.password1).hexdigest()
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(utils.check_password("hex_md5", self.password1, hashed_password1, "utf8"))
 | 
			
		||||
        self.assertFalse(utils.check_password("hex_md5", self.password2, hashed_password1, "utf8"))
 | 
			
		||||
 | 
			
		||||
    def test_hex_sha512(self):
 | 
			
		||||
        """test the hex_sha512 auth method"""
 | 
			
		||||
        hashed_password1 = utils.hashlib.sha512(self.password1).hexdigest()
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            utils.check_password("hex_sha512", self.password1, hashed_password1, "utf8")
 | 
			
		||||
        )
 | 
			
		||||
        self.assertFalse(
 | 
			
		||||
            utils.check_password("hex_sha512", self.password2, hashed_password1, "utf8")
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BaseServicePattern(object):
 | 
			
		||||
    """Mixing for setting up service pattern for testing"""
 | 
			
		||||
    def setup_service_patterns(self, proxy=False):
 | 
			
		||||
        """setting up service pattern"""
 | 
			
		||||
        # For general purpose testing
 | 
			
		||||
        self.service = "https://www.example.com"
 | 
			
		||||
        self.service_pattern = models.ServicePattern.objects.create(
 | 
			
		||||
            name="example",
 | 
			
		||||
            pattern="^https://www\.example\.com(/.*)?$",
 | 
			
		||||
            proxy=proxy,
 | 
			
		||||
        )
 | 
			
		||||
        models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
 | 
			
		||||
 | 
			
		||||
        # For testing the restrict_users attributes
 | 
			
		||||
        self.service_restrict_user_fail = "https://restrict_user_fail.example.com"
 | 
			
		||||
        self.service_pattern_restrict_user_fail = models.ServicePattern.objects.create(
 | 
			
		||||
            name="restrict_user_fail",
 | 
			
		||||
            pattern="^https://restrict_user_fail\.example\.com(/.*)?$",
 | 
			
		||||
            restrict_users=True,
 | 
			
		||||
            proxy=proxy,
 | 
			
		||||
        )
 | 
			
		||||
        self.service_restrict_user_success = "https://restrict_user_success.example.com"
 | 
			
		||||
        self.service_pattern_restrict_user_success = models.ServicePattern.objects.create(
 | 
			
		||||
            name="restrict_user_success",
 | 
			
		||||
            pattern="^https://restrict_user_success\.example\.com(/.*)?$",
 | 
			
		||||
            restrict_users=True,
 | 
			
		||||
            proxy=proxy,
 | 
			
		||||
        )
 | 
			
		||||
        models.Username.objects.create(
 | 
			
		||||
            value=settings.CAS_TEST_USER,
 | 
			
		||||
            service_pattern=self.service_pattern_restrict_user_success
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # For testing the user attributes filtering conditions
 | 
			
		||||
        self.service_filter_fail = "https://filter_fail.example.com"
 | 
			
		||||
        self.service_pattern_filter_fail = models.ServicePattern.objects.create(
 | 
			
		||||
            name="filter_fail",
 | 
			
		||||
            pattern="^https://filter_fail\.example\.com(/.*)?$",
 | 
			
		||||
            proxy=proxy,
 | 
			
		||||
        )
 | 
			
		||||
        models.FilterAttributValue.objects.create(
 | 
			
		||||
            attribut="right",
 | 
			
		||||
            pattern="^admin$",
 | 
			
		||||
            service_pattern=self.service_pattern_filter_fail
 | 
			
		||||
        )
 | 
			
		||||
        self.service_filter_success = "https://filter_success.example.com"
 | 
			
		||||
        self.service_pattern_filter_success = models.ServicePattern.objects.create(
 | 
			
		||||
            name="filter_success",
 | 
			
		||||
            pattern="^https://filter_success\.example\.com(/.*)?$",
 | 
			
		||||
            proxy=proxy,
 | 
			
		||||
        )
 | 
			
		||||
        models.FilterAttributValue.objects.create(
 | 
			
		||||
            attribut="email",
 | 
			
		||||
            pattern="^%s$" % re.escape(settings.CAS_TEST_ATTRIBUTES['email']),
 | 
			
		||||
            service_pattern=self.service_pattern_filter_success
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # For testing the user_field attributes
 | 
			
		||||
        self.service_field_needed_fail = "https://field_needed_fail.example.com"
 | 
			
		||||
        self.service_pattern_field_needed_fail = models.ServicePattern.objects.create(
 | 
			
		||||
            name="field_needed_fail",
 | 
			
		||||
            pattern="^https://field_needed_fail\.example\.com(/.*)?$",
 | 
			
		||||
            user_field="uid",
 | 
			
		||||
            proxy=proxy,
 | 
			
		||||
        )
 | 
			
		||||
        self.service_field_needed_success = "https://field_needed_success.example.com"
 | 
			
		||||
        self.service_pattern_field_needed_success = models.ServicePattern.objects.create(
 | 
			
		||||
            name="field_needed_success",
 | 
			
		||||
            pattern="^https://field_needed_success\.example\.com(/.*)?$",
 | 
			
		||||
            user_field="nom",
 | 
			
		||||
            proxy=proxy,
 | 
			
		||||
        )
 | 
			
		||||
from cas_server.tests.utils import (
 | 
			
		||||
    copy_form,
 | 
			
		||||
    get_login_page_params,
 | 
			
		||||
    get_auth_client,
 | 
			
		||||
    get_user_ticket_request,
 | 
			
		||||
    get_pgt,
 | 
			
		||||
    get_proxy_ticket
 | 
			
		||||
)
 | 
			
		||||
from cas_server.tests.mixin import BaseServicePattern, XmlContent
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
 | 
			
		||||
@@ -449,7 +261,7 @@ class LoginTestCase(TestCase, BaseServicePattern):
 | 
			
		||||
            response["Location"].startswith("%s?ticket=" % self.service_field_needed_success)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @override_settings(CAS_TEST_ATTRIBUTES={'nom': []})
 | 
			
		||||
    @override_settings(CAS_TEST_ATTRIBUTES={'alias': []})
 | 
			
		||||
    def test_service_user_field_evaluate_to_false(self):
 | 
			
		||||
        """
 | 
			
		||||
            Test using a user attribute as username:
 | 
			
		||||
@@ -458,7 +270,7 @@ class LoginTestCase(TestCase, BaseServicePattern):
 | 
			
		||||
        client = get_auth_client()
 | 
			
		||||
        response = client.get("/login", {"service": self.service_field_needed_success})
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        self.assertTrue(b"The attribut nom is needed to use that service" in response.content)
 | 
			
		||||
        self.assertTrue(b"The attribut alias is needed to use that service" in response.content)
 | 
			
		||||
 | 
			
		||||
    def test_gateway(self):
 | 
			
		||||
        """test gateway parameter"""
 | 
			
		||||
@@ -743,6 +555,22 @@ class AuthTestCase(TestCase):
 | 
			
		||||
    @override_settings(CAS_AUTH_SHARED_SECRET='test')
 | 
			
		||||
    def test_auth_view_goodpass(self):
 | 
			
		||||
        """successful request are awsered by yes"""
 | 
			
		||||
        client = get_auth_client()
 | 
			
		||||
        response = client.post(
 | 
			
		||||
            '/auth',
 | 
			
		||||
            {
 | 
			
		||||
                'username': settings.CAS_TEST_USER,
 | 
			
		||||
                'password': settings.CAS_TEST_PASSWORD,
 | 
			
		||||
                'service': self.service,
 | 
			
		||||
                'secret': 'test'
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        self.assertEqual(response.content, b'yes\n')
 | 
			
		||||
 | 
			
		||||
    @override_settings(CAS_AUTH_SHARED_SECRET='test')
 | 
			
		||||
    def test_auth_view_goodpass_logged(self):
 | 
			
		||||
        """successful request are awsered by yes, using a logged sessions"""
 | 
			
		||||
        client = Client()
 | 
			
		||||
        response = client.post(
 | 
			
		||||
            '/auth',
 | 
			
		||||
@@ -853,6 +681,12 @@ class ValidateTestCase(TestCase):
 | 
			
		||||
            pattern="^https://user_field\.example\.com(/.*)?$",
 | 
			
		||||
            user_field="alias"
 | 
			
		||||
        )
 | 
			
		||||
        self.service_user_field_alt = "https://user_field_alt.example.com"
 | 
			
		||||
        self.service_pattern_user_field_alt = models.ServicePattern.objects.create(
 | 
			
		||||
            name="user field alt",
 | 
			
		||||
            pattern="^https://user_field_alt\.example\.com(/.*)?$",
 | 
			
		||||
            user_field="nom"
 | 
			
		||||
        )
 | 
			
		||||
        models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
 | 
			
		||||
 | 
			
		||||
    def test_validate_view_ok(self):
 | 
			
		||||
@@ -893,14 +727,18 @@ class ValidateTestCase(TestCase):
 | 
			
		||||
            test with a good user_field. A bad user_field (that evaluate to False)
 | 
			
		||||
            wont happed cause it is filtered in the login view
 | 
			
		||||
        """
 | 
			
		||||
        ticket = get_user_ticket_request(self.service_user_field)[1]
 | 
			
		||||
        for (service, username) in [
 | 
			
		||||
            (self.service_user_field, b"demo1"),
 | 
			
		||||
            (self.service_user_field_alt, b"Nymous")
 | 
			
		||||
        ]:
 | 
			
		||||
            ticket = get_user_ticket_request(service)[1]
 | 
			
		||||
            client = Client()
 | 
			
		||||
            response = client.get(
 | 
			
		||||
                '/validate',
 | 
			
		||||
            {'ticket': ticket.value, 'service': self.service_user_field}
 | 
			
		||||
                {'ticket': ticket.value, 'service': service}
 | 
			
		||||
            )
 | 
			
		||||
            self.assertEqual(response.status_code, 200)
 | 
			
		||||
        self.assertEqual(response.content, b'yes\ndemo1\n')
 | 
			
		||||
            self.assertEqual(response.content, b'yes\n' + username + b'\n')
 | 
			
		||||
 | 
			
		||||
    def test_validate_missing_parameter(self):
 | 
			
		||||
        """test with a missing GET parameter among [service, ticket]"""
 | 
			
		||||
@@ -916,63 +754,6 @@ class ValidateTestCase(TestCase):
 | 
			
		||||
            self.assertEqual(response.content, b'no\n')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class XmlContent(object):
 | 
			
		||||
    """Mixin for test on CAS XML responses"""
 | 
			
		||||
    def assert_error(self, response, code, text=None):
 | 
			
		||||
        """Assert a validation error"""
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        root = etree.fromstring(response.content)
 | 
			
		||||
        error = root.xpath(
 | 
			
		||||
            "//cas:authenticationFailure",
 | 
			
		||||
            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(len(error), 1)
 | 
			
		||||
        self.assertEqual(error[0].attrib['code'], code)
 | 
			
		||||
        if text is not None:
 | 
			
		||||
            self.assertEqual(error[0].text, text)
 | 
			
		||||
 | 
			
		||||
    def assert_success(self, response, username, original_attributes):
 | 
			
		||||
        """assert a ticket validation success"""
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
 | 
			
		||||
        root = etree.fromstring(response.content)
 | 
			
		||||
        sucess = root.xpath(
 | 
			
		||||
            "//cas:authenticationSuccess",
 | 
			
		||||
            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
			
		||||
        )
 | 
			
		||||
        self.assertTrue(sucess)
 | 
			
		||||
 | 
			
		||||
        users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
 | 
			
		||||
        self.assertEqual(len(users), 1)
 | 
			
		||||
        self.assertEqual(users[0].text, username)
 | 
			
		||||
 | 
			
		||||
        attributes = root.xpath(
 | 
			
		||||
            "//cas:attributes",
 | 
			
		||||
            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(len(attributes), 1)
 | 
			
		||||
        attrs1 = set()
 | 
			
		||||
        for attr in attributes[0]:
 | 
			
		||||
            attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text))
 | 
			
		||||
 | 
			
		||||
        attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
 | 
			
		||||
        self.assertEqual(len(attributes), len(attrs1))
 | 
			
		||||
        attrs2 = set()
 | 
			
		||||
        for attr in attributes:
 | 
			
		||||
            attrs2.add((attr.attrib['name'], attr.attrib['value']))
 | 
			
		||||
        original = set()
 | 
			
		||||
        for key, value in original_attributes.items():
 | 
			
		||||
            if isinstance(value, list):
 | 
			
		||||
                for sub_value in value:
 | 
			
		||||
                    original.add((key, sub_value))
 | 
			
		||||
            else:
 | 
			
		||||
                original.add((key, value))
 | 
			
		||||
        self.assertEqual(attrs1, attrs2)
 | 
			
		||||
        self.assertEqual(attrs1, original)
 | 
			
		||||
 | 
			
		||||
        return root
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
 | 
			
		||||
class ValidateServiceTestCase(TestCase, XmlContent):
 | 
			
		||||
    """tests for the serviceValidate view"""
 | 
			
		||||
@@ -992,6 +773,12 @@ class ValidateServiceTestCase(TestCase, XmlContent):
 | 
			
		||||
            pattern="^https://user_field\.example\.com(/.*)?$",
 | 
			
		||||
            user_field="alias"
 | 
			
		||||
        )
 | 
			
		||||
        self.service_user_field_alt = "https://user_field_alt.example.com"
 | 
			
		||||
        self.service_pattern_user_field_alt = models.ServicePattern.objects.create(
 | 
			
		||||
            name="user field alt",
 | 
			
		||||
            pattern="^https://user_field_alt\.example\.com(/.*)?$",
 | 
			
		||||
            user_field="nom"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.service_one_attribute = "https://one_attribute.example.com"
 | 
			
		||||
        self.service_pattern_one_attribute = models.ServicePattern.objects.create(
 | 
			
		||||
@@ -1171,15 +958,19 @@ class ValidateServiceTestCase(TestCase, XmlContent):
 | 
			
		||||
            test with a good user_field. A bad user_field (that evaluate to False)
 | 
			
		||||
            wont happed cause it is filtered in the login view
 | 
			
		||||
        """
 | 
			
		||||
        ticket = get_user_ticket_request(self.service_user_field)[1]
 | 
			
		||||
        for (service, username) in [
 | 
			
		||||
            (self.service_user_field, settings.CAS_TEST_ATTRIBUTES["alias"][0]),
 | 
			
		||||
            (self.service_user_field_alt, settings.CAS_TEST_ATTRIBUTES["nom"])
 | 
			
		||||
        ]:
 | 
			
		||||
            ticket = get_user_ticket_request(service)[1]
 | 
			
		||||
            client = Client()
 | 
			
		||||
            response = client.get(
 | 
			
		||||
                '/serviceValidate',
 | 
			
		||||
            {'ticket': ticket.value, 'service': self.service_user_field}
 | 
			
		||||
                {'ticket': ticket.value, 'service': service}
 | 
			
		||||
            )
 | 
			
		||||
            self.assert_success(
 | 
			
		||||
                response,
 | 
			
		||||
            settings.CAS_TEST_ATTRIBUTES["alias"][0],
 | 
			
		||||
                username,
 | 
			
		||||
                {}
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
@@ -1349,3 +1140,198 @@ class ProxyTestCase(TestCase, BaseServicePattern, XmlContent):
 | 
			
		||||
                "INVALID_REQUEST",
 | 
			
		||||
                'you must specify and pgt and targetService'
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
 | 
			
		||||
class SamlValidateTestCase(TestCase, BaseServicePattern, XmlContent):
 | 
			
		||||
    """tests for the proxy view"""
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        """preparing test context"""
 | 
			
		||||
        self.setup_service_patterns(proxy=True)
 | 
			
		||||
 | 
			
		||||
        self.service_pgt = 'http://127.0.0.1'
 | 
			
		||||
        self.service_pattern_pgt = models.ServicePattern.objects.create(
 | 
			
		||||
            name="localhost",
 | 
			
		||||
            pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
 | 
			
		||||
            proxy=True,
 | 
			
		||||
            proxy_callback=True
 | 
			
		||||
        )
 | 
			
		||||
        models.ReplaceAttributName.objects.create(
 | 
			
		||||
            name="*",
 | 
			
		||||
            service_pattern=self.service_pattern_pgt
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    xml_template = """
 | 
			
		||||
<SOAP-ENV:Envelope xmlns:SOAP-ENV="http://schemas.xmlsoap.org/soap/envelope/">
 | 
			
		||||
    <SOAP-ENV:Header/>
 | 
			
		||||
    <SOAP-ENV:Body>
 | 
			
		||||
        <samlp:Request
 | 
			
		||||
            xmlns:samlp="urn:oasis:names:tc:SAML:1.0:protocol"
 | 
			
		||||
            MajorVersion="1" MinorVersion="1"
 | 
			
		||||
            RequestID="%(request_id)s"
 | 
			
		||||
            IssueInstant="%(issue_instant)s"
 | 
			
		||||
        >
 | 
			
		||||
            <samlp:AssertionArtifact>%(ticket)s</samlp:AssertionArtifact>
 | 
			
		||||
        </samlp:Request>
 | 
			
		||||
    </SOAP-ENV:Body>
 | 
			
		||||
</SOAP-ENV:Envelope>"""
 | 
			
		||||
 | 
			
		||||
    def assert_success(self, response, username, original_attributes):
 | 
			
		||||
        """assert ticket validation success"""
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        root = etree.fromstring(response.content)
 | 
			
		||||
        success = root.xpath(
 | 
			
		||||
            "//samlp:StatusCode",
 | 
			
		||||
            namespaces={'samlp': "urn:oasis:names:tc:SAML:1.0:protocol"}
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(len(success), 1)
 | 
			
		||||
        self.assertTrue(success[0].attrib['Value'].endswith(":Success"))
 | 
			
		||||
 | 
			
		||||
        user = root.xpath(
 | 
			
		||||
            "//samla:NameIdentifier",
 | 
			
		||||
            namespaces={'samla': "urn:oasis:names:tc:SAML:1.0:assertion"}
 | 
			
		||||
        )
 | 
			
		||||
        self.assertTrue(user)
 | 
			
		||||
        self.assertEqual(user[0].text, username)
 | 
			
		||||
 | 
			
		||||
        attributes = root.xpath(
 | 
			
		||||
            "//samla:AttributeStatement/samla:Attribute",
 | 
			
		||||
            namespaces={'samla': "urn:oasis:names:tc:SAML:1.0:assertion"}
 | 
			
		||||
        )
 | 
			
		||||
        attrs = set()
 | 
			
		||||
        for attr in attributes:
 | 
			
		||||
            attrs.add((attr.attrib['AttributeName'], attr.getchildren()[0].text))
 | 
			
		||||
        original = set()
 | 
			
		||||
        for key, value in original_attributes.items():
 | 
			
		||||
            if isinstance(value, list):
 | 
			
		||||
                for subval in value:
 | 
			
		||||
                    original.add((key, subval))
 | 
			
		||||
            else:
 | 
			
		||||
                original.add((key, value))
 | 
			
		||||
        self.assertEqual(original, attrs)
 | 
			
		||||
 | 
			
		||||
    def assert_error(self, response, code, msg=None):
 | 
			
		||||
        """assert ticket validation error"""
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        root = etree.fromstring(response.content)
 | 
			
		||||
        error = root.xpath(
 | 
			
		||||
            "//samlp:StatusCode",
 | 
			
		||||
            namespaces={'samlp': "urn:oasis:names:tc:SAML:1.0:protocol"}
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(len(error), 1)
 | 
			
		||||
        self.assertTrue(error[0].attrib['Value'].endswith(":%s" % code))
 | 
			
		||||
        if msg is not None:
 | 
			
		||||
            self.assertEqual(error[0].text, msg)
 | 
			
		||||
 | 
			
		||||
    def test_saml_ok(self):
 | 
			
		||||
        """
 | 
			
		||||
            test with a valid (ticket, service), with a ST and a PT,
 | 
			
		||||
            the username and all attributes are transmited"""
 | 
			
		||||
        tickets = [
 | 
			
		||||
            get_user_ticket_request(self.service)[1],
 | 
			
		||||
            get_proxy_ticket(self.service)
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        for ticket in tickets:
 | 
			
		||||
            client = Client()
 | 
			
		||||
            response = client.post(
 | 
			
		||||
                '/samlValidate?TARGET=%s' % self.service,
 | 
			
		||||
                self.xml_template % {
 | 
			
		||||
                    'ticket': ticket.value,
 | 
			
		||||
                    'request_id': utils.gen_saml_id(),
 | 
			
		||||
                    'issue_instant': timezone.now().isoformat()
 | 
			
		||||
                },
 | 
			
		||||
                content_type="text/xml; encoding='utf-8'"
 | 
			
		||||
            )
 | 
			
		||||
            self.assert_success(response, settings.CAS_TEST_USER, settings.CAS_TEST_ATTRIBUTES)
 | 
			
		||||
 | 
			
		||||
    def test_saml_ok_user_field(self):
 | 
			
		||||
        """test with a valid(ticket, service), use a attributes as transmitted username"""
 | 
			
		||||
        for (service, username) in [
 | 
			
		||||
            (self.service_field_needed_success, settings.CAS_TEST_ATTRIBUTES['alias'][0]),
 | 
			
		||||
            (self.service_field_needed_success_alt, settings.CAS_TEST_ATTRIBUTES['nom'])
 | 
			
		||||
        ]:
 | 
			
		||||
            ticket = get_user_ticket_request(service)[1]
 | 
			
		||||
 | 
			
		||||
            client = Client()
 | 
			
		||||
            response = client.post(
 | 
			
		||||
                '/samlValidate?TARGET=%s' % service,
 | 
			
		||||
                self.xml_template % {
 | 
			
		||||
                    'ticket': ticket.value,
 | 
			
		||||
                    'request_id': utils.gen_saml_id(),
 | 
			
		||||
                    'issue_instant': timezone.now().isoformat()
 | 
			
		||||
                },
 | 
			
		||||
                content_type="text/xml; encoding='utf-8'"
 | 
			
		||||
            )
 | 
			
		||||
            self.assert_success(response, username, {})
 | 
			
		||||
 | 
			
		||||
    def test_saml_bad_ticket(self):
 | 
			
		||||
        """test validation with a bad ST and a bad PT, validation should fail"""
 | 
			
		||||
        tickets = [utils.gen_st(), utils.gen_pt()]
 | 
			
		||||
 | 
			
		||||
        for ticket in tickets:
 | 
			
		||||
            client = Client()
 | 
			
		||||
            response = client.post(
 | 
			
		||||
                '/samlValidate?TARGET=%s' % self.service,
 | 
			
		||||
                self.xml_template % {
 | 
			
		||||
                    'ticket': ticket,
 | 
			
		||||
                    'request_id': utils.gen_saml_id(),
 | 
			
		||||
                    'issue_instant': timezone.now().isoformat()
 | 
			
		||||
                },
 | 
			
		||||
                content_type="text/xml; encoding='utf-8'"
 | 
			
		||||
            )
 | 
			
		||||
            self.assert_error(
 | 
			
		||||
                response,
 | 
			
		||||
                "AuthnFailed",
 | 
			
		||||
                'ticket %s not found' % ticket
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def test_saml_bad_ticket_prefix(self):
 | 
			
		||||
        """test validation with a bad ticket prefix. Validation should fail with 'AuthnFailed'"""
 | 
			
		||||
        bad_ticket = "RANDOM-NOT-BEGINING-WITH-ST-OR-ST"
 | 
			
		||||
        client = Client()
 | 
			
		||||
        response = client.post(
 | 
			
		||||
            '/samlValidate?TARGET=%s' % self.service,
 | 
			
		||||
            self.xml_template % {
 | 
			
		||||
                'ticket': bad_ticket,
 | 
			
		||||
                'request_id': utils.gen_saml_id(),
 | 
			
		||||
                'issue_instant': timezone.now().isoformat()
 | 
			
		||||
            },
 | 
			
		||||
            content_type="text/xml; encoding='utf-8'"
 | 
			
		||||
        )
 | 
			
		||||
        self.assert_error(
 | 
			
		||||
            response,
 | 
			
		||||
            "AuthnFailed",
 | 
			
		||||
            'ticket %s should begin with PT- or ST-' % bad_ticket
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_saml_bad_target(self):
 | 
			
		||||
        """test with a valid(ticket, service), but using a bad target"""
 | 
			
		||||
        bad_target = "https://www.example.org"
 | 
			
		||||
        ticket = get_user_ticket_request(self.service)[1]
 | 
			
		||||
 | 
			
		||||
        client = Client()
 | 
			
		||||
        response = client.post(
 | 
			
		||||
            '/samlValidate?TARGET=%s' % bad_target,
 | 
			
		||||
            self.xml_template % {
 | 
			
		||||
                'ticket': ticket.value,
 | 
			
		||||
                'request_id': utils.gen_saml_id(),
 | 
			
		||||
                'issue_instant': timezone.now().isoformat()
 | 
			
		||||
            },
 | 
			
		||||
            content_type="text/xml; encoding='utf-8'"
 | 
			
		||||
        )
 | 
			
		||||
        self.assert_error(
 | 
			
		||||
            response,
 | 
			
		||||
            "AuthnFailed",
 | 
			
		||||
            'TARGET %s do not match ticket service' % bad_target
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_saml_bad_xml(self):
 | 
			
		||||
        """test validation with a bad xml request, validation should fail"""
 | 
			
		||||
        client = Client()
 | 
			
		||||
        response = client.post(
 | 
			
		||||
            '/samlValidate?TARGET=%s' % self.service,
 | 
			
		||||
            "<root></root>",
 | 
			
		||||
            content_type="text/xml; encoding='utf-8'"
 | 
			
		||||
        )
 | 
			
		||||
        self.assert_error(response, 'VersionMismatch')
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										86
									
								
								cas_server/tests/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								cas_server/tests/utils.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,86 @@
 | 
			
		||||
"""Some utils functions for tests"""
 | 
			
		||||
from cas_server.default_settings import settings
 | 
			
		||||
 | 
			
		||||
from django.test import Client
 | 
			
		||||
 | 
			
		||||
from lxml import etree
 | 
			
		||||
 | 
			
		||||
from cas_server import models
 | 
			
		||||
from cas_server import utils
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def copy_form(form):
 | 
			
		||||
    """Copy form value into a dict"""
 | 
			
		||||
    params = {}
 | 
			
		||||
    for field in form:
 | 
			
		||||
        if field.value():
 | 
			
		||||
            params[field.name] = field.value()
 | 
			
		||||
        else:
 | 
			
		||||
            params[field.name] = ""
 | 
			
		||||
    return params
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_login_page_params(client=None):
 | 
			
		||||
    """Return a client and the POST params for the client to login"""
 | 
			
		||||
    if client is None:
 | 
			
		||||
        client = Client()
 | 
			
		||||
    response = client.get('/login')
 | 
			
		||||
    params = copy_form(response.context["form"])
 | 
			
		||||
    return client, params
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_auth_client(**update):
 | 
			
		||||
    """return a authenticated client"""
 | 
			
		||||
    client, params = get_login_page_params()
 | 
			
		||||
    params["username"] = settings.CAS_TEST_USER
 | 
			
		||||
    params["password"] = settings.CAS_TEST_PASSWORD
 | 
			
		||||
    params.update(update)
 | 
			
		||||
 | 
			
		||||
    client.post('/login', params)
 | 
			
		||||
    return client
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_user_ticket_request(service):
 | 
			
		||||
    """Make an auth client to request a ticket for `service`, return the tuple (user, ticket)"""
 | 
			
		||||
    client = get_auth_client()
 | 
			
		||||
    response = client.get("/login", {"service": service})
 | 
			
		||||
    ticket_value = response['Location'].split('ticket=')[-1]
 | 
			
		||||
    user = models.User.objects.get(
 | 
			
		||||
        username=settings.CAS_TEST_USER,
 | 
			
		||||
        session_key=client.session.session_key
 | 
			
		||||
    )
 | 
			
		||||
    ticket = models.ServiceTicket.objects.get(value=ticket_value)
 | 
			
		||||
    return (user, ticket)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_pgt():
 | 
			
		||||
    """return a dict contening a service, user and PGT ticket for this service"""
 | 
			
		||||
    (host, port) = utils.PGTUrlHandler.run()[1:3]
 | 
			
		||||
    service = "http://%s:%s" % (host, port)
 | 
			
		||||
 | 
			
		||||
    (user, ticket) = get_user_ticket_request(service)
 | 
			
		||||
 | 
			
		||||
    client = Client()
 | 
			
		||||
    client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service})
 | 
			
		||||
    params = utils.PGTUrlHandler.PARAMS.copy()
 | 
			
		||||
 | 
			
		||||
    params["service"] = service
 | 
			
		||||
    params["user"] = user
 | 
			
		||||
 | 
			
		||||
    return params
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_proxy_ticket(service):
 | 
			
		||||
    params = get_pgt()
 | 
			
		||||
 | 
			
		||||
    # get a proxy ticket
 | 
			
		||||
    client = Client()
 | 
			
		||||
    response = client.get('/proxy', {'pgt': params['pgtId'], 'targetService': service})
 | 
			
		||||
    root = etree.fromstring(response.content)
 | 
			
		||||
    proxy_ticket = root.xpath(
 | 
			
		||||
        "//cas:proxyTicket",
 | 
			
		||||
        namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
			
		||||
    )
 | 
			
		||||
    proxy_ticket = proxy_ticket[0].text
 | 
			
		||||
    ticket = models.ProxyTicket.objects.get(value=proxy_ticket)
 | 
			
		||||
    return ticket
 | 
			
		||||
@@ -923,11 +923,15 @@ class SamlValidate(View, AttributesMixin):
 | 
			
		||||
                'username': self.ticket.user.username,
 | 
			
		||||
                'attributes': attributes
 | 
			
		||||
            }
 | 
			
		||||
            if self.ticket.service_pattern.user_field and \
 | 
			
		||||
                    self.ticket.user.attributs.get(self.ticket.service_pattern.user_field):
 | 
			
		||||
            if (self.ticket.service_pattern.user_field and
 | 
			
		||||
                    self.ticket.user.attributs.get(self.ticket.service_pattern.user_field)):
 | 
			
		||||
                params['username'] = self.ticket.user.attributs.get(
 | 
			
		||||
                    self.ticket.service_pattern.user_field
 | 
			
		||||
                )
 | 
			
		||||
                if isinstance(params['username'], list):
 | 
			
		||||
                    # the list is not empty because we wont generate a ticket with a user_field
 | 
			
		||||
                    # that evaluate to False
 | 
			
		||||
                    params['username'] = params['username'][0]
 | 
			
		||||
            logger.info(
 | 
			
		||||
                "SamlValidate: ticket %s validated for user %s on service %s." % (
 | 
			
		||||
                    self.ticket.value,
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user