diff --git a/taipy/rest/config/__init__.py b/taipy/rest/config/__init__.py new file mode 100644 index 0000000000..675fb0a3d4 --- /dev/null +++ b/taipy/rest/config/__init__.py @@ -0,0 +1,10 @@ +# Copyright 2021-2025 Avaiga Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. diff --git a/taipy/rest/config/rest_checker.py b/taipy/rest/config/rest_checker.py new file mode 100644 index 0000000000..12e66ae09b --- /dev/null +++ b/taipy/rest/config/rest_checker.py @@ -0,0 +1,63 @@ +# Copyright 2021-2025 Avaiga Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +from typing import cast + +from taipy.common.config._config import _Config +from taipy.common.config.checker._checkers._config_checker import _ConfigChecker +from taipy.common.config.checker.issue_collector import IssueCollector + +from .rest_config import RestConfig + + +class _RestConfigChecker(_ConfigChecker): + def __init__(self, config: _Config, collector: IssueCollector): + super().__init__(config, collector) + + def _check(self) -> IssueCollector: + rest_configs = cast(dict, self._config._sections.get(RestConfig.name, {})) + + for rest_config_id, rest_config in rest_configs.items(): + if rest_config_id != _Config.DEFAULT_KEY: + self._check_port(rest_config_id, rest_config) + self._check_host(rest_config_id, rest_config) + self._check_https_settings(rest_config_id, rest_config) + + return self._collector + + def _check_port(self, rest_config_id: str, rest_config: RestConfig): + if not isinstance(rest_config.port, int) or not (1 <= rest_config.port <= 65535): + self._error( + "port", + rest_config.port, + f"The port of RestConfig `{rest_config_id}` must be an integer between 1 and 65535.", + ) + + def _check_host(self, rest_config_id: str, rest_config: RestConfig): + if not isinstance(rest_config.host, str) or not rest_config.host: + self._error( + "host", rest_config.host, f"The host of RestConfig `{rest_config_id}` must be a non-empty string." + ) + + def _check_https_settings(self, rest_config_id: str, rest_config: RestConfig): + if rest_config.use_https: + if not rest_config.ssl_cert or not rest_config.ssl_key: + self._error( + "ssl_cert/ssl_key", + (rest_config.ssl_cert, rest_config.ssl_key), + f"When HTTPS is enabled in RestConfig `{rest_config_id}`, both ssl_cert and ssl_key must be set.", + ) + elif not isinstance(rest_config.ssl_cert, str) or not isinstance(rest_config.ssl_key, str): + self._error( + "ssl_cert/ssl_key", + (rest_config.ssl_cert, rest_config.ssl_key), + f"The ssl_cert and ssl_key of RestConfig `{rest_config_id}` must be valid strings.", + ) diff --git a/taipy/rest/config/rest_config.py b/taipy/rest/config/rest_config.py new file mode 100644 index 0000000000..0337682684 --- /dev/null +++ b/taipy/rest/config/rest_config.py @@ -0,0 +1,99 @@ +# Copyright 2021-2025 Avaiga Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +from typing import Optional, Tuple + +from taipy.common.config import Section, _inject_section + + +class RestConfig(Section): + name: str = "REST" + + def __init__(self): + super().__init__("rest") + self._port: int = 5000 + self._host: str = "127.0.0.1" + self._use_https: bool = False + self._ssl_cert: Optional[str] = None + self._ssl_key: Optional[str] = None + + def __copy__(self): + new_instance = RestConfig() + new_instance._port = self._port + new_instance._host = self._host + new_instance._use_https = self._use_https + new_instance._ssl_cert = self._ssl_cert + new_instance._ssl_key = self._ssl_key + return new_instance + + @property + def port(self) -> int: + return self._port + + @port.setter + def port(self, value: int): + self._port = value + + @property + def host(self) -> str: + return self._host + + @host.setter + def host(self, value: str): + self._host = value + + @property + def use_https(self) -> bool: + return self._use_https + + @use_https.setter + def use_https(self, value: bool): + self._use_https = value + + @property + def ssl_cert(self) -> Optional[str]: + return self._ssl_cert + + @ssl_cert.setter + def ssl_cert(self, value: Optional[str]): + self._ssl_cert = value + + @property + def ssl_key(self) -> Optional[str]: + return self._ssl_key + + @ssl_key.setter + def ssl_key(self, value: Optional[str]): + self._ssl_key = value + + @property + def ssl_context(self) -> Optional[Tuple[Optional[str], Optional[str]]]: + return (self._ssl_cert, self._ssl_key) if self._use_https else None + + def configure_rest( + self, + port: int = 5000, + host: str = "127.0.0.1", + use_https: bool = False, + ssl_cert: Optional[str] = None, + ssl_key: Optional[str] = None, + ): + self.port = port + self.host = host + self.use_https = use_https + self.ssl_cert = ssl_cert + self.ssl_key = ssl_key + + +# At the end of the file +_inject_section( + RestConfig, "rest", default=RestConfig(), configuration_methods=[("configure_rest", RestConfig.configure_rest)] +) diff --git a/taipy/rest/rest.py b/taipy/rest/rest.py index def98b195c..6b51d23192 100644 --- a/taipy/rest/rest.py +++ b/taipy/rest/rest.py @@ -44,4 +44,14 @@ def run(self, **kwargs): Arguments: **kwargs : Options to provide to the application server. """ + rest_config = Config.rest + kwargs.update( + { + "port": rest_config.port, + "host": rest_config.host, + "ssl_context": (rest_config.get("ssl_cert"), rest_config.get("ssl_key")) + if (rest_config.get("use_https", False)) + else None, + } + ) self._app.run(**kwargs) diff --git a/tests/rest/test_rest_config.py b/tests/rest/test_rest_config.py new file mode 100644 index 0000000000..07c17ed412 --- /dev/null +++ b/tests/rest/test_rest_config.py @@ -0,0 +1,121 @@ +# Copyright 2021-2025 Avaiga Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +from taipy.common.config._config import _Config +from taipy.common.config.checker.issue_collector import IssueCollector +from taipy.rest.config.rest_checker import _RestConfigChecker +from taipy.rest.config.rest_config import RestConfig + + +def test_rest_config_default_values(): + rest_config = RestConfig() + assert rest_config.port == 5000 + assert rest_config.host == "127.0.0.1" + assert rest_config.use_https is False + assert rest_config.ssl_cert is None + assert rest_config.ssl_key is None + + +def test_rest_config_custom_values(): + rest_config = RestConfig() + rest_config.configure_rest(port=8080, host="0.0.0.0", use_https=True, ssl_cert="cert.pem", ssl_key="key.pem") + + assert rest_config.port == 8080 + assert rest_config.host == "0.0.0.0" + assert rest_config.use_https is True + assert rest_config.ssl_cert == "cert.pem" + assert rest_config.ssl_key == "key.pem" + + +def test_rest_config_copy(): + rest_config = RestConfig() + rest_config.configure_rest(port=8080, host="0.0.0.0", use_https=True, ssl_cert="cert.pem", ssl_key="key.pem") + rest_config_copy = rest_config.__copy__() + + assert rest_config_copy.port == 8080 + assert rest_config_copy.host == "0.0.0.0" + assert rest_config_copy.use_https is True + assert rest_config_copy.ssl_cert == "cert.pem" + assert rest_config_copy.ssl_key == "key.pem" + + # Ensure it's a deep copy + rest_config_copy.port = 9090 + assert rest_config.port == 8080 + + +def test_rest_config_checker_valid_config(): + config = _Config() + collector = IssueCollector() + rest_config = RestConfig() + rest_config.configure_rest(port=8080, host="0.0.0.0", use_https=True, ssl_cert="cert.pem", ssl_key="key.pem") + + config._sections[RestConfig.name] = {"test_rest_config": rest_config} + checker = _RestConfigChecker(config, collector) + issues = checker._check() + + assert len(issues.errors) == 0 + assert len(issues.warnings) == 0 + + +def test_rest_config_checker_invalid_port(): + config = _Config() + collector = IssueCollector() + rest_config = RestConfig() + rest_config.configure_rest(port=70000) # Invalid port + + config._sections[RestConfig.name] = {"test_rest_config": rest_config} + checker = _RestConfigChecker(config, collector) + issues = checker._check() + + assert len(issues.errors) == 1 + assert "port" in issues.errors[0].field + + +def test_rest_config_checker_invalid_host(): + config = _Config() + collector = IssueCollector() + rest_config = RestConfig() + rest_config.configure_rest(host="") # Invalid host + + config._sections[RestConfig.name] = {"test_rest_config": rest_config} + checker = _RestConfigChecker(config, collector) + issues = checker._check() + + assert len(issues.errors) == 1 + assert "host" in issues.errors[0].field + + +def test_rest_config_checker_https_missing_cert_and_key(): + config = _Config() + collector = IssueCollector() + rest_config = RestConfig() + rest_config.configure_rest(use_https=True) # Missing ssl_cert and ssl_key + + config._sections[RestConfig.name] = {"test_rest_config": rest_config} + checker = _RestConfigChecker(config, collector) + issues = checker._check() + + assert len(issues.errors) == 1 + assert "ssl_cert/ssl_key" in issues.errors[0].field + + +def test_rest_config_checker_https_invalid_cert_and_key(): + config = _Config() + collector = IssueCollector() + rest_config = RestConfig() + rest_config.configure_rest(use_https=True, ssl_cert=123, ssl_key=456) # Invalid types for ssl_cert and ssl_key + + config._sections[RestConfig.name] = {"test_rest_config": rest_config} + checker = _RestConfigChecker(config, collector) + issues = checker._check() + + assert len(issues.errors) == 1 + assert "ssl_cert/ssl_key" in issues.errors[0].field