diff --git a/.gitignore b/.gitignore index 53e0f2f..9bb3ecd 100644 --- a/.gitignore +++ b/.gitignore @@ -164,4 +164,5 @@ cython_debug/ .vscode *.pt tmp -generation \ No newline at end of file +generation +log/* \ No newline at end of file diff --git a/inference.ipynb b/inference.ipynb index 984e9e9..c68dea6 100644 --- a/inference.ipynb +++ b/inference.ipynb @@ -21,53 +21,409 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# ! cd modules/musecoco/1-text2attribute_model ;\\\n", + "# cat data/predict.json ;\\\n", + "# bash predict.sh ;\\\n", + "# python stage2_pre.py ;\\\n", + "# mv infer_test.bin ../2-attribute2music_model/data/infer_input\n", + "\n", + "# ! cd /workspace/Chat_Midi/muzic/musecoco/2-attribute2music_model ;\\\n", + "# bash interactive_1billion.sh 0 2" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# ! cd muzic/musecoco/evaluation ;\\\n", + "# python eval_acc_v3.py --root /workspace/Chat_Midi/muzic/musecoco/2-attribute2music_model/generation/0505/linear_mask-1billion-checkpoint_2_280000/infer_test/topk15-t1.0-ngram0" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from main import *" + ] + }, + { + "cell_type": "code", + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "[\n", - " {\n", - " \"text\": \"This music has a meter of 4/4 and a balanced beat. Its playtime is about 40 seconds. The use of grand piano, guitar, bass, violin, synthesizer and drum is vital to the music's overall sound and performance. The song spans approximately 13 ~ 16 bars.\"\n", - " },\n", - " {\n", - " \"text\": \"The music's limited pitch range of 5 octaves allows for a greater emphasis on the nuances of tone and phrasing, while its use of major key creates a distinct atmosphere. With a runtime of 31 ~ 45 seconds, this song showcases a highly vigorous rhythm and features grand piano, guitar, bass, violin, synthesizer and drum. It is played at a moderate speed, adhering to a 4/4 meter, and is characterized by its religious and pop sound.\"\n", - " }\n", - "]" + "2024-09-11 02:01:38 | WARNING | src.control.musecoco.text2attribute_model.main | Process rank: -1, device: cuda:0, n_gpu: 2distributed training: False, 16-bits training: False\n", + "2024-09-11 02:01:38 | INFO | src.control.musecoco.text2attribute_model.main | Training/evaluation parameters TrainingArguments(\n", + "_n_gpu=2,\n", + "adafactor=False,\n", + "adam_beta1=0.9,\n", + "adam_beta2=0.999,\n", + "adam_epsilon=1e-08,\n", + "auto_find_batch_size=False,\n", + "bf16=False,\n", + "bf16_full_eval=False,\n", + "data_seed=None,\n", + "dataloader_drop_last=False,\n", + "dataloader_num_workers=0,\n", + "dataloader_pin_memory=True,\n", + "ddp_bucket_cap_mb=None,\n", + "ddp_find_unused_parameters=None,\n", + "ddp_timeout=1800,\n", + "debug=[],\n", + "deepspeed=None,\n", + "disable_tqdm=False,\n", + "do_eval=False,\n", + "do_predict=True,\n", + "do_train=False,\n", + "eval_accumulation_steps=None,\n", + "eval_delay=0,\n", + "eval_steps=None,\n", + "evaluation_strategy=no,\n", + "fp16=False,\n", + "fp16_backend=auto,\n", + "fp16_full_eval=False,\n", + "fp16_opt_level=O1,\n", + "fsdp=[],\n", + "fsdp_min_num_params=0,\n", + "fsdp_transformer_layer_cls_to_wrap=None,\n", + "full_determinism=False,\n", + "gradient_accumulation_steps=1,\n", + "gradient_checkpointing=False,\n", + "greater_is_better=None,\n", + "group_by_length=False,\n", + "half_precision_backend=auto,\n", + "hub_model_id=None,\n", + "hub_private_repo=False,\n", + "hub_strategy=every_save,\n", + "hub_token=,\n", + "ignore_data_skip=False,\n", + "include_inputs_for_metrics=False,\n", + "jit_mode_eval=False,\n", + "label_names=None,\n", + "label_smoothing_factor=0.0,\n", + "learning_rate=5e-05,\n", + "length_column_name=length,\n", + "load_best_model_at_end=False,\n", + "local_rank=-1,\n", + "log_level=passive,\n", + "log_level_replica=passive,\n", + "log_on_each_node=True,\n", + "logging_dir=storage/tmp/runs/Sep11_02-01-38_2742e12a356b,\n", + "logging_first_step=False,\n", + "logging_nan_inf_filter=True,\n", + "logging_steps=500,\n", + "logging_strategy=steps,\n", + "lr_scheduler_type=linear,\n", + "max_grad_norm=1.0,\n", + "max_steps=-1,\n", + "metric_for_best_model=None,\n", + "mp_parameters=,\n", + "no_cuda=False,\n", + "num_train_epochs=3.0,\n", + "optim=adamw_hf,\n", + "optim_args=None,\n", + "output_dir=storage/tmp,\n", + "overwrite_output_dir=True,\n", + "past_index=-1,\n", + "per_device_eval_batch_size=8,\n", + "per_device_train_batch_size=8,\n", + "prediction_loss_only=False,\n", + "push_to_hub=False,\n", + "push_to_hub_model_id=None,\n", + "push_to_hub_organization=None,\n", + "push_to_hub_token=,\n", + "ray_scope=last,\n", + "remove_unused_columns=True,\n", + "report_to=['tensorboard', 'wandb'],\n", + "resume_from_checkpoint=None,\n", + "run_name=storage/tmp,\n", + "save_on_each_node=False,\n", + "save_steps=500,\n", + "save_strategy=steps,\n", + "save_total_limit=None,\n", + "seed=42,\n", + "sharded_ddp=[],\n", + "skip_memory_metrics=True,\n", + "tf32=None,\n", + "torch_compile=False,\n", + "torch_compile_backend=None,\n", + "torch_compile_mode=None,\n", + "torchdynamo=None,\n", + "tpu_metrics_debug=False,\n", + "tpu_num_cores=None,\n", + "use_ipex=False,\n", + "use_legacy_prediction_loop=False,\n", + "use_mps_device=False,\n", + "warmup_ratio=0.0,\n", + "warmup_steps=0,\n", + "weight_decay=0.0,\n", + "xpu_backend=None,\n", + ")\n", + "2024-09-11 02:01:38 | INFO | src.control.musecoco.text2attribute_model.main | load a local file for test: storage/input/predict.json\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/root/miniconda3/envs/MuseCoco/lib/python3.8/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", + " warnings.warn(\n", + "[INFO|configuration_utils.py:660] 2024-09-11 02:01:38,854 >> loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--IreneXu--MuseCoco_text2attribute/snapshots/7ae4925d3b78107069e16d24ee3755aa26be944e/config.json\n", + "[INFO|configuration_utils.py:712] 2024-09-11 02:01:38,855 >> Model config BertConfig {\n", + " \"_name_or_path\": \"bert-large-uncased\",\n", + " \"architectures\": [\n", + " \"BertForAttributModel\"\n", + " ],\n", + " \"attention_probs_dropout_prob\": 0.1,\n", + " \"classifier_dropout\": null,\n", + " \"gradient_checkpointing\": false,\n", + " \"hidden_act\": \"gelu\",\n", + " \"hidden_dropout_prob\": 0.1,\n", + " \"hidden_size\": 1024,\n", + " \"initializer_range\": 0.02,\n", + " \"intermediate_size\": 4096,\n", + " \"layer_norm_eps\": 1e-12,\n", + " \"max_position_embeddings\": 512,\n", + " \"model_type\": \"bert\",\n", + " \"num_attention_heads\": 16,\n", + " \"num_hidden_layers\": 24,\n", + " \"pad_token_id\": 0,\n", + " \"position_embedding_type\": \"absolute\",\n", + " \"torch_dtype\": \"float32\",\n", + " \"transformers_version\": \"4.26.0\",\n", + " \"type_vocab_size\": 2,\n", + " \"use_cache\": true,\n", + " \"vocab_size\": 30522\n", + "}\n", + "\n", + "[INFO|tokenization_utils_base.py:1802] 2024-09-11 02:01:39,294 >> loading file vocab.txt from cache at /root/.cache/huggingface/hub/models--IreneXu--MuseCoco_text2attribute/snapshots/7ae4925d3b78107069e16d24ee3755aa26be944e/vocab.txt\n", + "[INFO|tokenization_utils_base.py:1802] 2024-09-11 02:01:39,295 >> loading file tokenizer.json from cache at /root/.cache/huggingface/hub/models--IreneXu--MuseCoco_text2attribute/snapshots/7ae4925d3b78107069e16d24ee3755aa26be944e/tokenizer.json\n", + "[INFO|tokenization_utils_base.py:1802] 2024-09-11 02:01:39,295 >> loading file added_tokens.json from cache at None\n", + "[INFO|tokenization_utils_base.py:1802] 2024-09-11 02:01:39,295 >> loading file special_tokens_map.json from cache at /root/.cache/huggingface/hub/models--IreneXu--MuseCoco_text2attribute/snapshots/7ae4925d3b78107069e16d24ee3755aa26be944e/special_tokens_map.json\n", + "[INFO|tokenization_utils_base.py:1802] 2024-09-11 02:01:39,296 >> loading file tokenizer_config.json from cache at /root/.cache/huggingface/hub/models--IreneXu--MuseCoco_text2attribute/snapshots/7ae4925d3b78107069e16d24ee3755aa26be944e/tokenizer_config.json\n", + "[INFO|modeling_utils.py:2275] 2024-09-11 02:01:39,308 >> loading weights file pytorch_model.bin from cache at /root/.cache/huggingface/hub/models--IreneXu--MuseCoco_text2attribute/snapshots/7ae4925d3b78107069e16d24ee3755aa26be944e/pytorch_model.bin\n", + "[INFO|modeling_utils.py:2857] 2024-09-11 02:01:43,995 >> All model checkpoint weights were used when initializing BertForAttributModel.\n", + "\n", + "[INFO|modeling_utils.py:2865] 2024-09-11 02:01:43,996 >> All the weights of BertForAttributModel were initialized from the model checkpoint at IreneXu/MuseCoco_text2attribute.\n", + "If your task is similar to the task the model of the checkpoint was trained on, you can already use BertForAttributModel for predictions without further training.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2024-09-11 02:01:47 | INFO | fairseq_cli.interactive | Namespace(add_bos_token=False, all_gather_list_size=16384, batch_size=2, batch_size_valid=2, beam=1, bf16=False, bpe=None, broadcast_buffers=False, bucket_cap_mb=25, buffer_size=2, checkpoint_shard_count=1, checkpoint_suffix='', command_embed_dim=None, command_mask_prob=0.4, command_path=None, constraints=None, cpu=False, criterion='cross_entropy', ctrl_command_path='storage/tmp/infer_test.bin', curriculum=0, data='src/control/musecoco/attribute2music_model/data/truncated_2560/data-bin', data_buffer_size=10, dataset_impl=None, ddp_backend='c10d', decoding_format=None, device_id=0, disable_validation=False, distributed_backend='nccl', distributed_init_method=None, distributed_no_spawn=False, distributed_port=-1, distributed_rank=0, distributed_world_size=1, distributed_wrapper='DDP', diverse_beam_groups=-1, diverse_beam_strength=0.5, diversity_rate=-1.0, empty_cache_freq=0, end=100, fast_stat_sync=False, find_unused_parameters=False, fix_batches_to_gpus=False, fixed_validation_seed=None, force_anneal=None, fp16=False, fp16_init_scale=128, fp16_no_flatten_grads=False, fp16_scale_tolerance=0.0, fp16_scale_window=None, future_target=False, gen_subset='test', input='-', is_inference=False, iter_decode_eos_penalty=0.0, iter_decode_force_max_iter=False, iter_decode_max_iter=10, iter_decode_with_beam=1, iter_decode_with_external_reranker=False, lenpen=1, lm_path=None, lm_weight=0.0, localsgd_frequency=3, log_format=None, log_interval=100, lr_scheduler='fixed', lr_shrink=0.1, match_source_len=False, max_len_a=0, max_len_b=2560, max_target_positions=None, max_tokens=None, max_tokens_valid=None, memory_efficient_bf16=False, memory_efficient_fp16=False, min_len=512.0, min_loss_scale=0.0001, model_overrides='{}', model_parallel_size=1, nbest=1, need_num=2, no_beamable_mm=False, no_early_stop=False, no_progress_bar=False, no_repeat_ngram_size=0, no_seed_provided=False, nprocs_per_node=2, num_shards=1, num_workers=1, optimizer=None, output_dictionary_size=-1, padding_to_max_length=0, past_target=False, path='src/control/musecoco/attribute2music_model/checkpoints/linear_mask-1billion/checkpoint_2_280000.pt', pipeline_balance=None, pipeline_checkpoint='never', pipeline_chunks=0, pipeline_decoder_balance=None, pipeline_decoder_devices=None, pipeline_devices=None, pipeline_encoder_balance=None, pipeline_encoder_devices=None, pipeline_model_parallel=False, prefix_size=0, print_alignment=False, print_step=False, profile=False, quantization_config_path=None, quiet=False, remove_bpe=None, replace_unk=None, required_batch_size_multiple=8, required_seq_len_multiple=1, results_path=None, retain_dropout=False, retain_dropout_modules=None, retain_iter_history=False, sacrebleu=False, sample_break_mode='none', sampling=True, sampling_topk=15, sampling_topp=-1.0, save_root='storage/generation/0505/linear_mask-1billion-checkpoint_2_280000/topk15-t1.0-ngram0', score_reference=False, scoring='bleu', seed=1, self_target=False, shard_id=0, shorten_data_split_list='', shorten_method='none', skip_invalid_size_inputs_valid_test=False, slowmo_algorithm='LocalSGD', slowmo_momentum=None, start=0, task='language_modeling_control', temperature=1.0, tensorboard_logdir=None, threshold_loss_scale=None, tokenizer=None, tokens_per_sample=1024, tpu=False, train_subset='train', truncated_length=5868, unkpen=0, unnormalized=False, use_gold_labels=0, user_dir=None, valid_subset='valid', validate_after_updates=0, validate_interval=1, validate_interval_updates=0, warmup_updates=0, zero_sharding='none')\n", + "2024-09-11 02:01:47 | INFO | fairseq.tasks.language_modeling | dictionary: 1253 types\n", + "2024-09-11 02:01:47 | INFO | fairseq_cli.interactive | loading model(s) from src/control/musecoco/attribute2music_model/checkpoints/linear_mask-1billion/checkpoint_2_280000.pt\n", + "2024-09-11 02:02:42 | INFO | fairseq_cli.interactive | Sentence buffer size: 2\n", + "2024-09-11 02:02:42 | INFO | fairseq_cli.interactive | NOTE: hypothesis and token scores are output in base 2\n", + "2024-09-11 02:02:42 | INFO | fairseq_cli.interactive | Type the input sentence and press return:\n" ] } ], "source": [ - "! cd modules/musecoco/1-text2attribute_model ;\\\n", - "cat data/predict.json ;\\\n", - "bash predict.sh ;\\\n", - "python stage2_pre.py ;\\\n", - "mv infer_test.bin ../2-attribute2music_model/data/infer_input\n", - "\n", - "! cd /workspace/Chat_Midi/muzic/musecoco/2-attribute2music_model ;\\\n", - "bash interactive_1billion.sh 0 2" + "text2attribute_predictor = init_text2attribute()\n", + "attribute2midi_predictor = init_attribute2midi()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/root/miniconda3/envs/MuseCoco/lib/python3.8/site-packages/datasets/load.py:2566: FutureWarning: 'use_auth_token' was deprecated in favor of 'token' in version 2.14.0 and will be removed in 3.0.0.\n", + "You can remove this warning by passing 'token=' instead.\n", + " warnings.warn(\n", + "Using custom data configuration default-6c07d90dc18a6c34\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2024-09-11 02:03:20 | INFO | datasets.builder | Using custom data configuration default-6c07d90dc18a6c34\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading Dataset Infos from /root/miniconda3/envs/MuseCoco/lib/python3.8/site-packages/datasets/packaged_modules/json\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2024-09-11 02:03:20 | INFO | datasets.info | Loading Dataset Infos from /root/miniconda3/envs/MuseCoco/lib/python3.8/site-packages/datasets/packaged_modules/json\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Overwrite dataset info from restored data version if exists.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2024-09-11 02:03:20 | INFO | datasets.builder | Overwrite dataset info from restored data version if exists.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading Dataset info from /root/.cache/huggingface/datasets/json/default-6c07d90dc18a6c34/0.0.0/f4e89e8750d5d5ffbef2c078bf0ddfedef29dc2faff52a6255cf513c05eb1092\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2024-09-11 02:03:20 | INFO | datasets.info | Loading Dataset info from /root/.cache/huggingface/datasets/json/default-6c07d90dc18a6c34/0.0.0/f4e89e8750d5d5ffbef2c078bf0ddfedef29dc2faff52a6255cf513c05eb1092\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Found cached dataset json (/root/.cache/huggingface/datasets/json/default-6c07d90dc18a6c34/0.0.0/f4e89e8750d5d5ffbef2c078bf0ddfedef29dc2faff52a6255cf513c05eb1092)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2024-09-11 02:03:20 | INFO | datasets.builder | Found cached dataset json (/root/.cache/huggingface/datasets/json/default-6c07d90dc18a6c34/0.0.0/f4e89e8750d5d5ffbef2c078bf0ddfedef29dc2faff52a6255cf513c05eb1092)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading Dataset info from /root/.cache/huggingface/datasets/json/default-6c07d90dc18a6c34/0.0.0/f4e89e8750d5d5ffbef2c078bf0ddfedef29dc2faff52a6255cf513c05eb1092\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2024-09-11 02:03:20 | INFO | datasets.info | Loading Dataset info from /root/.cache/huggingface/datasets/json/default-6c07d90dc18a6c34/0.0.0/f4e89e8750d5d5ffbef2c078bf0ddfedef29dc2faff52a6255cf513c05eb1092\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "1293e84c4ad04064b0c6a28982377bfa", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Running tokenizer on dataset: 0%| | 0/1 [00:00> The following columns in the test set don't have a corresponding argument in `BertForAttributModel.forward` and have been ignored: text. If text are not expected by `BertForAttributModel.forward`, you can safely ignore this message.\n", + "/root/miniconda3/envs/MuseCoco/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py:30: UserWarning: \n", + " There is an imbalance between your GPUs. You may want to exclude GPU 0 which\n", + " has less than 75% of the memory or cores of GPU 1. You can do so by setting\n", + " the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES\n", + " environment variable.\n", + " warnings.warn(imbalance_warn.format(device_ids[min_pos], device_ids[max_pos]))\n", + "[INFO|trainer.py:2964] 2024-09-11 02:03:23,028 >> ***** Running Prediction *****\n", + "[INFO|trainer.py:2966] 2024-09-11 02:03:23,029 >> Num examples = 1\n", + "[INFO|trainer.py:2969] 2024-09-11 02:03:23,029 >> Batch size = 16\n", + "/root/miniconda3/envs/MuseCoco/lib/python3.8/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", + " warnings.warn('Was asked to gather along dimension 0, but all '\n" + ] + }, + { + "data": { + "text/html": [], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "text2attribute_predictor.predict()\n", + "prepare_stage2()\n" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Extract attributes: 100%|█████████████████████████| 3/3 [00:02<00:00, 1.39it/s]\n", - "Comput Accuracy: 100%|██████████████████████████| 2/2 [00:00<00:00, 1582.76it/s]\n", - "ASA: 0.45833333333333337\n" + "Starts to generate 0 to 1 of 2 samples in 2 batch steps!\n", + "2024-09-11 02:03:23 | INFO | src.control.musecoco.attribute2music_model.linear_mask.A2M_task_new | Using max_positions limit (100000) for unknown\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/workspace/Chat_Midi/ChatPiano/modules/musecoco-text2midi-service/src/control/musecoco/attribute2music_model/linear_mask/command_seq_generator.py:657: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n", + " unfin_idx = idx // beam_size\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "batch:0 save_id:0 over with length 940; Average translation time:45.81053590774536 seconds; Remi seq length: 876; Batch size:2; Translation shape:2.\n", + "batch:0 save_id:1 over with length 769; Average translation time:45.81053590774536 seconds; Remi seq length: 705; Batch size:2; Translation shape:2.\n" ] } ], "source": [ - "! cd muzic/musecoco/evaluation ;\\\n", - "python eval_acc_v3.py --root /workspace/Chat_Midi/muzic/musecoco/2-attribute2music_model/generation/0505/linear_mask-1billion-checkpoint_2_280000/infer_test/topk15-t1.0-ngram0" + "attribute2midi_predictor.predict()" ] }, { diff --git a/main.py b/main.py index 74ded3d..40396a4 100644 --- a/main.py +++ b/main.py @@ -1,18 +1,17 @@ import os import sys import shutil -from src.control.musecoco.text2attribute_model import main +from src.control.musecoco.text2attribute_model import Text2AttributePredictor, prepare_data from src.control.musecoco.attribute2music_model import interactive_dict_v5_1billion -from src.control.musecoco.text2attribute_model import stage2_pre # Import the stage2 script -def text2attribute(): +def init_text2attribute(): # Step 1: Simulate terminal input by modifying sys.argv for text2attribute model # Define variables model_name_or_path = "IreneXu/MuseCoco_text2attribute" - test_file = "src/control/musecoco/text2attribute_model/data/predict.json" + test_file = "storage/input/predict.json" attributes_file = "src/control/musecoco/text2attribute_model/data/att_key.json" num_labels_file = "src/control/musecoco/text2attribute_model/num_labels.json" - output_dir = "src/control/musecoco/text2attribute_model//tmp" + output_dir = "storage/tmp" # Convert Python variables into sys.argv format sys.argv = [ @@ -27,26 +26,27 @@ def text2attribute(): ] # Call the main function to process simulated inputs - main() + predictor = Text2AttributePredictor() + + return predictor def prepare_stage2(): # Step 2: Prepare intermediate data by executing necessary scripts # Move to the directory for text2attribute model processing # Run `stage2_pre.py` - you mentioned it's a script that can be imported - stage2_pre() + prepare_data() # Move generated `infer_test.bin` to the appropriate directory - source_path = "src/control/musecoco/text2attribute_model/infer_test.bin" - destination_path = "src/control/musecoco/attribute2music_model/data/infer_input/infer_test.bin" + source_path = "infer_test.bin" + destination_path = "storage/tmp/infer_test.bin" os.makedirs(os.path.dirname(destination_path), exist_ok=True) shutil.move(source_path, destination_path) -def attribute2midi(): +def init_attribute2midi(): # Step 3: Set up variables for attribute2music model start, end = 0, 100 # Example values for start and end model_size = "1billion" k = 15 - command_name = "infer_test" need_num = 2 temp = 1.0 ngram = 0 @@ -60,9 +60,9 @@ def attribute2midi(): # Step 4: Define paths DATA_DIR = f"src/control/musecoco/attribute2music_model/data/{datasets_name}" checkpoint_path = f"src/control/musecoco/attribute2music_model/checkpoints/{model_name}/{checkpoint_name}.pt" - ctrl_command_path = f"src/control/musecoco/attribute2music_model/data/infer_input/{command_name}.bin" - save_root = f"src/control/musecoco/attribute2music_model/generation/{date}/{model_name}-{checkpoint_name}/{command_name}/topk{k}-t{temp}-ngram{ngram}" - log_root = f"src/control/musecoco/attribute2music_model/log/{date}/{model_name}" + ctrl_command_path = f"storage/tmp/infer_test.bin" + save_root = f"storage/generation/{date}/{model_name}-{checkpoint_name}/topk{k}-t{temp}-ngram{ngram}" + log_root = f"storage/log/{date}/{model_name}" # Step 5: Set environment variables os.environ["CUDA_VISIBLE_DEVICES"] = device @@ -95,9 +95,13 @@ def attribute2midi(): # Step 9: Call cli_main with modified arguments interactive_dict_v5_1billion.seed_everything(2024) # Set random seed - interactive_dict_v5_1billion.cli_main() + + return interactive_dict_v5_1billion.Attribute2MusicPredictor() if __name__ == "__main__": - text2attribute() + text2attribute_predictor = init_text2attribute() + attribute2midi_predictor = init_attribute2midi() + + text2attribute_predictor.predict() prepare_stage2() - attribute2midi() \ No newline at end of file + attribute2midi_predictor.predict() diff --git a/src/control/musecoco/attribute2music_model/linear_mask/A2M_task_new.py b/src/control/musecoco/attribute2music_model/linear_mask/A2M_task_new.py index 1b73709..5687408 100644 --- a/src/control/musecoco/attribute2music_model/linear_mask/A2M_task_new.py +++ b/src/control/musecoco/attribute2music_model/linear_mask/A2M_task_new.py @@ -372,7 +372,7 @@ def build_generator( compute_alignment=getattr(args, "print_alignment", False), ) - from command_seq_generator import CommandSequenceGenerator + from .command_seq_generator import CommandSequenceGenerator # Choose search strategy. Defaults to Beam Search. sampling = getattr(args, "sampling", False) diff --git a/src/control/musecoco/attribute2music_model/linear_mask/interactive_dict_v5_1billion.py b/src/control/musecoco/attribute2music_model/linear_mask/interactive_dict_v5_1billion.py index a49eb77..2801290 100644 --- a/src/control/musecoco/attribute2music_model/linear_mask/interactive_dict_v5_1billion.py +++ b/src/control/musecoco/attribute2music_model/linear_mask/interactive_dict_v5_1billion.py @@ -418,6 +418,266 @@ def cli_main(): # else: # label_embedding(args) +class Attribute2MusicPredictor: + def __init__(self): + parser = options.get_interactive_generation_parser() + parser.add_argument("--save_root", type=str) + parser.add_argument("--need_num", type=int, default=32) + parser.add_argument("--ctrl_command_path", type=str, default="") + parser.add_argument("--start", type = int, default=None) + parser.add_argument("--end", type = int, default=None) + parser.add_argument("--use_gold_labels", type = int, default=0) + args = options.parse_args_and_arch(parser) + + self.args = args + + args = self.args + + start_time = time.time() + self.total_translate_time = 0 + + utils.import_user_module(args) + + if args.buffer_size < 1: + args.buffer_size = 1 + if args.max_tokens is None and args.batch_size is None: + args.batch_size = 1 + + assert ( + not args.sampling or args.nbest == args.beam + ), "--sampling requires --nbest to be equal to --beam" + assert ( + not args.batch_size or args.batch_size <= args.buffer_size + ), "--batch-size cannot be larger than --buffer-size" + + logger.info(args) + + # Fix seed for stochastic decoding + if args.seed is not None and not args.no_seed_provided: + np.random.seed(args.seed) + utils.set_torch_seed(args.seed) + + use_cuda = torch.cuda.is_available() and not args.cpu + + # Setup task, e.g., translation_control + task = tasks.setup_task(args) + + # Load ensemble + logger.info("loading model(s) from {}".format(args.path)) + models, _model_args = checkpoint_utils.load_model_ensemble( + args.path.split(os.pathsep), + arg_overrides=eval(args.model_overrides), + task=task, + suffix=getattr(args, "checkpoint_suffix", ""), + strict=(args.checkpoint_shard_count == 1), + num_shards=args.checkpoint_shard_count, + ) + + for model in models: + if args.fp16: + model.half() + if use_cuda and not args.pipeline_model_parallel: + model.cuda() + model.prepare_for_inference_(args) + model.decoder.args.is_inference = True + + # Set dictionaries + src_dict = task.source_dictionary + tgt_dict = task.target_dictionary + + # Initialize generator + generator = task.build_generator(models, args) + + # Handle tokenization and BPE + tokenizer = encoders.build_tokenizer(args) + bpe = encoders.build_bpe(args) + + def encode_fn(x): + if tokenizer is not None: + x = tokenizer.encode(x) + if bpe is not None: + x = bpe.encode(x) + return x + + def decode_fn(x): + if bpe is not None: + x = bpe.decode(x) + if tokenizer is not None: + x = tokenizer.decode(x) + return x + + # Load alignment dictionary for unknown word replacement + # (None if no unknown word replacement, empty if no path to align dictionary) + align_dict = utils.load_align_dict(args.replace_unk) + + max_positions = utils.resolve_max_positions( + task.max_positions(), *[model.max_positions() for model in models] + ) + + if args.constraints: + logger.warning( + "NOTE: Constrained decoding currently assumes a shared subword vocabulary." + ) + + if args.buffer_size > 1: + logger.info("Sentence buffer size: %s", args.buffer_size) + logger.info("NOTE: hypothesis and token scores are output in base 2") + logger.info("Type the input sentence and press return:") + start_id = 0 + + # for inputs in buffered_read(args.input, args.buffer_size): + self.save_root = args.save_root + os.makedirs(self.save_root, exist_ok=True) + midi_decoder = MidiDecoder("REMIGEN2") + + # test_command = np.load("../Text2Music_data/v2.1_20230218/full_0218_filter_by_5866/infer_command_balanced.npy", + # allow_pickle=True).item() + # test_command = np.load(args.ctrl_command_path, allow_pickle=True).item() + + # test_command = json.load(open(args.ctrl_command_path, "r")) + if args.use_gold_labels: + with open(args.save_root + "/Using_gold_labels!.txt", "w") as check_input: + pass + else: + with open(args.save_root + "/Using_pred_labels!.txt", "w") as check_input: + pass + + self.task = task + self.max_positions = max_positions + self.encode_fn = encode_fn + self.use_cuda = use_cuda + self.generator = generator + self.models = models + self.tgt_dict = tgt_dict + self.start_id = start_id + self.src_dict = src_dict + self.align_dict = align_dict + self.tgt_dict = tgt_dict + self.midi_decoder = midi_decoder + + + def predict(self): + args = self.args + test_command = pickle.load(open(args.ctrl_command_path, "rb")) + if args.start is None: + args.start = 0 + args.end = len(test_command) + else: + args.start = min(max(args.start, 0), len(test_command)) + args.end = min(max(args.end, 0), len(test_command)) + + gen_command_list = [] + for j in range(args.need_num): + for i in range(args.start, args.end): + if args.use_gold_labels: + pred_labels = test_command[i]["gold_labels"] + else: + pred_labels = test_command[i]["pred_labels"] + attribute_tokens = convert_vector_to_token(pred_labels) + test_command[i]["infer_command_tokens"] = attribute_tokens + gen_command_list.append([test_command[i]["infer_command_tokens"], f"{i}", j, test_command[i]]) + + steps = len(gen_command_list) // args.batch_size + print(f"Starts to generate {args.start} to {args.end} of {len(gen_command_list)} samples in {steps + 1} batch steps!") + + + for batch_step in range(steps + 1): + infer_list = gen_command_list[batch_step*args.batch_size:(batch_step+1)*args.batch_size] + infer_command_token = [g[0] for g in infer_list] + # assert infer_command.shape[1] == 133, f"error feature dim for {gen_key}!" + if len(infer_list) == 0: + continue + if os.path.exists(self.save_root + f"/{infer_list[-1][1]}/remi/{infer_list[-1][2]}.txt"): + print(f"Skip the {batch_step}-th batch since has been generated!") + continue + + # start_tokens = [f""] + start_tokens = [] + sep_pos = [] + for attribute_prefix in infer_command_token: + start_tokens.append(" ".join(attribute_prefix) + " ") + sep_pos.append(len(attribute_prefix)) # notice that pos is len(attribute_prefix) in this sequence + sep_pos = np.array(sep_pos) + for inputs in [start_tokens]: # "" for none prefix input + results = [] + for batch in make_batches(inputs, args, self.task, self.max_positions, self.encode_fn): + bsz = batch.src_tokens.size(0) + src_tokens = batch.src_tokens + src_lengths = batch.src_lengths + constraints = batch.constraints + + if self.use_cuda: + src_tokens = src_tokens.cuda() + src_lengths = src_lengths.cuda() + if constraints is not None: + constraints = constraints.cuda() + + sample = { + "net_input": { + "src_tokens": src_tokens, + "src_lengths": src_lengths, + "sep_pos": sep_pos, + }, + } + translate_start_time = time.time() + translations = self.task.inference_step( + self.generator, self.models, sample, constraints=constraints + ) + translate_time = time.time() - translate_start_time + self.total_translate_time += translate_time + list_constraints = [[] for _ in range(bsz)] + if args.constraints: + list_constraints = [unpack_constraints(c) for c in constraints] + + for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)): + src_tokens_i = utils.strip_pad(src_tokens[i], self.tgt_dict.pad()) + constraints = list_constraints[i] + results.append( + ( + self.start_id + id, + src_tokens_i, + hypos, + { + "constraints": constraints, + "time": translate_time / len(translations), + "translation_shape":len(translations), + }, + ) + ) + + # sort output to match input order + for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]): + if self.src_dict is not None: + src_str = self.src_dict.string(src_tokens, args.remove_bpe) + # Process top predictions + for hypo in hypos[: min(len(hypos), args.nbest)]: + hypo_tokens, hypo_str, alignment = utils.post_process_prediction( + hypo_tokens=hypo["tokens"].int().cpu(), + src_str=src_str, + alignment=hypo["alignment"], + align_dict=self.align_dict, + tgt_dict=self.tgt_dict, + remove_bpe=args.remove_bpe, + extra_symbols_to_ignore=get_symbols_to_strip_from_output(self.generator), + ) + + os.makedirs(self.save_root + f"/{infer_list[id_][1]}", exist_ok=True) + if not os.path.exists(self.save_root + f"/{infer_list[id_][1]}/infer_command.json"): + with open(self.save_root + f"/{infer_list[id_][1]}/infer_command.json", "w") as f: + json.dump(infer_list[id_][-1], f) + save_id = infer_list[id_][2] + + os.makedirs(self.save_root + f"/{infer_list[id_][1]}/remi", exist_ok=True) + with open(self.save_root + f"/{infer_list[id_][1]}/remi/{save_id}.txt", "w") as f: + f.write(hypo_str) + remi_token = hypo_str.split(" ")[sep_pos[id_] + 1:] + print(f"batch:{batch_step} save_id:{save_id} over with length {len(hypo_str.split(' '))}; " + f"Average translation time:{info['time']} seconds; Remi seq length: {len(remi_token)}; Batch size:{args.batch_size}; \ + Translation shape:{info['translation_shape']}.") + os.makedirs(self.save_root + f"/{infer_list[id_][1]}/midi", exist_ok=True) + midi_obj = self.midi_decoder.decode_from_token_str_list(remi_token) + midi_obj.dump(self.save_root + f"/{infer_list[id_][1]}/midi/{save_id}.mid") + if __name__ == "__main__": seed_everything(2024) # 2023 diff --git a/src/control/musecoco/text2attribute_model/__init__.py b/src/control/musecoco/text2attribute_model/__init__.py index c28a133..ad5e467 100644 --- a/src/control/musecoco/text2attribute_model/__init__.py +++ b/src/control/musecoco/text2attribute_model/__init__.py @@ -1 +1,2 @@ -from .main import main +from .main import Text2AttributePredictor +from .stage2_pre import prepare_data \ No newline at end of file diff --git a/src/control/musecoco/text2attribute_model/main.py b/src/control/musecoco/text2attribute_model/main.py index 14ee718..ed612f4 100644 --- a/src/control/musecoco/text2attribute_model/main.py +++ b/src/control/musecoco/text2attribute_model/main.py @@ -472,5 +472,188 @@ def compute_metrics(p: EvalPrediction): json.dump(result_output, open(os.path.join(training_args.output_dir, f"predict_attributes.json"),'w')) json.dump(softmaxprobs, open(os.path.join(training_args.output_dir, f"softmax_probs.json"),'w')) + + +class Text2AttributePredictor: + def __init__(self): + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + self.model_args, self.data_args, self.training_args = parser.parse_args_into_dataclasses() + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + log_level = self.training_args.get_process_log_level() + logger.setLevel(log_level) + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + # Log on each process the small summary: + logger.warning( + f"Process rank: {self.training_args.local_rank}, device: {self.training_args.device}, n_gpu: {self.training_args.n_gpu}" + + f"distributed training: {bool(self.training_args.local_rank != -1)}, 16-bits training: {self.training_args.fp16}" + ) + logger.info(f"Training/evaluation parameters {self.training_args}") + + # Detecting last checkpoint. + last_checkpoint = None + + # Set seed before initializing model. + set_seed(self.training_args.seed) + + # Get the datasets: + self.data_files = {} + + if self.data_args.test_file is not None: + test_extension = self.data_args.test_file.split(".")[-1] + assert ( + test_extension == 'json' + ), "`test_file` should have the extension `json`" + self.data_files["test"] = self.data_args.test_file + else: + raise ValueError("Need a test file for `do_predict`.") + + for key in self.data_files.keys(): + logger.info(f"load a local file for {key}: {self.data_files[key]}") + + + # Attribute values / labels + self.attributes = json.load(open(self.data_args.attributes, 'r')) + num_labels = OrderedDict() + if self.data_args.num_labels: + num_labels = json.load(open(self.data_args.num_labels)) + else: + raise ValueError("Need `num_label`.") + + # Load pretrained model and tokenizer + # + # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + config = BertConfig.from_pretrained( + self.model_args.config_name if self.model_args.config_name else self.model_args.model_name_or_path, + cache_dir=self.model_args.cache_dir, + revision=self.model_args.model_revision, + use_auth_token=True if self.model_args.use_auth_token else None, + ) + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_args.tokenizer_name if self.model_args.tokenizer_name else self.model_args.model_name_or_path, + cache_dir=self.model_args.cache_dir, + use_fast=self.model_args.use_fast_tokenizer, + revision=self.model_args.model_revision, + use_auth_token=True if self.model_args.use_auth_token else None, + ) + model = BertForAttributModel.from_pretrained( + self.model_args.model_name_or_path, + from_tf=bool(".ckpt" in self.model_args.model_name_or_path), + config=config, + num_labels = num_labels, + tokenizer = self.tokenizer, + cache_dir=self.model_args.cache_dir, + revision=self.model_args.model_revision, + use_auth_token=True if self.model_args.use_auth_token else None, + ignore_mismatched_sizes=self.model_args.ignore_mismatched_sizes, + ) + + model.tokenizer = self.tokenizer + + # Preprocessing the raw_datasets + # Padding strategy + if self.data_args.pad_to_max_length: + self.padding = "max_length" + else: + # We will pad later, dynamically at batch creation, to the max sequence length in each batch + self.padding = False + + if self.data_args.max_seq_length > self.tokenizer.model_max_length: + logger.warning( + f"The max_seq_length passed ({self.data_args.max_seq_length}) is larger than the maximum length for the" + f"model ({self.tokenizer.model_max_length}). Using max_seq_length={self.tokenizer.model_max_length}." + ) + self.max_seq_length = min(self.data_args.max_seq_length, self.tokenizer.model_max_length) + + + # Data collator will default to DataCollatorWithPadding when the tokenizer is passed to Trainer, so we change it if + # we already did the padding. + if self.data_args.pad_to_max_length: + data_collator = default_data_collator + elif self.training_args.fp16: + data_collator = DataCollatorWithPadding(self.tokenizer, pad_to_multiple_of=8) + else: + data_collator = None + + # Initialize our Trainer + self.trainer = Trainer( + model=model, + args=self.training_args, + train_dataset=None, + eval_dataset=None, + compute_metrics=None, + tokenizer=self.tokenizer, + data_collator=data_collator, + # callbacks=[EarlyStoppingCallback(5, 0.01)] + ) + + + def preprocess_function(self, examples, attributes): + # Tokenize the texts + result = self.tokenizer(examples['text'], + padding=self.padding, + max_length=self.max_seq_length, + truncation=True) + if 'labels' in examples: + for idx in range(len(examples['labels'])): + att_value = OrderedDict() + for order, att in enumerate(attributes): + att_value[att] = examples['labels'][idx][order].index(1) + examples['labels'][idx] = deepcopy(att_value) + result['labels']= examples['labels'] + return result + + + def predict(self): + # Loading a dataset from local json files + raw_datasets = load_dataset( + "json", + data_files=self.data_files, + cache_dir=self.model_args.cache_dir, + use_auth_token=True if self.model_args.use_auth_token else None, + ) + + with self.training_args.main_process_first(desc="dataset map pre-processing"): + raw_datasets = raw_datasets.map( + partial(self.preprocess_function, attributes=self.attributes), + batched=True, + load_from_cache_file=not self.data_args.overwrite_cache, + desc="Running tokenizer on dataset", + ) + + + if "test" not in raw_datasets: + raise ValueError("--do_predict requires a test dataset") + predict_dataset = raw_datasets["test"] + if self.data_args.max_predict_samples is not None: + max_predict_samples = min(len(predict_dataset), self.data_args.max_predict_samples) + predict_dataset = predict_dataset.select(range(max_predict_samples)) + + logger.info("*** Predict ***") + predictions = self.trainer.predict(predict_dataset, metric_key_prefix="predict").predictions + result_output = {} + softmaxprobs = {} + for k,v in predictions.items(): + pred = np.argmax(predictions[k], axis=1) + softmax = torch.nn.Softmax(dim=1) + softmaxprobs[k] = softmax(torch.from_numpy(predictions[k])).tolist() + result_output[k] = np.zeros(predictions[k].shape, dtype=np.int8) + for p in range(len(pred)): + result_output[k][p, pred[p]] = 1 + result_output[k] = result_output[k].tolist() + json.dump(result_output, open(os.path.join(self.training_args.output_dir, f"predict_attributes.json"),'w')) + json.dump(softmaxprobs, open(os.path.join(self.training_args.output_dir, f"softmax_probs.json"),'w')) + + if __name__ == "__main__": main() diff --git a/src/control/musecoco/text2attribute_model/stage2_pre.py b/src/control/musecoco/text2attribute_model/stage2_pre.py index ee36225..1fc6705 100644 --- a/src/control/musecoco/text2attribute_model/stage2_pre.py +++ b/src/control/musecoco/text2attribute_model/stage2_pre.py @@ -1,52 +1,56 @@ import json, pickle from copy import deepcopy -test = json.load(open('data/predict.json','r')) -pred = json.load(open('tmp/predict_attributes.json','r')) -probs = json.load(open('tmp/softmax_probs.json','r')) -att_key = json.load(open('data/att_key.json','r')) +def prepare_data(): + test = json.load(open('storage/input/predict.json','r')) + pred = json.load(open('storage/tmp/predict_attributes.json','r')) + probs = json.load(open('storage/tmp/softmax_probs.json','r')) + att_key = json.load(open('src/control/musecoco/text2attribute_model/data/att_key.json','r')) -final = [] -for line in test: - ins = {} - ins['text'] = line['text'] - ins['pred_labels'] = {} - ins['pred_probs'] = {} - final.append(deepcopy(ins)) + final = [] + for line in test: + ins = {} + ins['text'] = line['text'] + ins['pred_labels'] = {} + ins['pred_probs'] = {} + final.append(deepcopy(ins)) -for k, v in pred.items(): - for j in range(len(v)): - final[j]['pred_labels'][k] = deepcopy(v[j]) -for k, v in probs.items(): - for j in range(len(v)): - final[j]['pred_probs'][k] = deepcopy(v[j]) + for k, v in pred.items(): + for j in range(len(v)): + final[j]['pred_labels'][k] = deepcopy(v[j]) + for k, v in probs.items(): + for j in range(len(v)): + final[j]['pred_probs'][k] = deepcopy(v[j]) -I1s2_key = [] -S4_key = [] -for att in att_key: - if att[:4]=="I1s2": - I1s2_key.append(att) - if att[:2]=="S4": - S4_key.append(att) + I1s2_key = [] + S4_key = [] + for att in att_key: + if att[:4]=="I1s2": + I1s2_key.append(att) + if att[:2]=="S4": + S4_key.append(att) -for idx in range(len(final)): - pred_labels_I1s2 = [] - pred_probs_I1s2 = [] - pred_labels_S4 = [] - pred_probs_S4 = [] - for i1s2 in I1s2_key: - pred_labels_I1s2.append(deepcopy(final[idx]['pred_labels'][i1s2])) - pred_probs_I1s2.append(deepcopy(final[idx]['pred_probs'][i1s2])) - final[idx]['pred_labels'].pop(i1s2) - final[idx]['pred_probs'].pop(i1s2) - for s4 in S4_key: - pred_labels_S4.append(deepcopy(final[idx]['pred_labels'][s4])) - pred_probs_S4.append(deepcopy(final[idx]['pred_probs'][s4])) - final[idx]['pred_labels'].pop(s4) - final[idx]['pred_probs'].pop(s4) - final[idx]['pred_probs']['I1s2'] = deepcopy(pred_probs_I1s2) - final[idx]['pred_probs']['S4'] = deepcopy(pred_probs_S4) - final[idx]['pred_labels']['I1s2'] = deepcopy(pred_labels_I1s2) - final[idx]['pred_labels']['S4'] = deepcopy(pred_labels_S4) + for idx in range(len(final)): + pred_labels_I1s2 = [] + pred_probs_I1s2 = [] + pred_labels_S4 = [] + pred_probs_S4 = [] + for i1s2 in I1s2_key: + pred_labels_I1s2.append(deepcopy(final[idx]['pred_labels'][i1s2])) + pred_probs_I1s2.append(deepcopy(final[idx]['pred_probs'][i1s2])) + final[idx]['pred_labels'].pop(i1s2) + final[idx]['pred_probs'].pop(i1s2) + for s4 in S4_key: + pred_labels_S4.append(deepcopy(final[idx]['pred_labels'][s4])) + pred_probs_S4.append(deepcopy(final[idx]['pred_probs'][s4])) + final[idx]['pred_labels'].pop(s4) + final[idx]['pred_probs'].pop(s4) + final[idx]['pred_probs']['I1s2'] = deepcopy(pred_probs_I1s2) + final[idx]['pred_probs']['S4'] = deepcopy(pred_probs_S4) + final[idx]['pred_labels']['I1s2'] = deepcopy(pred_labels_I1s2) + final[idx]['pred_labels']['S4'] = deepcopy(pred_labels_S4) -pickle.dump(final, open('infer_test.bin','wb')) \ No newline at end of file + pickle.dump(final, open('infer_test.bin','wb')) + +if __name__ == "__main__": + prepare_data() \ No newline at end of file diff --git a/storage/input/predict.json b/storage/input/predict.json new file mode 100644 index 0000000..7400229 --- /dev/null +++ b/storage/input/predict.json @@ -0,0 +1,5 @@ +[ + { + "text": "The unique and resonant sound of this music is conveyed through its use of the minor key and the grand piano and strings used to create it. With a fast tempo and sullenness-laden projection, this song consists of about 14 bars, bringing together a complete musical experience that is both powerful and memorable. Its pitch range is within 2 octaves." + } +] \ No newline at end of file diff --git a/storage/input/predict_backup.json b/storage/input/predict_backup.json new file mode 100644 index 0000000..5a39e5c --- /dev/null +++ b/storage/input/predict_backup.json @@ -0,0 +1,11 @@ +[ + { + "text": "This music has a meter of 4/4 and a balanced beat. Its playtime is about 40 seconds. The use of grand piano, guitar, bass, violin, synthesizer and drum is vital to the music's overall sound and performance. The song spans approximately 13 ~ 16 bars." + }, + { + "text": "The music's limited pitch range of 5 octaves allows for a greater emphasis on the nuances of tone and phrasing, while its use of major key creates a distinct atmosphere. With a runtime of 31 ~ 45 seconds, this song showcases a highly vigorous rhythm and features grand piano, guitar, bass, violin, synthesizer and drum. It is played at a moderate speed, adhering to a 4/4 meter, and is characterized by its religious and pop sound." + }, + { + "text": "The unique and resonant sound of this music is conveyed through its use of the minor key and the grand piano and strings used to create it. With a fast tempo and sullenness-laden projection, this song consists of about 14 bars, bringing together a complete musical experience that is both powerful and memorable. Its pitch range is within 2 octaves." + } +] \ No newline at end of file