mirror of
https://gitlab.crans.org/bde/nk20
synced 2025-06-21 01:48:21 +02:00
Parse input of search filters to prevent errors based on invalid regex, fixes #113
Signed-off-by: Yohann D'ANELLO <ynerant@crans.org>
This commit is contained in:
42
apps/api/filters.py
Normal file
42
apps/api/filters.py
Normal file
@ -0,0 +1,42 @@
|
||||
import re
|
||||
from functools import lru_cache
|
||||
|
||||
from rest_framework.filters import SearchFilter
|
||||
|
||||
|
||||
class RegexSafeSearchFilter(SearchFilter):
|
||||
@lru_cache
|
||||
def validate_regex(self, search_term) -> bool:
|
||||
try:
|
||||
re.compile(search_term)
|
||||
return True
|
||||
except re.error:
|
||||
return False
|
||||
|
||||
def get_search_fields(self, view, request):
|
||||
"""
|
||||
Ensure that given regex are valid.
|
||||
If not, we consider that the user is trying to search by substring.
|
||||
"""
|
||||
search_fields = super().get_search_fields(view, request)
|
||||
search_terms = self.get_search_terms(request)
|
||||
|
||||
for search_term in search_terms:
|
||||
if not self.validate_regex(search_term):
|
||||
# Invalid regex. We assume we don't query by regex but by substring.
|
||||
search_fields = [f.replace('$', '') for f in search_fields]
|
||||
break
|
||||
|
||||
return search_fields
|
||||
|
||||
def get_search_terms(self, request):
|
||||
"""
|
||||
Ensure that search field is a valid regex query. If not, we remove extra characters.
|
||||
"""
|
||||
terms = super().get_search_terms(request)
|
||||
if not all(self.validate_regex(term) for term in terms):
|
||||
# Invalid regex. If a ^ is prefixed to the search term, we remove it.
|
||||
terms = [term[1:] if term[0] == '^' else term for term in terms]
|
||||
# Same for dollars.
|
||||
terms = [term[:-1] if term[-1] == '$' else term for term in terms]
|
||||
return terms
|
Reference in New Issue
Block a user