Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an input URL list parameter #38

Merged
merged 16 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions tests/test_ecommerce.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,22 @@ def test_metadata():
"title": "URL",
"type": "string",
},
"urls": {
"anyOf": [
{"items": {"type": "string"}, "type": "array"},
{"type": "null"},
],
"default": None,
"description": (
"Initial URLs for the crawl, separated by new lines. Enter the "
"full URL including http(s), you can copy and paste it from your "
"browser. Example: https://toscrape.com/"
),
"exclusiveRequired": True,
"group": "inputs",
"title": "URLs",
"widget": "textarea",
},
"urls_file": {
"default": "",
"description": (
Expand Down Expand Up @@ -706,12 +722,24 @@ def test_input_none():

def test_input_multiple():
crawler = get_crawler()
with pytest.raises(ValueError):
EcommerceSpider.from_crawler(
crawler,
url="https://a.example",
urls=["https://b.example"],
)
with pytest.raises(ValueError):
EcommerceSpider.from_crawler(
crawler,
url="https://a.example",
urls_file="https://b.example",
)
with pytest.raises(ValueError):
EcommerceSpider.from_crawler(
crawler,
urls=["https://a.example"],
urls_file="https://b.example",
)


def test_url_invalid():
Expand All @@ -720,6 +748,47 @@ def test_url_invalid():
EcommerceSpider.from_crawler(crawler, url="foo")


def test_urls(caplog):
crawler = get_crawler()
url = "https://example.com"

spider = EcommerceSpider.from_crawler(crawler, urls=[url])
start_requests = list(spider.start_requests())
assert len(start_requests) == 1
assert start_requests[0].url == url
assert start_requests[0].callback == spider.parse_navigation

spider = EcommerceSpider.from_crawler(crawler, urls=url)
start_requests = list(spider.start_requests())
assert len(start_requests) == 1
assert start_requests[0].url == url
assert start_requests[0].callback == spider.parse_navigation

caplog.clear()
spider = EcommerceSpider.from_crawler(
crawler,
urls="https://a.example\n \nhttps://b.example\nhttps://c.example\nfoo\n\n",
)
assert "'foo', from the 'urls' spider argument, is not a valid URL" in caplog.text
start_requests = list(spider.start_requests())
assert len(start_requests) == 3
assert all(
request.callback == spider.parse_navigation for request in start_requests
)
assert start_requests[0].url == "https://a.example"
assert start_requests[1].url == "https://b.example"
assert start_requests[2].url == "https://c.example"

caplog.clear()
with pytest.raises(ValueError):
spider = EcommerceSpider.from_crawler(
crawler,
urls="foo\nbar",
)
assert "'foo', from the 'urls' spider argument, is not a valid URL" in caplog.text
assert "'bar', from the 'urls' spider argument, is not a valid URL" in caplog.text


def test_urls_file():
crawler = get_crawler()
url = "https://example.com"
Expand Down
54 changes: 53 additions & 1 deletion zyte_spider_templates/params.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
import re
from enum import Enum
from typing import Dict, Optional, Union
from logging import getLogger
from typing import Dict, List, Optional, Union

from pydantic import BaseModel, ConfigDict, Field, field_validator

Expand All @@ -12,6 +14,8 @@

from .utils import _URL_PATTERN

logger = getLogger(__name__)


@document_enum
class ExtractFrom(str, Enum):
Expand Down Expand Up @@ -110,6 +114,54 @@ class UrlParam(BaseModel):
)


class UrlsParam(BaseModel):
urls: Optional[List[str]] = Field(
title="URLs",
description=(
"Initial URLs for the crawl, separated by new lines. Enter the "
"full URL including http(s), you can copy and paste it from your "
"browser. Example: https://toscrape.com/"
),
default=None,
json_schema_extra={
"group": "inputs",
"exclusiveRequired": True,
"widget": "textarea",
},
)

@field_validator("urls", mode="before")
@classmethod
def validate_url_list(cls, value: Union[List[str], str]) -> List[str]:
"""Validate a list of URLs.

If a string is received as input, it is split into multiple strings
on new lines.

List items that do not match a URL pattern trigger a warning and are
removed from the list. If all URLs are invalid, validation fails.
"""
if isinstance(value, str):
value = value.split("\n")
if not value:
return value
result = []
for v in value:
v = v.strip()
if not v:
continue
if not re.search(_URL_PATTERN, v):
logger.warning(
f"{v!r}, from the 'urls' spider argument, is not a "
f"valid URL and will be ignored."
)
continue
result.append(v)
if not result:
raise ValueError(f"No valid URL found in {value!r}")
return result


class PostalAddress(BaseModel):
"""
Represents a postal address with various optional components such as
Expand Down
4 changes: 3 additions & 1 deletion zyte_spider_templates/spiders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,21 @@
MaxRequestsParam,
UrlParam,
UrlsFileParam,
UrlsParam,
)

# Higher priority than command-line-defined settings (40).
ARG_SETTING_PRIORITY: int = 50

_INPUT_FIELDS = ("url", "urls_file")
_INPUT_FIELDS = ("url", "urls", "urls_file")


class BaseSpiderParams(
ExtractFromParam,
MaxRequestsParam,
GeolocationParam,
UrlsFileParam,
UrlsParam,
UrlParam,
BaseModel,
):
Expand Down
2 changes: 2 additions & 0 deletions zyte_spider_templates/spiders/ecommerce.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ def _init_input(self):
urls = load_url_list(response.text)
self.logger.info(f"Loaded {len(urls)} initial URLs from {urls_file}.")
self.start_urls = urls
elif self.args.urls:
self.start_urls = self.args.urls
else:
self.start_urls = [self.args.url]
self.allowed_domains = list(set(get_domain(url) for url in self.start_urls))
Expand Down
Loading