diff --git a/src/pai_rag/tools/data_process/embed_workflow.py b/src/pai_rag/tools/data_process/embed_workflow.py index 2694a987..bc6ffeb3 100644 --- a/src/pai_rag/tools/data_process/embed_workflow.py +++ b/src/pai_rag/tools/data_process/embed_workflow.py @@ -17,10 +17,9 @@ def main(args): "working_dir": args.working_dir, "config_file": args.config_file, }, - concurrency=get_concurrency(), + concurrency=get_concurrency(args.num_cpus), batch_size=args.batch_size, ) - logger.info("Embedding nodes completed.") logger.info(f"Write to {args.output_dir}") timestamp = time.strftime("%Y%m%d-%H%M%S") ds = ds.repartition(1) @@ -31,7 +30,7 @@ def main(args): ), force_ascii=False, ) - logger.info("Write completed.") + logger.info(f"Write to {args.output_dir} successfully.") if __name__ == "__main__": diff --git a/src/pai_rag/tools/data_process/parse_workflow.py b/src/pai_rag/tools/data_process/parse_workflow.py index 77fdbc53..441e61d2 100644 --- a/src/pai_rag/tools/data_process/parse_workflow.py +++ b/src/pai_rag/tools/data_process/parse_workflow.py @@ -16,10 +16,11 @@ def main(args): results = ray.get(run_tasks) logger.info("Master node completed processing files.") os.makedirs(args.output_dir, exist_ok=True) + logger.info(f"Write to {args.output_dir}") timestamp = time.strftime("%Y%m%d-%H%M%S") save_file = os.path.join(args.output_dir, f"{timestamp}.jsonl") parser.write_to_file.remote(results, save_file) - logger.info(f"Results written to {save_file} asynchronously.") + logger.info(f"Write to {save_file} successfully.") if __name__ == "__main__": diff --git a/src/pai_rag/tools/data_process/split_workflow.py b/src/pai_rag/tools/data_process/split_workflow.py index 176acf75..3d6487de 100644 --- a/src/pai_rag/tools/data_process/split_workflow.py +++ b/src/pai_rag/tools/data_process/split_workflow.py @@ -18,11 +18,10 @@ def main(args): "working_dir": args.working_dir, "config_file": args.config_file, }, - concurrency=get_concurrency(), + concurrency=get_concurrency(args.num_cpus), batch_size=args.batch_size, ) - logger.info("Splitting nodes completed.") - + logger.info(f"Write to {args.output_dir}") timestamp = time.strftime("%Y%m%d-%H%M%S") ds = ds.repartition(1) ds.write_json(