Skip to content

Commit

Permalink
parallel cache files test fix (#1109)
Browse files Browse the repository at this point in the history
  • Loading branch information
IIaKyJIuH authored Jun 8, 2023
1 parent 4baa484 commit 40312cb
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 15 deletions.
2 changes: 1 addition & 1 deletion fedot/core/caching/base_cache_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self, main_table: str = 'default', cache_dir: Optional[str] = None,
stats_keys: Sequence = ('default_hit', 'default_total')):
self._main_table = main_table
self._db_suffix = f'.{main_table}_db'
if cache_dir is None:
if cache_dir is None or Path(cache_dir).samefile(default_fedot_data_dir()):
self.db_path = Path(default_fedot_data_dir())
self._del_prev_temps()
else:
Expand Down
20 changes: 6 additions & 14 deletions test/integration/cache/test_parallel_cache_files.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,15 @@
import multiprocessing
import sqlite3
from functools import partial
from pathlib import Path
from typing import Callable

import psutil
from joblib import Parallel, cpu_count, delayed

from examples.simple.classification.api_classification import run_classification_example
from examples.simple.regression.api_regression import run_regression_example
from examples.simple.time_series_forecasting.api_forecasting import run_ts_forecasting_example
from fedot.core.utils import default_fedot_data_dir


def run_example(target: Callable):
target()


def get_unused_pid() -> int:
busy_pids = set(psutil.pids())
for test_pid in range(1, 10000):
Expand All @@ -34,18 +28,16 @@ def test_parallel_cache_files():
test_file_2.touch()

tasks = [
partial(run_regression_example, with_tuning=False),
partial(run_classification_example, timeout=2., with_tuning=False),
partial(run_ts_forecasting_example, dataset='beer', horizon=10, timeout=2., with_tuning=False),
delayed(run_regression_example)(with_tuning=False),
delayed(run_classification_example)(timeout=1., with_tuning=False),
delayed(run_ts_forecasting_example)(dataset='beer', horizon=10, timeout=1., with_tuning=False),
]

cpus = multiprocessing.cpu_count()
cpus = cpu_count()
if cpus > 1:
try:
with multiprocessing.Pool(processes=cpus) as pool:
list(pool.imap(run_example, tasks))
Parallel(n_jobs=cpus)(tasks)
except sqlite3.OperationalError:
assert False, 'DBs collides'

assert not test_file_1.exists()
assert not test_file_2.exists()

0 comments on commit 40312cb

Please sign in to comment.