Skip to content

Commit

Permalink
Remove hardcoded combined.jsonl with a flag
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaohanzhan-db committed Jan 12, 2024
1 parent fa8f3d9 commit 876fab5
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
8 changes: 7 additions & 1 deletion scripts/data_prep/convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ def fetch_DT(args: Namespace) -> None:
# combine downloaded jsonl into one big jsonl for IFT
iterative_combine_jsons(
args.json_output_path,
os.path.join(args.json_output_path, 'combined.jsonl'))
os.path.join(args.json_output_path, args.json_output_filename))


if __name__ == '__main__':
Expand Down Expand Up @@ -505,6 +505,12 @@ def fetch_DT(args: Namespace) -> None:
help=
'Use serverless or not. Make sure the workspace is entitled with serverless'
)
parser.add_argument(
'--json_output_filename',
required=False,
type=str,
default='train-00000-of-00001.jsonl',
help='The combined final jsonl that combines all partitioned jsonl')
args = parser.parse_args()

from databricks.sdk import WorkspaceClient
Expand Down
1 change: 1 addition & 0 deletions tests/a_scripts/data_prep/test_convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def test_stream_delta_to_json(self, mock_workspace_client: Any,
args.cluster_id = '1234'
args.debug = False
args.use_serverless = False
args.json_output_filename = 'combined.jsonl'

mock_cluster_get = MagicMock()
mock_cluster_get.return_value = MagicMock(
Expand Down

0 comments on commit 876fab5

Please sign in to comment.