From 90731511edbfe5d58a0c2c7ebe601136b56ac1d4 Mon Sep 17 00:00:00 2001 From: Eric Pien Date: Wed, 4 Dec 2024 12:42:55 -0800 Subject: [PATCH] Update Screener - Setting Body to Screener returns the Screener Object to enable method chaining. (i.e. `r = yf.Screener().set_predefined_body("day_gainers").response`) - Limit query size to 250 and raise error if larger. This also avoids unnecessary network calls and makes the issue clearer. --- tests/test_screener.py | 28 ++++++++++++++--- yfinance/screener/screener.py | 57 +++++++++++++++++++++++++++++------ 2 files changed, 71 insertions(+), 14 deletions(-) diff --git a/tests/test_screener.py b/tests/test_screener.py index 17f4988a..0ec6c201 100644 --- a/tests/test_screener.py +++ b/tests/test_screener.py @@ -13,7 +13,7 @@ def setUpClass(self): self.query = EquityQuery('gt',['eodprice',3]) def test_set_default_body(self): - self.screener.set_default_body(self.query) + result = self.screener.set_default_body(self.query) self.assertEqual(self.screener.body['offset'], 0) self.assertEqual(self.screener.body['size'], 100) @@ -23,11 +23,13 @@ def test_set_default_body(self): self.assertEqual(self.screener.body['query'], self.query.to_dict()) self.assertEqual(self.screener.body['userId'], '') self.assertEqual(self.screener.body['userIdType'], 'guid') + self.assertEqual(self.screener, result) def test_set_predefined_body(self): k = 'most_actives' - self.screener.set_predefined_body(k) + result = self.screener.set_predefined_body(k) self.assertEqual(self.screener.body, PREDEFINED_SCREENER_BODY_MAP[k]) + self.assertEqual(self.screener, result) def test_set_predefined_body_invalid_key(self): with self.assertRaises(ValueError): @@ -44,9 +46,10 @@ def test_set_body(self): "userId": "", "userIdType": "guid" } - self.screener.set_body(body) + result = self.screener.set_body(body) self.assertEqual(self.screener.body, body) + self.assertEqual(self.screener, result) def test_set_body_missing_keys(self): body = { @@ -87,10 +90,11 @@ def test_patch_body(self): } self.screener.set_body(initial_body) patch_values = {"size": 50} - self.screener.patch_body(patch_values) + result = self.screener.patch_body(patch_values) self.assertEqual(self.screener.body['size'], 50) self.assertEqual(self.screener.body['query'], self.query.to_dict()) + self.assertEqual(self.screener, result) def test_patch_body_extra_keys(self): initial_body = { @@ -108,6 +112,22 @@ def test_patch_body_extra_keys(self): with self.assertRaises(ValueError): self.screener.patch_body(patch_values) + @patch('yfinance.screener.screener.YfData.post') + def test_set_large_size_in_body(self, mock_post): + body = { + "offset": 0, + "size": 251, # yahoo limits at 250 + "sortField": "ticker", + "sortType": "desc", + "quoteType": "equity", + "query": self.query.to_dict(), + "userId": "", + "userIdType": "guid" + } + + with self.assertRaises(ValueError): + self.screener.set_body(body).response + @patch('yfinance.screener.screener.YfData.post') def test_fetch(self, mock_post): mock_response = MagicMock() diff --git a/yfinance/screener/screener.py b/yfinance/screener/screener.py index cf6e1688..01ff667b 100644 --- a/yfinance/screener/screener.py +++ b/yfinance/screener/screener.py @@ -67,9 +67,22 @@ def predefined_bodies(self) -> Dict: """ return self._predefined_bodies - def set_default_body(self, query: Query, offset: int = 0, size: int = 100, sortField: str = "ticker", sortType: str = "desc", quoteType: str = "equity", userId: str = "", userIdType: str = "guid") -> None: + def set_default_body(self, query: Query, offset: int = 0, size: int = 100, sortField: str = "ticker", sortType: str = "desc", quoteType: str = "equity", userId: str = "", userIdType: str = "guid") -> 'Screener': """ - Set the default body using a custom query + Set the default body using a custom query. + + Args: + query (Query): The Query object to set as the body. + offset (Optional[int]): The offset for the results. Defaults to 0. + size (Optional[int]): The number of results to return. Defaults to 100. Maximum is 250 as set by Yahoo. + sortField (Optional[str]): The field to sort the results by. Defaults to "ticker". + sortType (Optional[str]): The type of sorting (e.g., "asc" or "desc"). Defaults to "desc". + quoteType (Optional[str]): The type of quote (e.g., "equity"). Defaults to "equity". + userId (Optional[str]): The user ID. Defaults to an empty string. + userIdType (Optional[str]): The type of user ID (e.g., "guid"). Defaults to "guid". + + Returns: + Screener: self Example: @@ -89,11 +102,18 @@ def set_default_body(self, query: Query, offset: int = 0, size: int = 100, sortF "userId": userId, "userIdType": userIdType } + return self - def set_predefined_body(self, k: str) -> None: + def set_predefined_body(self, predefined_key: str) -> 'Screener': """ Set a predefined body + Args: + predefined_key (str): key to one of predefined screens + + Returns: + Screener: self + Example: .. code-block:: python @@ -106,16 +126,23 @@ def set_predefined_body(self, k: str) -> None: :attr:`Screener.predefined_bodies ` supported predefined screens """ - body = PREDEFINED_SCREENER_BODY_MAP.get(k, None) + body = PREDEFINED_SCREENER_BODY_MAP.get(predefined_key, None) if not body: - raise ValueError(f'Invalid key {k} provided for predefined screener') + raise ValueError(f'Invalid key {predefined_key} provided for predefined screener') self._body_updated = True self._body = body + return self - def set_body(self, body: Dict) -> None: + def set_body(self, body: Dict) -> 'Screener': """ - Set the fully custom body + Set the fully custom body using dictionary input + + Args: + body (Dict): full query body + + Returns: + Screener: self Example: @@ -142,11 +169,17 @@ def set_body(self, body: Dict) -> None: self._body_updated = True self._body = body + return self - - def patch_body(self, values: Dict) -> None: + def patch_body(self, values: Dict) -> 'Screener': """ - Patch parts of the body + Patch parts of the body using dictionary input + + Args: + body (Dict): partial query body + + Returns: + Screener: self Example: @@ -161,10 +194,14 @@ def patch_body(self, values: Dict) -> None: self._body_updated = True for k in values: self._body[k] = values[k] + return self def _validate_body(self) -> None: if not all(k in self._body for k in self._accepted_body_keys): raise ValueError("Missing required keys in body") + + if self._body["size"] > 250: + raise ValueError("Yahoo limits query size to 250. Please decrease the size of the query.") def _fetch(self) -> Dict: params_dict = {"corsDomain": "finance.yahoo.com", "formatted": "false", "lang": "en-US", "region": "US"}