Skip to content

Commit

Permalink
feat(profil_precision): add filter profil_precisions in api
Browse files Browse the repository at this point in the history
  • Loading branch information
hlecuyer committed Nov 19, 2024
1 parent a75b885 commit 3f6fe40
Show file tree
Hide file tree
Showing 10 changed files with 208 additions and 4 deletions.
2 changes: 1 addition & 1 deletion api/CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ After running the main dag:
source .venv/bin/activate

# Launch command to import the Admin Express database
python src/data_inclusion/api/cli.py import_admin_express
python src/data_inclusion/api/cli.py import_communes

# Launch command to import data
python src/data_inclusion/api/cli.py load_inclusion_data
Expand Down
2 changes: 1 addition & 1 deletion api/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
"sqlalchemy",
"tqdm",
"uvicorn[standard]",
"data-inclusion-schema==0.17.0",
"data-inclusion-schema==0.20.0-dev1",
],
extras_require={
"test": [
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""add profils_precisions field in service
Revision ID: c947102bb23f
Revises: 68fe052dc63c
Create Date: 2024-10-28 17:22:23.374004
"""

import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects.postgresql import TSVECTOR

# revision identifiers, used by Alembic.
revision = "c947102bb23f"
down_revision = "68fe052dc63c"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.add_column(
"api__services",
sa.Column(
"profils_precisions",
sa.String(),
nullable=True,
),
)
# can't use ARRAY_TO_STRING mutable function in a generation expression.
# So it must be overiden by an immutable function
op.execute("""
CREATE OR REPLACE FUNCTION generate_profils_precisions(
profils_precisions TEXT,
profils TEXT[]
)
RETURNS TSVECTOR AS $$
BEGIN
RETURN to_tsvector(
'french',
COALESCE(profils_precisions, '') ||
' '||
COALESCE(ARRAY_TO_STRING(profils, ' '), '')
);
END;
$$ LANGUAGE plpgsql IMMUTABLE;
""")
op.add_column(
"api__services",
sa.Column(
"searchable_index_profils_precisions",
TSVECTOR(),
sa.Computed(
"generate_profils_precisions(profils_precisions, profils)",
persisted=True,
),
),
)
op.create_index(
"ix_api__services_searchable_index_profils_precisions",
"api__services",
["searchable_index_profils_precisions"],
postgresql_using="gin",
)


def downgrade() -> None:
op.drop_column("api__services", "searchable_index_profils_precisions")
op.drop_column("api__services", "profils_precisions")
12 changes: 12 additions & 0 deletions api/src/data_inclusion/api/inclusion_data/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from datetime import date

import sqlalchemy as sqla
from sqlalchemy import Computed
from sqlalchemy.dialects.postgresql import TSVECTOR
from sqlalchemy.orm import Mapped, mapped_column, relationship

from data_inclusion.api.core.db import Base
Expand Down Expand Up @@ -92,6 +94,16 @@ class Service(Base):
presentation_resume: Mapped[str | None]
prise_rdv: Mapped[str | None]
profils: Mapped[list[str] | None]
profils_precisions: Mapped[str | None]
# generate_profils_precisions is a function that generates
# a TSVECTOR from profils_precisions and profils
# cf: 20241028_172223_c947102bb23f_add_profils_autres_field_in_service.py
searchable_index_profils_precisions: Mapped[str | None] = mapped_column(
TSVECTOR,
Computed(
"generate_profils_precisions(profils_precisions, profils)", persisted=True
),
)
recurrence: Mapped[str | None]
source: Mapped[str]
structure_id: Mapped[str]
Expand Down
10 changes: 10 additions & 0 deletions api/src/data_inclusion/api/inclusion_data/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,15 @@ def search_services_endpoint(
"""
),
] = None,
profils_precisions: Annotated[
Optional[str],
fastapi.Query(
description="""Une recherche elargie sur les profils.
Chaque résultat renvoyé correspond a la recherche fulltext sur
ce champs.
"""
),
] = None,
types: Annotated[
Optional[list[di_schema.TypologieService]],
fastapi.Query(
Expand Down Expand Up @@ -420,6 +429,7 @@ def search_services_endpoint(
frais=frais,
modes_accueil=modes_accueil,
profils=profils,
profils_precisions=profils_precisions,
types=types,
search_point=search_point,
include_outdated=inclure_suspendus,
Expand Down
19 changes: 18 additions & 1 deletion api/src/data_inclusion/api/inclusion_data/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import geoalchemy2
import sqlalchemy as sqla
from sqlalchemy import orm
from sqlalchemy import func, orm

import fastapi

Expand Down Expand Up @@ -137,6 +137,17 @@ def filter_services_by_profils(
)


def filter_services_by_profils_precisions(
query: sqla.Select,
profils_precisions: str,
):
return query.filter(
models.Service.searchable_index_profils_precisions.bool_op("@@")(
func.websearch_to_tsquery("french", profils_precisions)
)
)


def filter_services_by_types(
query: sqla.Select,
types: list[di_schema.TypologieService],
Expand Down Expand Up @@ -265,6 +276,7 @@ def filter_services(
thematiques: list[di_schema.Thematique] | None = None,
frais: list[di_schema.Frais] | None = None,
profils: list[di_schema.Profil] | None = None,
profils_precisions: str | None = None,
modes_accueil: list[di_schema.ModeAccueil] | None = None,
types: list[di_schema.TypologieService] | None = None,
include_outdated: bool | None = False,
Expand Down Expand Up @@ -292,6 +304,9 @@ def filter_services(
if not include_outdated:
query = filter_outdated_services(query)

if profils_precisions is not None:
query = filter_services_by_profils_precisions(query, profils_precisions)

return query


Expand Down Expand Up @@ -354,6 +369,7 @@ def search_services(
frais: list[di_schema.Frais] | None = None,
modes_accueil: list[di_schema.ModeAccueil] | None = None,
profils: list[di_schema.Profil] | None = None,
profils_precisions: str | None = None,
types: list[di_schema.TypologieService] | None = None,
search_point: str | None = None,
include_outdated: bool | None = False,
Expand Down Expand Up @@ -454,6 +470,7 @@ def search_services(
thematiques=thematiques,
frais=frais,
profils=profils,
profils_precisions=profils_precisions,
modes_accueil=modes_accueil,
types=types,
include_outdated=include_outdated,
Expand Down
33 changes: 33 additions & 0 deletions api/tests/e2e/api/__snapshots__/test_inclusion_data.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,17 @@
},
"description": "Une liste de profils.\n Chaque résultat renvoyé a (au moins) un profil dans cette liste.\n "
},
{
"name": "profils_precisions",
"in": "query",
"required": false,
"schema": {
"type": "string",
"description": "Une recherche elargie sur les profils.\n Chaque résultat renvoyé correspond a la recherche fulltext sur\n ce champs.\n ",
"title": "Profils Precisions"
},
"description": "Une recherche elargie sur les profils.\n Chaque résultat renvoyé correspond a la recherche fulltext sur\n ce champs.\n "
},
{
"name": "types",
"in": "query",
Expand Down Expand Up @@ -1496,6 +1507,17 @@
],
"title": "Profils"
},
"profils_precisions": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"title": "Profils Precisions"
},
"pre_requis": {
"anyOf": [
{
Expand Down Expand Up @@ -2590,6 +2612,17 @@
],
"title": "Profils"
},
"profils_precisions": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"title": "Profils Precisions"
},
"pre_requis": {
"anyOf": [
{
Expand Down
63 changes: 63 additions & 0 deletions api/tests/e2e/api/test_inclusion_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def test_list_services_all(api_client):
"presentation_resume": "Puissant fine.",
"prise_rdv": "https://teixeira.fr/",
"profils": ["femmes"],
"profils_precisions": "femmes",
"recurrence": None,
"score_qualite": 0.5,
"source": "dora",
Expand Down Expand Up @@ -369,6 +370,68 @@ def test_list_services_filter_by_categorie_thematique(api_client):
assert resp_data["items"][0]["id"] == service.id


@pytest.mark.parametrize(
"profils_precisions,input,found",
[
("jeunes moins de 18 ans", "jeunes", True),
("jeune moins de 18 ans", "jeunes", True),
("jeunes et personne age", "vieux", False),
("jeunes et personne age", "personne OR âgée", True),
("jeunes et personne age", "personne jeune", True),
# FIXME: this test is failing because of the accent in the input
("jeunes et personne agee", "âgée", False),
],
)
@pytest.mark.with_token
def test_can_filter_resources_by_profils_precisions(
api_client, profils_precisions, input, found
):
resource = factories.ServiceFactory(
profils=None, profils_precisions=profils_precisions
)
factories.ServiceFactory(profils=None, profils_precisions="tests")

response = api_client.get(
"/api/v0/search/services", params={"profils_precisions": input}
)

assert response.status_code == 200
resp_data = response.json()
if found:
assert_paginated_response_data(resp_data, total=1)
assert list_resources_data(resp_data)[0]["id"] in [resource.id]
else:
assert_paginated_response_data(resp_data, total=0)


@pytest.mark.parametrize(
"profils,input,found",
[
([schema.Profil.FEMMES.value], "femme", True),
([schema.Profil.JEUNES_16_26.value], "jeune", True),
([schema.Profil.FEMMES.value], "jeune", False),
],
)
@pytest.mark.with_token
def test_can_filter_resources_by_profils_precisions_with_only_profils_data(
api_client, profils, input, found
):
resource = factories.ServiceFactory(profils=profils, profils_precisions="")
factories.ServiceFactory(profils=schema.Profil.RETRAITES, profils_precisions="")

response = api_client.get(
"/api/v0/search/services", params={"profils_precisions": input}
)

assert response.status_code == 200
resp_data = response.json()
if found:
assert_paginated_response_data(resp_data, total=1)
assert list_resources_data(resp_data)[0]["id"] in [resource.id]
else:
assert_paginated_response_data(resp_data, total=0)


@pytest.mark.parametrize(
"thematiques,input,found",
[
Expand Down
1 change: 1 addition & 0 deletions api/tests/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class Meta:
],
getter=lambda v: [v.value],
)
profils_precisions = factory.Faker("text", max_nb_chars=20, locale="fr_FR")
pre_requis = []
cumulable = False
justificatifs = []
Expand Down
2 changes: 1 addition & 1 deletion deployment/MIGRATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Here is the corrected and formatted version of your migration process:

1. **Connect via SSH to the instance**
```bash
ssh root@163.172.186.56
ssh root@<INSTANCE_IP>
```
2. **Install PostgreSQL 17**
```bash
Expand Down

0 comments on commit 3f6fe40

Please sign in to comment.