Tweak the cas client lib to always return unicode
hence, the behaviour is consistent between python2 and python3
This commit is contained in:
		@@ -21,6 +21,7 @@
 | 
			
		||||
# This file is originated from https://github.com/python-cas/python-cas
 | 
			
		||||
# at commit ec1f2d4779625229398547b9234d0e9e874a2c9a
 | 
			
		||||
 | 
			
		||||
import six
 | 
			
		||||
from six.moves.urllib import parse as urllib_parse
 | 
			
		||||
from six.moves.urllib import request as urllib_request
 | 
			
		||||
from six.moves.urllib.request import Request
 | 
			
		||||
@@ -32,6 +33,15 @@ class CASError(ValueError):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ReturnUnicode(object):
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def unicode(string, charset):
 | 
			
		||||
        if not isinstance(string, six.text_type):
 | 
			
		||||
            return string.decode(charset)
 | 
			
		||||
        else:
 | 
			
		||||
            return string
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SingleLogoutMixin(object):
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def get_saml_slos(cls, logout_request):
 | 
			
		||||
@@ -124,7 +134,7 @@ class CASClientBase(object):
 | 
			
		||||
        raise CASError("Bad http code %s" % response.code)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CASClientV1(CASClientBase):
 | 
			
		||||
class CASClientV1(CASClientBase, ReturnUnicode):
 | 
			
		||||
    """CAS Client Version 1"""
 | 
			
		||||
 | 
			
		||||
    logout_redirect_param_name = 'url'
 | 
			
		||||
@@ -140,15 +150,21 @@ class CASClientV1(CASClientBase):
 | 
			
		||||
        page = urllib_request.urlopen(url)
 | 
			
		||||
        try:
 | 
			
		||||
            verified = page.readline().strip()
 | 
			
		||||
            if verified == 'yes':
 | 
			
		||||
                return page.readline().strip(), None, None
 | 
			
		||||
            if verified == b'yes':
 | 
			
		||||
                content_type = page.info().get('Content-type')
 | 
			
		||||
                if "charset=" in content_type:
 | 
			
		||||
                    charset = content_type.split("charset=")[-1]
 | 
			
		||||
                else:
 | 
			
		||||
                    charset = "ascii"
 | 
			
		||||
                user = self.unicode(page.readline().strip(), charset)
 | 
			
		||||
                return user, None, None
 | 
			
		||||
            else:
 | 
			
		||||
                return None, None, None
 | 
			
		||||
        finally:
 | 
			
		||||
            page.close()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CASClientV2(CASClientBase):
 | 
			
		||||
class CASClientV2(CASClientBase, ReturnUnicode):
 | 
			
		||||
    """CAS Client Version 2"""
 | 
			
		||||
 | 
			
		||||
    url_suffix = 'serviceValidate'
 | 
			
		||||
@@ -161,8 +177,8 @@ class CASClientV2(CASClientBase):
 | 
			
		||||
 | 
			
		||||
    def verify_ticket(self, ticket):
 | 
			
		||||
        """Verifies CAS 2.0+/3.0+ XML-based authentication ticket and returns extended attributes"""
 | 
			
		||||
        response = self.get_verification_response(ticket)
 | 
			
		||||
        return self.verify_response(response)
 | 
			
		||||
        (response, charset) = self.get_verification_response(ticket)
 | 
			
		||||
        return self.verify_response(response, charset)
 | 
			
		||||
 | 
			
		||||
    def get_verification_response(self, ticket):
 | 
			
		||||
        params = [('ticket', ticket), ('service', self.service_url)]
 | 
			
		||||
@@ -172,37 +188,42 @@ class CASClientV2(CASClientBase):
 | 
			
		||||
        url = base_url + '?' + urllib_parse.urlencode(params)
 | 
			
		||||
        page = urllib_request.urlopen(url)
 | 
			
		||||
        try:
 | 
			
		||||
            return page.read()
 | 
			
		||||
            content_type = page.info().get('Content-type')
 | 
			
		||||
            if "charset=" in content_type:
 | 
			
		||||
                charset = content_type.split("charset=")[-1]
 | 
			
		||||
            else:
 | 
			
		||||
                charset = "ascii"
 | 
			
		||||
            return (page.read(), charset)
 | 
			
		||||
        finally:
 | 
			
		||||
            page.close()
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def parse_attributes_xml_element(cls, element):
 | 
			
		||||
    def parse_attributes_xml_element(cls, element, charset):
 | 
			
		||||
        attributes = dict()
 | 
			
		||||
        for attribute in element:
 | 
			
		||||
            tag = attribute.tag.split("}").pop()
 | 
			
		||||
            tag = cls.self.unicode(attribute.tag, charset).split(u"}").pop()
 | 
			
		||||
            if tag in attributes:
 | 
			
		||||
                if isinstance(attributes[tag], list):
 | 
			
		||||
                    attributes[tag].append(attribute.text)
 | 
			
		||||
                    attributes[tag].append(cls.unicode(attribute.text, charset))
 | 
			
		||||
                else:
 | 
			
		||||
                    attributes[tag] = [attributes[tag]]
 | 
			
		||||
                    attributes[tag].append(attribute.text)
 | 
			
		||||
                    attributes[tag].append(cls.unicode(attribute.text, charset))
 | 
			
		||||
            else:
 | 
			
		||||
                if tag == 'attraStyle':
 | 
			
		||||
                if tag == u'attraStyle':
 | 
			
		||||
                    pass
 | 
			
		||||
                else:
 | 
			
		||||
                    attributes[tag] = attribute.text
 | 
			
		||||
                    attributes[tag] = cls.unicode(attribute.text, charset)
 | 
			
		||||
        return attributes
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def verify_response(cls, response):
 | 
			
		||||
        user, attributes, pgtiou = cls.parse_response_xml(response)
 | 
			
		||||
    def verify_response(cls, response, charset):
 | 
			
		||||
        user, attributes, pgtiou = cls.parse_response_xml(response, charset)
 | 
			
		||||
        if len(attributes) == 0:
 | 
			
		||||
            attributes = None
 | 
			
		||||
        return user, attributes, pgtiou
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def parse_response_xml(cls, response):
 | 
			
		||||
    def parse_response_xml(cls, response, charset):
 | 
			
		||||
        try:
 | 
			
		||||
            from xml.etree import ElementTree
 | 
			
		||||
        except ImportError:
 | 
			
		||||
@@ -216,11 +237,11 @@ class CASClientV2(CASClientBase):
 | 
			
		||||
        if tree[0].tag.endswith('authenticationSuccess'):
 | 
			
		||||
            for element in tree[0]:
 | 
			
		||||
                if element.tag.endswith('user'):
 | 
			
		||||
                    user = element.text
 | 
			
		||||
                    user = cls.unicode(element.text, charset)
 | 
			
		||||
                elif element.tag.endswith('proxyGrantingTicket'):
 | 
			
		||||
                    pgtiou = element.text
 | 
			
		||||
                    pgtiou = cls.unicode(element.text, charset)
 | 
			
		||||
                elif element.tag.endswith('attributes'):
 | 
			
		||||
                    attributes = cls.parse_attributes_xml_element(element)
 | 
			
		||||
                    attributes = cls.parse_attributes_xml_element(element, charset)
 | 
			
		||||
        return user, attributes, pgtiou
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -230,23 +251,23 @@ class CASClientV3(CASClientV2, SingleLogoutMixin):
 | 
			
		||||
    logout_redirect_param_name = 'service'
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def parse_attributes_xml_element(cls, element):
 | 
			
		||||
    def parse_attributes_xml_element(cls, element, charset):
 | 
			
		||||
        attributes = dict()
 | 
			
		||||
        for attribute in element:
 | 
			
		||||
            tag = attribute.tag.split("}").pop()
 | 
			
		||||
            tag = cls.unicode(attribute.tag, charset).split(u"}").pop()
 | 
			
		||||
            if tag in attributes:
 | 
			
		||||
                if isinstance(attributes[tag], list):
 | 
			
		||||
                    attributes[tag].append(attribute.text)
 | 
			
		||||
                    attributes[tag].append(cls.unicode(attribute.text, charset))
 | 
			
		||||
                else:
 | 
			
		||||
                    attributes[tag] = [attributes[tag]]
 | 
			
		||||
                    attributes[tag].append(attribute.text)
 | 
			
		||||
                    attributes[tag].append(cls.unicode(attribute.text, charset))
 | 
			
		||||
            else:
 | 
			
		||||
                attributes[tag] = attribute.text
 | 
			
		||||
                attributes[tag] = cls.unicode(attribute.text, charset)
 | 
			
		||||
        return attributes
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def verify_response(cls, response):
 | 
			
		||||
        return cls.parse_response_xml(response)
 | 
			
		||||
    def verify_response(cls, response, charset):
 | 
			
		||||
        return cls.parse_response_xml(response, charset)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
SAML_1_0_NS = 'urn:oasis:names:tc:SAML:1.0:'
 | 
			
		||||
@@ -284,6 +305,11 @@ class CASClientWithSAMLV1(CASClientV2, SingleLogoutMixin):
 | 
			
		||||
            from elementtree import ElementTree
 | 
			
		||||
 | 
			
		||||
        page = self.fetch_saml_validation(ticket)
 | 
			
		||||
        content_type = page.info().get('Content-type')
 | 
			
		||||
        if "charset=" in content_type:
 | 
			
		||||
            charset = content_type.split("charset=")[-1]
 | 
			
		||||
        else:
 | 
			
		||||
            charset = "ascii"
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            user = None
 | 
			
		||||
@@ -296,21 +322,25 @@ class CASClientWithSAMLV1(CASClientV2, SingleLogoutMixin):
 | 
			
		||||
                # User is validated
 | 
			
		||||
                name_identifier = tree.find('.//' + SAML_1_0_ASSERTION_NS + 'NameIdentifier')
 | 
			
		||||
                if name_identifier is not None:
 | 
			
		||||
                    user = name_identifier.text
 | 
			
		||||
                    user = self.unicode(name_identifier.text, charset)
 | 
			
		||||
                attrs = tree.findall('.//' + SAML_1_0_ASSERTION_NS + 'Attribute')
 | 
			
		||||
                for at in attrs:
 | 
			
		||||
                    if self.username_attribute in list(at.attrib.values()):
 | 
			
		||||
                        user = at.find(SAML_1_0_ASSERTION_NS + 'AttributeValue').text
 | 
			
		||||
                        attributes['uid'] = user
 | 
			
		||||
                        user = self.unicode(
 | 
			
		||||
                            at.find(SAML_1_0_ASSERTION_NS + 'AttributeValue').text,
 | 
			
		||||
                            charset
 | 
			
		||||
                        )
 | 
			
		||||
                        attributes[u'uid'] = user
 | 
			
		||||
 | 
			
		||||
                    values = at.findall(SAML_1_0_ASSERTION_NS + 'AttributeValue')
 | 
			
		||||
                    key = self.unicode(at.attrib['AttributeName'], charset)
 | 
			
		||||
                    if len(values) > 1:
 | 
			
		||||
                        values_array = []
 | 
			
		||||
                        for v in values:
 | 
			
		||||
                            values_array.append(v.text)
 | 
			
		||||
                            attributes[at.attrib['AttributeName']] = values_array
 | 
			
		||||
                            values_array.append(self.unicode(v.text, charset))
 | 
			
		||||
                            attributes[key] = values_array
 | 
			
		||||
                    else:
 | 
			
		||||
                        attributes[at.attrib['AttributeName']] = values[0].text
 | 
			
		||||
                        attributes[key] = self.unicode(values[0].text, charset)
 | 
			
		||||
            return user, attributes, None
 | 
			
		||||
        finally:
 | 
			
		||||
            page.close()
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user