Skip to content

Commit

Permalink
Update Screener
Browse files Browse the repository at this point in the history
- Setting Body to Screener returns the Screener Object to enable method chaining. (i.e. `r = yf.Screener().set_predefined_body("day_gainers").response`)
- Limit query size to 250 and raise error if larger. This also avoids unnecessary network calls and makes the issue clearer.
  • Loading branch information
ericpien committed Dec 4, 2024
1 parent a8f0998 commit 9073151
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 14 deletions.
28 changes: 24 additions & 4 deletions tests/test_screener.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def setUpClass(self):
self.query = EquityQuery('gt',['eodprice',3])

def test_set_default_body(self):
self.screener.set_default_body(self.query)
result = self.screener.set_default_body(self.query)

self.assertEqual(self.screener.body['offset'], 0)
self.assertEqual(self.screener.body['size'], 100)
Expand All @@ -23,11 +23,13 @@ def test_set_default_body(self):
self.assertEqual(self.screener.body['query'], self.query.to_dict())
self.assertEqual(self.screener.body['userId'], '')
self.assertEqual(self.screener.body['userIdType'], 'guid')
self.assertEqual(self.screener, result)

def test_set_predefined_body(self):
k = 'most_actives'
self.screener.set_predefined_body(k)
result = self.screener.set_predefined_body(k)
self.assertEqual(self.screener.body, PREDEFINED_SCREENER_BODY_MAP[k])
self.assertEqual(self.screener, result)

def test_set_predefined_body_invalid_key(self):
with self.assertRaises(ValueError):
Expand All @@ -44,9 +46,10 @@ def test_set_body(self):
"userId": "",
"userIdType": "guid"
}
self.screener.set_body(body)
result = self.screener.set_body(body)

self.assertEqual(self.screener.body, body)
self.assertEqual(self.screener, result)

def test_set_body_missing_keys(self):
body = {
Expand Down Expand Up @@ -87,10 +90,11 @@ def test_patch_body(self):
}
self.screener.set_body(initial_body)
patch_values = {"size": 50}
self.screener.patch_body(patch_values)
result = self.screener.patch_body(patch_values)

self.assertEqual(self.screener.body['size'], 50)
self.assertEqual(self.screener.body['query'], self.query.to_dict())
self.assertEqual(self.screener, result)

def test_patch_body_extra_keys(self):
initial_body = {
Expand All @@ -108,6 +112,22 @@ def test_patch_body_extra_keys(self):
with self.assertRaises(ValueError):
self.screener.patch_body(patch_values)

@patch('yfinance.screener.screener.YfData.post')
def test_set_large_size_in_body(self, mock_post):
body = {
"offset": 0,
"size": 251, # yahoo limits at 250
"sortField": "ticker",
"sortType": "desc",
"quoteType": "equity",
"query": self.query.to_dict(),
"userId": "",
"userIdType": "guid"
}

with self.assertRaises(ValueError):
self.screener.set_body(body).response

@patch('yfinance.screener.screener.YfData.post')
def test_fetch(self, mock_post):
mock_response = MagicMock()
Expand Down
57 changes: 47 additions & 10 deletions yfinance/screener/screener.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,22 @@ def predefined_bodies(self) -> Dict:
"""
return self._predefined_bodies

def set_default_body(self, query: Query, offset: int = 0, size: int = 100, sortField: str = "ticker", sortType: str = "desc", quoteType: str = "equity", userId: str = "", userIdType: str = "guid") -> None:
def set_default_body(self, query: Query, offset: int = 0, size: int = 100, sortField: str = "ticker", sortType: str = "desc", quoteType: str = "equity", userId: str = "", userIdType: str = "guid") -> 'Screener':
"""
Set the default body using a custom query
Set the default body using a custom query.
Args:
query (Query): The Query object to set as the body.
offset (Optional[int]): The offset for the results. Defaults to 0.
size (Optional[int]): The number of results to return. Defaults to 100. Maximum is 250 as set by Yahoo.
sortField (Optional[str]): The field to sort the results by. Defaults to "ticker".
sortType (Optional[str]): The type of sorting (e.g., "asc" or "desc"). Defaults to "desc".
quoteType (Optional[str]): The type of quote (e.g., "equity"). Defaults to "equity".
userId (Optional[str]): The user ID. Defaults to an empty string.
userIdType (Optional[str]): The type of user ID (e.g., "guid"). Defaults to "guid".
Returns:
Screener: self
Example:
Expand All @@ -89,11 +102,18 @@ def set_default_body(self, query: Query, offset: int = 0, size: int = 100, sortF
"userId": userId,
"userIdType": userIdType
}
return self

def set_predefined_body(self, k: str) -> None:
def set_predefined_body(self, predefined_key: str) -> 'Screener':
"""
Set a predefined body
Args:
predefined_key (str): key to one of predefined screens
Returns:
Screener: self
Example:
.. code-block:: python
Expand All @@ -106,16 +126,23 @@ def set_predefined_body(self, k: str) -> None:
:attr:`Screener.predefined_bodies <yfinance.Screener.predefined_bodies>`
supported predefined screens
"""
body = PREDEFINED_SCREENER_BODY_MAP.get(k, None)
body = PREDEFINED_SCREENER_BODY_MAP.get(predefined_key, None)
if not body:
raise ValueError(f'Invalid key {k} provided for predefined screener')
raise ValueError(f'Invalid key {predefined_key} provided for predefined screener')

self._body_updated = True
self._body = body
return self

def set_body(self, body: Dict) -> None:
def set_body(self, body: Dict) -> 'Screener':
"""
Set the fully custom body
Set the fully custom body using dictionary input
Args:
body (Dict): full query body
Returns:
Screener: self
Example:
Expand All @@ -142,11 +169,17 @@ def set_body(self, body: Dict) -> None:

self._body_updated = True
self._body = body
return self


def patch_body(self, values: Dict) -> None:
def patch_body(self, values: Dict) -> 'Screener':
"""
Patch parts of the body
Patch parts of the body using dictionary input
Args:
body (Dict): partial query body
Returns:
Screener: self
Example:
Expand All @@ -161,10 +194,14 @@ def patch_body(self, values: Dict) -> None:
self._body_updated = True
for k in values:
self._body[k] = values[k]
return self

def _validate_body(self) -> None:
if not all(k in self._body for k in self._accepted_body_keys):
raise ValueError("Missing required keys in body")

if self._body["size"] > 250:
raise ValueError("Yahoo limits query size to 250. Please decrease the size of the query.")

def _fetch(self) -> Dict:
params_dict = {"corsDomain": "finance.yahoo.com", "formatted": "false", "lang": "en-US", "region": "US"}
Expand Down

0 comments on commit 9073151

Please sign in to comment.