Skip to content

Commit

Permalink
Revert "Remove deprecated advocate package (#6944)"
Browse files Browse the repository at this point in the history
This reverts commit bd115e7, as
it turns out to be a useful security feature.

In order to remove this in a better way, we'll need to replace it
with something that provides equivalent functionality.
  • Loading branch information
justinclift committed May 6, 2024
1 parent bd115e7 commit 62890c3
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 18 deletions.
77 changes: 76 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ readme = "README.md"

[tool.poetry.dependencies]
python = ">=3.8,<3.11"
advocate = "1.0.0"
aniso8601 = "8.0.0"
authlib = "0.15.5"
backoff = "2.2.1"
Expand Down
15 changes: 11 additions & 4 deletions redash/query_runner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@
from contextlib import ExitStack
from functools import wraps

import requests
import sqlparse
from dateutil import parser
from rq.timeouts import JobTimeoutException
from sshtunnel import open_tunnel

from redash import settings, utils
from redash.utils.requests_session import (
UnacceptableAddressException,
requests_or_advocate,
requests_session,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -375,7 +379,7 @@ def get_response(self, url, auth=None, http_method="get", **kwargs):
error = None
response = None
try:
response = requests.request(http_method, url, auth=auth, **kwargs)
response = requests_session.request(http_method, url, auth=auth, **kwargs)
# Raise a requests HTTP exception with the appropriate reason
# for 4xx and 5xx response status codes which is later caught
# and passed back.
Expand All @@ -385,11 +389,14 @@ def get_response(self, url, auth=None, http_method="get", **kwargs):
if response.status_code != 200:
error = "{} ({}).".format(self.response_error, response.status_code)

except requests.HTTPError as exc:
except requests_or_advocate.HTTPError as exc:
logger.exception(exc)
error = "Failed to execute query. "
f"Return Code: {response.status_code} Reason: {response.text}"
except requests.RequestException as exc:
except UnacceptableAddressException as exc:
logger.exception(exc)
error = "Can't query private addresses."
except requests_or_advocate.RequestException as exc:
# Catch all other requests exceptions and return the error.
logger.exception(exc)
error = str(exc)
Expand Down
10 changes: 8 additions & 2 deletions redash/query_runner/csv.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import io
import logging

import requests
import yaml

from redash.query_runner import BaseQueryRunner, NotSupported, register
from redash.utils.requests_session import (
UnacceptableAddressException,
requests_or_advocate,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -56,7 +59,7 @@ def run_query(self, query, user):
pass

try:
response = requests.get(url=path, headers={"User-agent": ua})
response = requests_or_advocate.get(url=path, headers={"User-agent": ua})
workbook = pd.read_csv(io.BytesIO(response.content), sep=",", **args)

df = workbook.copy()
Expand Down Expand Up @@ -96,6 +99,9 @@ def run_query(self, query, user):
except KeyboardInterrupt:
error = "Query cancelled by user."
data = None
except UnacceptableAddressException:
error = "Can't query private addresses."
data = None
except Exception as e:
error = "Error reading {0}. {1}".format(path, str(e))
data = None
Expand Down
10 changes: 8 additions & 2 deletions redash/query_runner/excel.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import logging

import requests
import yaml

from redash.query_runner import BaseQueryRunner, NotSupported, register
from redash.utils.requests_session import (
UnacceptableAddressException,
requests_or_advocate,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -54,7 +57,7 @@ def run_query(self, query, user):
pass

try:
response = requests.get(url=path, headers={"User-agent": ua})
response = requests_or_advocate.get(url=path, headers={"User-agent": ua})
workbook = pd.read_excel(response.content, **args)

df = workbook.copy()
Expand Down Expand Up @@ -94,6 +97,9 @@ def run_query(self, query, user):
except KeyboardInterrupt:
error = "Query cancelled by user."
data = None
except UnacceptableAddressException:
error = "Can't query private addresses."
data = None
except Exception as e:
error = "Error reading {0}. {1}".format(path, str(e))
data = None
Expand Down
3 changes: 3 additions & 0 deletions redash/settings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@
# Whether file downloads are enforced or not.
ENFORCE_FILE_SAVE = parse_boolean(os.environ.get("REDASH_ENFORCE_FILE_SAVE", "true"))

# Whether api calls using the json query runner will block private addresses
ENFORCE_PRIVATE_ADDRESS_BLOCK = parse_boolean(os.environ.get("REDASH_ENFORCE_PRIVATE_IP_BLOCK", "true"))

# Whether to use secure cookies by default.
COOKIES_SECURE = parse_boolean(os.environ.get("REDASH_COOKIES_SECURE", str(ENFORCE_HTTPS)))
# Whether the session cookie is set to secure.
Expand Down
18 changes: 18 additions & 0 deletions redash/utils/requests_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from advocate.exceptions import UnacceptableAddressException # noqa: F401

from redash import settings

if settings.ENFORCE_PRIVATE_ADDRESS_BLOCK:
import advocate as requests_or_advocate
else:
import requests as requests_or_advocate


class ConfiguredSession(requests_or_advocate.Session):
def request(self, *args, **kwargs):
if not settings.REQUESTS_ALLOW_REDIRECTS:
kwargs.update({"allow_redirects": False})
return super().request(*args, **kwargs)


requests_session = ConfiguredSession()
21 changes: 12 additions & 9 deletions tests/query_runner/test_http.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from unittest import TestCase

import mock
import requests

from redash.query_runner import BaseHTTPQueryRunner
from redash.utils.requests_session import (
ConfiguredSession,
requests_or_advocate,
)


class RequiresAuthQueryRunner(BaseHTTPQueryRunner):
Expand Down Expand Up @@ -34,7 +37,7 @@ def test_get_auth_empty_requires_authentication(self):
query_runner = RequiresAuthQueryRunner({})
self.assertRaisesRegex(ValueError, "Username and Password required", query_runner.get_auth)

@mock.patch("requests.request")
@mock.patch.object(ConfiguredSession, "request")
def test_get_response_success(self, mock_get):
mock_response = mock.Mock()
mock_response.status_code = 200
Expand All @@ -48,7 +51,7 @@ def test_get_response_success(self, mock_get):
self.assertEqual(response.status_code, 200)
self.assertIsNone(error)

@mock.patch("requests.request")
@mock.patch.object(ConfiguredSession, "request")
def test_get_response_success_custom_auth(self, mock_get):
mock_response = mock.Mock()
mock_response.status_code = 200
Expand All @@ -63,7 +66,7 @@ def test_get_response_success_custom_auth(self, mock_get):
self.assertEqual(response.status_code, 200)
self.assertIsNone(error)

@mock.patch("requests.request")
@mock.patch.object(ConfiguredSession, "request")
def test_get_response_failure(self, mock_get):
mock_response = mock.Mock()
mock_response.status_code = 301
Expand All @@ -76,12 +79,12 @@ def test_get_response_failure(self, mock_get):
mock_get.assert_called_once_with("get", url, auth=None)
self.assertIn(query_runner.response_error, error)

@mock.patch("requests.request")
@mock.patch.object(ConfiguredSession, "request")
def test_get_response_httperror_exception(self, mock_get):
mock_response = mock.Mock()
mock_response.status_code = 500
mock_response.text = "Server Error"
http_error = requests.HTTPError()
http_error = requests_or_advocate.HTTPError()
mock_response.raise_for_status.side_effect = http_error
mock_get.return_value = mock_response

Expand All @@ -92,13 +95,13 @@ def test_get_response_httperror_exception(self, mock_get):
self.assertIsNotNone(error)
self.assertIn("Failed to execute query", error)

@mock.patch("requests.request")
@mock.patch.object(ConfiguredSession, "request")
def test_get_response_requests_exception(self, mock_get):
mock_response = mock.Mock()
mock_response.status_code = 500
mock_response.text = "Server Error"
exception_message = "Some requests exception"
requests_exception = requests.RequestException(exception_message)
requests_exception = requests_or_advocate.RequestException(exception_message)
mock_response.raise_for_status.side_effect = requests_exception
mock_get.return_value = mock_response

Expand All @@ -109,7 +112,7 @@ def test_get_response_requests_exception(self, mock_get):
self.assertIsNotNone(error)
self.assertEqual(exception_message, error)

@mock.patch("requests.request")
@mock.patch.object(ConfiguredSession, "request")
def test_get_response_generic_exception(self, mock_get):
mock_response = mock.Mock()
mock_response.status_code = 500
Expand Down

0 comments on commit 62890c3

Please sign in to comment.