Skip to content

Commit

Permalink
Fix params for get_concurrency
Browse files Browse the repository at this point in the history
  • Loading branch information
wwxxzz committed Dec 9, 2024
1 parent c958bc3 commit 820278b
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 7 deletions.
5 changes: 2 additions & 3 deletions src/pai_rag/tools/data_process/embed_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@ def main(args):
"working_dir": args.working_dir,
"config_file": args.config_file,
},
concurrency=get_concurrency(),
concurrency=get_concurrency(args.num_cpus),
batch_size=args.batch_size,
)
logger.info("Embedding nodes completed.")
logger.info(f"Write to {args.output_dir}")
timestamp = time.strftime("%Y%m%d-%H%M%S")
ds = ds.repartition(1)
Expand All @@ -31,7 +30,7 @@ def main(args):
),
force_ascii=False,
)
logger.info("Write completed.")
logger.info(f"Write to {args.output_dir} successfully.")


if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion src/pai_rag/tools/data_process/parse_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ def main(args):
results = ray.get(run_tasks)
logger.info("Master node completed processing files.")
os.makedirs(args.output_dir, exist_ok=True)
logger.info(f"Write to {args.output_dir}")
timestamp = time.strftime("%Y%m%d-%H%M%S")
save_file = os.path.join(args.output_dir, f"{timestamp}.jsonl")
parser.write_to_file.remote(results, save_file)
logger.info(f"Results written to {save_file} asynchronously.")
logger.info(f"Write to {save_file} successfully.")


if __name__ == "__main__":
Expand Down
5 changes: 2 additions & 3 deletions src/pai_rag/tools/data_process/split_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@ def main(args):
"working_dir": args.working_dir,
"config_file": args.config_file,
},
concurrency=get_concurrency(),
concurrency=get_concurrency(args.num_cpus),
batch_size=args.batch_size,
)
logger.info("Splitting nodes completed.")

logger.info(f"Write to {args.output_dir}")
timestamp = time.strftime("%Y%m%d-%H%M%S")
ds = ds.repartition(1)
ds.write_json(
Expand Down

0 comments on commit 820278b

Please sign in to comment.