Skip to content
This repository has been archived by the owner on Feb 16, 2023. It is now read-only.

Commit

Permalink
more tests and bugfixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
jonaswinkler committed Nov 27, 2020
1 parent 6c30811 commit bc4192e
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 15 deletions.
2 changes: 1 addition & 1 deletion src/documents/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import pickle
import re

from django.conf import settings
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.neural_network import MLPClassifier
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
from sklearn.utils.multiclass import type_of_target

from documents.models import Document, MatchingModel
from paperless import settings


class IncompatibleClassifierVersionError(Exception):
Expand Down
104 changes: 104 additions & 0 deletions src/documents/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathvalidate import ValidationError
from rest_framework.test import APITestCase

from documents import index
from documents.models import Document, Correspondent, DocumentType, Tag
from documents.tests.utils import DirectoriesMixin

Expand Down Expand Up @@ -162,6 +163,109 @@ def test_document_filters(self):
results = response.data['results']
self.assertEqual(len(results), 3)

def test_search_no_query(self):
response = self.client.get("/api/search/")
results = response.data['results']

self.assertEqual(len(results), 0)

def test_search(self):
d1=Document.objects.create(title="invoice", content="the thing i bought at a shop and paid with bank account", checksum="A", pk=1)
d2=Document.objects.create(title="bank statement 1", content="things i paid for in august", pk=2, checksum="B")
d3=Document.objects.create(title="bank statement 3", content="things i paid for in september", pk=3, checksum="C")
with index.open_index(False).writer() as writer:
# Note to future self: there is a reason we dont use a model signal handler to update the index: some operations edit many documents at once
# (retagger, renamer) and we don't want to open a writer for each of these, but rather perform the entire operation with one writer.
# That's why we cant open the writer in a model on_save handler or something.
index.update_document(writer, d1)
index.update_document(writer, d2)
index.update_document(writer, d3)
response = self.client.get("/api/search/?query=bank")
results = response.data['results']
self.assertEqual(response.data['count'], 3)
self.assertEqual(response.data['page'], 1)
self.assertEqual(response.data['page_count'], 1)
self.assertEqual(len(results), 3)

response = self.client.get("/api/search/?query=september")
results = response.data['results']
self.assertEqual(response.data['count'], 1)
self.assertEqual(response.data['page'], 1)
self.assertEqual(response.data['page_count'], 1)
self.assertEqual(len(results), 1)

response = self.client.get("/api/search/?query=statement")
results = response.data['results']
self.assertEqual(response.data['count'], 2)
self.assertEqual(response.data['page'], 1)
self.assertEqual(response.data['page_count'], 1)
self.assertEqual(len(results), 2)

response = self.client.get("/api/search/?query=sfegdfg")
results = response.data['results']
self.assertEqual(response.data['count'], 0)
self.assertEqual(response.data['page'], 0)
self.assertEqual(response.data['page_count'], 0)
self.assertEqual(len(results), 0)

def test_search_multi_page(self):
with index.open_index(False).writer() as writer:
for i in range(55):
doc = Document.objects.create(checksum=str(i), pk=i+1, title=f"Document {i+1}", content="content")
index.update_document(writer, doc)

# This is here so that we test that no document gets returned twice (might happen if the paging is not working)
seen_ids = []

for i in range(1, 6):
response = self.client.get(f"/api/search/?query=content&page={i}")
results = response.data['results']
self.assertEqual(response.data['count'], 55)
self.assertEqual(response.data['page'], i)
self.assertEqual(response.data['page_count'], 6)
self.assertEqual(len(results), 10)

for result in results:
self.assertNotIn(result['id'], seen_ids)
seen_ids.append(result['id'])

response = self.client.get(f"/api/search/?query=content&page=6")
results = response.data['results']
self.assertEqual(response.data['count'], 55)
self.assertEqual(response.data['page'], 6)
self.assertEqual(response.data['page_count'], 6)
self.assertEqual(len(results), 5)

for result in results:
self.assertNotIn(result['id'], seen_ids)
seen_ids.append(result['id'])

response = self.client.get(f"/api/search/?query=content&page=7")
results = response.data['results']
self.assertEqual(response.data['count'], 55)
self.assertEqual(response.data['page'], 6)
self.assertEqual(response.data['page_count'], 6)
self.assertEqual(len(results), 5)

def test_search_invalid_page(self):
with index.open_index(False).writer() as writer:
for i in range(15):
doc = Document.objects.create(checksum=str(i), pk=i+1, title=f"Document {i+1}", content="content")
index.update_document(writer, doc)

first_page = self.client.get(f"/api/search/?query=content&page=1").data
second_page = self.client.get(f"/api/search/?query=content&page=2").data
should_be_first_page_1 = self.client.get(f"/api/search/?query=content&page=0").data
should_be_first_page_2 = self.client.get(f"/api/search/?query=content&page=dgfd").data
should_be_first_page_3 = self.client.get(f"/api/search/?query=content&page=").data
should_be_first_page_4 = self.client.get(f"/api/search/?query=content&page=-7868").data

self.assertDictEqual(first_page, should_be_first_page_1)
self.assertDictEqual(first_page, should_be_first_page_2)
self.assertDictEqual(first_page, should_be_first_page_3)
self.assertDictEqual(first_page, should_be_first_page_4)
self.assertNotEqual(len(first_page['results']), len(second_page['results']))

@mock.patch("documents.index.autocomplete")
def test_search_autocomplete(self, m):
m.side_effect = lambda ix, term, limit: [term for _ in range(limit)]
Expand Down
8 changes: 6 additions & 2 deletions src/documents/tests/test_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@

from documents.classifier import DocumentClassifier, IncompatibleClassifierVersionError
from documents.models import Correspondent, Document, Tag, DocumentType
from documents.tests.utils import DirectoriesMixin


class TestClassifier(TestCase):
class TestClassifier(DirectoriesMixin, TestCase):

def setUp(self):
super(TestClassifier, self).setUp()
self.classifier = DocumentClassifier()

def generate_test_data(self):
Expand Down Expand Up @@ -80,12 +82,14 @@ def testVersionIncreased(self):
self.assertTrue(self.classifier.train())
self.assertFalse(self.classifier.train())

self.classifier.save_classifier()

classifier2 = DocumentClassifier()

current_ver = DocumentClassifier.FORMAT_VERSION
with mock.patch("documents.classifier.DocumentClassifier.FORMAT_VERSION", current_ver+1):
# assure that we won't load old classifiers.
self.assertRaises(IncompatibleClassifierVersionError, self.classifier.reload)
self.assertRaises(IncompatibleClassifierVersionError, classifier2.reload)

self.classifier.save_classifier()

Expand Down
7 changes: 0 additions & 7 deletions src/documents/tests/test_document_retagger.py

This file was deleted.

58 changes: 58 additions & 0 deletions src/documents/tests/test_management_retagger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from django.core.management import call_command
from django.test import TestCase

from documents.models import Document, Tag, Correspondent, DocumentType
from documents.tests.utils import DirectoriesMixin


class TestRetagger(DirectoriesMixin, TestCase):

def make_models(self):
self.d1 = Document.objects.create(checksum="A", title="A", content="first document")
self.d2 = Document.objects.create(checksum="B", title="B", content="second document")
self.d3 = Document.objects.create(checksum="C", title="C", content="unrelated document")

self.tag_first = Tag.objects.create(name="tag1", match="first", matching_algorithm=Tag.MATCH_ANY)
self.tag_second = Tag.objects.create(name="tag2", match="second", matching_algorithm=Tag.MATCH_ANY)

self.correspondent_first = Correspondent.objects.create(
name="c1", match="first", matching_algorithm=Correspondent.MATCH_ANY)
self.correspondent_second = Correspondent.objects.create(
name="c2", match="second", matching_algorithm=Correspondent.MATCH_ANY)

self.doctype_first = DocumentType.objects.create(
name="dt1", match="first", matching_algorithm=DocumentType.MATCH_ANY)
self.doctype_second = DocumentType.objects.create(
name="dt2", match="second", matching_algorithm=DocumentType.MATCH_ANY)

def get_updated_docs(self):
return Document.objects.get(title="A"), Document.objects.get(title="B"), Document.objects.get(title="C")

def setUp(self) -> None:
super(TestRetagger, self).setUp()
self.make_models()

def test_add_tags(self):
call_command('document_retagger', '--tags')
d_first, d_second, d_unrelated = self.get_updated_docs()

self.assertEqual(d_first.tags.count(), 1)
self.assertEqual(d_second.tags.count(), 1)
self.assertEqual(d_unrelated.tags.count(), 0)

self.assertEqual(d_first.tags.first(), self.tag_first)
self.assertEqual(d_second.tags.first(), self.tag_second)

def test_add_type(self):
call_command('document_retagger', '--document_type')
d_first, d_second, d_unrelated = self.get_updated_docs()

self.assertEqual(d_first.document_type, self.doctype_first)
self.assertEqual(d_second.document_type, self.doctype_second)

def test_add_correspondent(self):
call_command('document_retagger', '--correspondent')
d_first, d_second, d_unrelated = self.get_updated_docs()

self.assertEqual(d_first.correspondent, self.correspondent_first)
self.assertEqual(d_second.correspondent, self.correspondent_second)
13 changes: 8 additions & 5 deletions src/documents/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ def setup_directories():
dirs.scratch_dir = tempfile.mkdtemp()
dirs.media_dir = tempfile.mkdtemp()
dirs.consumption_dir = tempfile.mkdtemp()
dirs.index_dir = os.path.join(dirs.data_dir, "documents", "originals")
dirs.index_dir = os.path.join(dirs.data_dir, "index")
dirs.originals_dir = os.path.join(dirs.media_dir, "documents", "originals")
dirs.thumbnail_dir = os.path.join(dirs.media_dir, "documents", "thumbnails")
os.makedirs(dirs.index_dir)
os.makedirs(dirs.originals_dir)
os.makedirs(dirs.thumbnail_dir)

os.makedirs(dirs.index_dir, exist_ok=True)
os.makedirs(dirs.originals_dir, exist_ok=True)
os.makedirs(dirs.thumbnail_dir, exist_ok=True)

override_settings(
DATA_DIR=dirs.data_dir,
Expand All @@ -28,7 +29,9 @@ def setup_directories():
ORIGINALS_DIR=dirs.originals_dir,
THUMBNAIL_DIR=dirs.thumbnail_dir,
CONSUMPTION_DIR=dirs.consumption_dir,
INDEX_DIR=dirs.index_dir
INDEX_DIR=dirs.index_dir,
MODEL_FILE=os.path.join(dirs.data_dir, "classification_model.pickle")

).enable()

return dirs
Expand Down
3 changes: 3 additions & 0 deletions src/documents/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ def get(self, request, format=None):
except (ValueError, TypeError):
page = 1

if page < 1:
page = 1

with index.query_page(self.ix, query, page) as result_page:
return Response(
{'count': len(result_page),
Expand Down

0 comments on commit bc4192e

Please sign in to comment.