From 820278ba28370e16c8e3b301b6d5c56ad04139e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=AD=B1=E6=96=87?= Date: Mon, 9 Dec 2024 16:08:44 +0800 Subject: [PATCH] Fix params for get_concurrency --- src/pai_rag/tools/data_process/embed_workflow.py | 5 ++--- src/pai_rag/tools/data_process/parse_workflow.py | 3 ++- src/pai_rag/tools/data_process/split_workflow.py | 5 ++--- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/pai_rag/tools/data_process/embed_workflow.py b/src/pai_rag/tools/data_process/embed_workflow.py index 2694a987..bc6ffeb3 100644 --- a/src/pai_rag/tools/data_process/embed_workflow.py +++ b/src/pai_rag/tools/data_process/embed_workflow.py @@ -17,10 +17,9 @@ def main(args): "working_dir": args.working_dir, "config_file": args.config_file, }, - concurrency=get_concurrency(), + concurrency=get_concurrency(args.num_cpus), batch_size=args.batch_size, ) - logger.info("Embedding nodes completed.") logger.info(f"Write to {args.output_dir}") timestamp = time.strftime("%Y%m%d-%H%M%S") ds = ds.repartition(1) @@ -31,7 +30,7 @@ def main(args): ), force_ascii=False, ) - logger.info("Write completed.") + logger.info(f"Write to {args.output_dir} successfully.") if __name__ == "__main__": diff --git a/src/pai_rag/tools/data_process/parse_workflow.py b/src/pai_rag/tools/data_process/parse_workflow.py index 77fdbc53..441e61d2 100644 --- a/src/pai_rag/tools/data_process/parse_workflow.py +++ b/src/pai_rag/tools/data_process/parse_workflow.py @@ -16,10 +16,11 @@ def main(args): results = ray.get(run_tasks) logger.info("Master node completed processing files.") os.makedirs(args.output_dir, exist_ok=True) + logger.info(f"Write to {args.output_dir}") timestamp = time.strftime("%Y%m%d-%H%M%S") save_file = os.path.join(args.output_dir, f"{timestamp}.jsonl") parser.write_to_file.remote(results, save_file) - logger.info(f"Results written to {save_file} asynchronously.") + logger.info(f"Write to {save_file} successfully.") if __name__ == "__main__": diff --git a/src/pai_rag/tools/data_process/split_workflow.py b/src/pai_rag/tools/data_process/split_workflow.py index 176acf75..3d6487de 100644 --- a/src/pai_rag/tools/data_process/split_workflow.py +++ b/src/pai_rag/tools/data_process/split_workflow.py @@ -18,11 +18,10 @@ def main(args): "working_dir": args.working_dir, "config_file": args.config_file, }, - concurrency=get_concurrency(), + concurrency=get_concurrency(args.num_cpus), batch_size=args.batch_size, ) - logger.info("Splitting nodes completed.") - + logger.info(f"Write to {args.output_dir}") timestamp = time.strftime("%Y%m%d-%H%M%S") ds = ds.repartition(1) ds.write_json(