Skip to content

Commit

Permalink
complete musecoco refactoring, progress to next stage
Browse files Browse the repository at this point in the history
  • Loading branch information
yhbcode000 committed Sep 10, 2024
1 parent 96c3567 commit 1bc4627
Show file tree
Hide file tree
Showing 10 changed files with 912 additions and 87 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -164,4 +164,5 @@ cython_debug/
.vscode
*.pt
tmp
generation
generation
log/*
402 changes: 379 additions & 23 deletions inference.ipynb

Large diffs are not rendered by default.

38 changes: 21 additions & 17 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
import os
import sys
import shutil
from src.control.musecoco.text2attribute_model import main
from src.control.musecoco.text2attribute_model import Text2AttributePredictor, prepare_data
from src.control.musecoco.attribute2music_model import interactive_dict_v5_1billion
from src.control.musecoco.text2attribute_model import stage2_pre # Import the stage2 script

def text2attribute():
def init_text2attribute():
# Step 1: Simulate terminal input by modifying sys.argv for text2attribute model
# Define variables
model_name_or_path = "IreneXu/MuseCoco_text2attribute"
test_file = "src/control/musecoco/text2attribute_model/data/predict.json"
test_file = "storage/input/predict.json"
attributes_file = "src/control/musecoco/text2attribute_model/data/att_key.json"
num_labels_file = "src/control/musecoco/text2attribute_model/num_labels.json"
output_dir = "src/control/musecoco/text2attribute_model//tmp"
output_dir = "storage/tmp"

# Convert Python variables into sys.argv format
sys.argv = [
Expand All @@ -27,26 +26,27 @@ def text2attribute():
]

# Call the main function to process simulated inputs
main()
predictor = Text2AttributePredictor()

return predictor

def prepare_stage2():
# Step 2: Prepare intermediate data by executing necessary scripts
# Move to the directory for text2attribute model processing
# Run `stage2_pre.py` - you mentioned it's a script that can be imported
stage2_pre()
prepare_data()

# Move generated `infer_test.bin` to the appropriate directory
source_path = "src/control/musecoco/text2attribute_model/infer_test.bin"
destination_path = "src/control/musecoco/attribute2music_model/data/infer_input/infer_test.bin"
source_path = "infer_test.bin"
destination_path = "storage/tmp/infer_test.bin"
os.makedirs(os.path.dirname(destination_path), exist_ok=True)
shutil.move(source_path, destination_path)

def attribute2midi():
def init_attribute2midi():
# Step 3: Set up variables for attribute2music model
start, end = 0, 100 # Example values for start and end
model_size = "1billion"
k = 15
command_name = "infer_test"
need_num = 2
temp = 1.0
ngram = 0
Expand All @@ -60,9 +60,9 @@ def attribute2midi():
# Step 4: Define paths
DATA_DIR = f"src/control/musecoco/attribute2music_model/data/{datasets_name}"
checkpoint_path = f"src/control/musecoco/attribute2music_model/checkpoints/{model_name}/{checkpoint_name}.pt"
ctrl_command_path = f"src/control/musecoco/attribute2music_model/data/infer_input/{command_name}.bin"
save_root = f"src/control/musecoco/attribute2music_model/generation/{date}/{model_name}-{checkpoint_name}/{command_name}/topk{k}-t{temp}-ngram{ngram}"
log_root = f"src/control/musecoco/attribute2music_model/log/{date}/{model_name}"
ctrl_command_path = f"storage/tmp/infer_test.bin"
save_root = f"storage/generation/{date}/{model_name}-{checkpoint_name}/topk{k}-t{temp}-ngram{ngram}"
log_root = f"storage/log/{date}/{model_name}"

# Step 5: Set environment variables
os.environ["CUDA_VISIBLE_DEVICES"] = device
Expand Down Expand Up @@ -95,9 +95,13 @@ def attribute2midi():

# Step 9: Call cli_main with modified arguments
interactive_dict_v5_1billion.seed_everything(2024) # Set random seed
interactive_dict_v5_1billion.cli_main()

return interactive_dict_v5_1billion.Attribute2MusicPredictor()

if __name__ == "__main__":
text2attribute()
text2attribute_predictor = init_text2attribute()
attribute2midi_predictor = init_attribute2midi()

text2attribute_predictor.predict()
prepare_stage2()
attribute2midi()
attribute2midi_predictor.predict()
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def build_generator(
compute_alignment=getattr(args, "print_alignment", False),
)

from command_seq_generator import CommandSequenceGenerator
from .command_seq_generator import CommandSequenceGenerator

# Choose search strategy. Defaults to Beam Search.
sampling = getattr(args, "sampling", False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,266 @@ def cli_main():
# else:
# label_embedding(args)

class Attribute2MusicPredictor:
def __init__(self):
parser = options.get_interactive_generation_parser()
parser.add_argument("--save_root", type=str)
parser.add_argument("--need_num", type=int, default=32)
parser.add_argument("--ctrl_command_path", type=str, default="")
parser.add_argument("--start", type = int, default=None)
parser.add_argument("--end", type = int, default=None)
parser.add_argument("--use_gold_labels", type = int, default=0)
args = options.parse_args_and_arch(parser)

self.args = args

args = self.args

start_time = time.time()
self.total_translate_time = 0

utils.import_user_module(args)

if args.buffer_size < 1:
args.buffer_size = 1
if args.max_tokens is None and args.batch_size is None:
args.batch_size = 1

assert (
not args.sampling or args.nbest == args.beam
), "--sampling requires --nbest to be equal to --beam"
assert (
not args.batch_size or args.batch_size <= args.buffer_size
), "--batch-size cannot be larger than --buffer-size"

logger.info(args)

# Fix seed for stochastic decoding
if args.seed is not None and not args.no_seed_provided:
np.random.seed(args.seed)
utils.set_torch_seed(args.seed)

use_cuda = torch.cuda.is_available() and not args.cpu

# Setup task, e.g., translation_control
task = tasks.setup_task(args)

# Load ensemble
logger.info("loading model(s) from {}".format(args.path))
models, _model_args = checkpoint_utils.load_model_ensemble(
args.path.split(os.pathsep),
arg_overrides=eval(args.model_overrides),
task=task,
suffix=getattr(args, "checkpoint_suffix", ""),
strict=(args.checkpoint_shard_count == 1),
num_shards=args.checkpoint_shard_count,
)

for model in models:
if args.fp16:
model.half()
if use_cuda and not args.pipeline_model_parallel:
model.cuda()
model.prepare_for_inference_(args)
model.decoder.args.is_inference = True

# Set dictionaries
src_dict = task.source_dictionary
tgt_dict = task.target_dictionary

# Initialize generator
generator = task.build_generator(models, args)

# Handle tokenization and BPE
tokenizer = encoders.build_tokenizer(args)
bpe = encoders.build_bpe(args)

def encode_fn(x):
if tokenizer is not None:
x = tokenizer.encode(x)
if bpe is not None:
x = bpe.encode(x)
return x

def decode_fn(x):
if bpe is not None:
x = bpe.decode(x)
if tokenizer is not None:
x = tokenizer.decode(x)
return x

# Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary)
align_dict = utils.load_align_dict(args.replace_unk)

max_positions = utils.resolve_max_positions(
task.max_positions(), *[model.max_positions() for model in models]
)

if args.constraints:
logger.warning(
"NOTE: Constrained decoding currently assumes a shared subword vocabulary."
)

if args.buffer_size > 1:
logger.info("Sentence buffer size: %s", args.buffer_size)
logger.info("NOTE: hypothesis and token scores are output in base 2")
logger.info("Type the input sentence and press return:")
start_id = 0

# for inputs in buffered_read(args.input, args.buffer_size):
self.save_root = args.save_root
os.makedirs(self.save_root, exist_ok=True)
midi_decoder = MidiDecoder("REMIGEN2")

# test_command = np.load("../Text2Music_data/v2.1_20230218/full_0218_filter_by_5866/infer_command_balanced.npy",
# allow_pickle=True).item()
# test_command = np.load(args.ctrl_command_path, allow_pickle=True).item()

# test_command = json.load(open(args.ctrl_command_path, "r"))
if args.use_gold_labels:
with open(args.save_root + "/Using_gold_labels!.txt", "w") as check_input:
pass
else:
with open(args.save_root + "/Using_pred_labels!.txt", "w") as check_input:
pass

self.task = task
self.max_positions = max_positions
self.encode_fn = encode_fn
self.use_cuda = use_cuda
self.generator = generator
self.models = models
self.tgt_dict = tgt_dict
self.start_id = start_id
self.src_dict = src_dict
self.align_dict = align_dict
self.tgt_dict = tgt_dict
self.midi_decoder = midi_decoder


def predict(self):
args = self.args
test_command = pickle.load(open(args.ctrl_command_path, "rb"))
if args.start is None:
args.start = 0
args.end = len(test_command)
else:
args.start = min(max(args.start, 0), len(test_command))
args.end = min(max(args.end, 0), len(test_command))

gen_command_list = []
for j in range(args.need_num):
for i in range(args.start, args.end):
if args.use_gold_labels:
pred_labels = test_command[i]["gold_labels"]
else:
pred_labels = test_command[i]["pred_labels"]
attribute_tokens = convert_vector_to_token(pred_labels)
test_command[i]["infer_command_tokens"] = attribute_tokens
gen_command_list.append([test_command[i]["infer_command_tokens"], f"{i}", j, test_command[i]])

steps = len(gen_command_list) // args.batch_size
print(f"Starts to generate {args.start} to {args.end} of {len(gen_command_list)} samples in {steps + 1} batch steps!")


for batch_step in range(steps + 1):
infer_list = gen_command_list[batch_step*args.batch_size:(batch_step+1)*args.batch_size]
infer_command_token = [g[0] for g in infer_list]
# assert infer_command.shape[1] == 133, f"error feature dim for {gen_key}!"
if len(infer_list) == 0:
continue
if os.path.exists(self.save_root + f"/{infer_list[-1][1]}/remi/{infer_list[-1][2]}.txt"):
print(f"Skip the {batch_step}-th batch since has been generated!")
continue

# start_tokens = [f""]
start_tokens = []
sep_pos = []
for attribute_prefix in infer_command_token:
start_tokens.append(" ".join(attribute_prefix) + " <sep>")
sep_pos.append(len(attribute_prefix)) # notice that <sep> pos is len(attribute_prefix) in this sequence
sep_pos = np.array(sep_pos)
for inputs in [start_tokens]: # "" for none prefix input
results = []
for batch in make_batches(inputs, args, self.task, self.max_positions, self.encode_fn):
bsz = batch.src_tokens.size(0)
src_tokens = batch.src_tokens
src_lengths = batch.src_lengths
constraints = batch.constraints

if self.use_cuda:
src_tokens = src_tokens.cuda()
src_lengths = src_lengths.cuda()
if constraints is not None:
constraints = constraints.cuda()

sample = {
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
"sep_pos": sep_pos,
},
}
translate_start_time = time.time()
translations = self.task.inference_step(
self.generator, self.models, sample, constraints=constraints
)
translate_time = time.time() - translate_start_time
self.total_translate_time += translate_time
list_constraints = [[] for _ in range(bsz)]
if args.constraints:
list_constraints = [unpack_constraints(c) for c in constraints]

for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
src_tokens_i = utils.strip_pad(src_tokens[i], self.tgt_dict.pad())
constraints = list_constraints[i]
results.append(
(
self.start_id + id,
src_tokens_i,
hypos,
{
"constraints": constraints,
"time": translate_time / len(translations),
"translation_shape":len(translations),
},
)
)

# sort output to match input order
for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]):
if self.src_dict is not None:
src_str = self.src_dict.string(src_tokens, args.remove_bpe)
# Process top predictions
for hypo in hypos[: min(len(hypos), args.nbest)]:
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo["tokens"].int().cpu(),
src_str=src_str,
alignment=hypo["alignment"],
align_dict=self.align_dict,
tgt_dict=self.tgt_dict,
remove_bpe=args.remove_bpe,
extra_symbols_to_ignore=get_symbols_to_strip_from_output(self.generator),
)

os.makedirs(self.save_root + f"/{infer_list[id_][1]}", exist_ok=True)
if not os.path.exists(self.save_root + f"/{infer_list[id_][1]}/infer_command.json"):
with open(self.save_root + f"/{infer_list[id_][1]}/infer_command.json", "w") as f:
json.dump(infer_list[id_][-1], f)
save_id = infer_list[id_][2]

os.makedirs(self.save_root + f"/{infer_list[id_][1]}/remi", exist_ok=True)
with open(self.save_root + f"/{infer_list[id_][1]}/remi/{save_id}.txt", "w") as f:
f.write(hypo_str)
remi_token = hypo_str.split(" ")[sep_pos[id_] + 1:]
print(f"batch:{batch_step} save_id:{save_id} over with length {len(hypo_str.split(' '))}; "
f"Average translation time:{info['time']} seconds; Remi seq length: {len(remi_token)}; Batch size:{args.batch_size}; \
Translation shape:{info['translation_shape']}.")
os.makedirs(self.save_root + f"/{infer_list[id_][1]}/midi", exist_ok=True)
midi_obj = self.midi_decoder.decode_from_token_str_list(remi_token)
midi_obj.dump(self.save_root + f"/{infer_list[id_][1]}/midi/{save_id}.mid")


if __name__ == "__main__":
seed_everything(2024) # 2023
Expand Down
3 changes: 2 additions & 1 deletion src/control/musecoco/text2attribute_model/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .main import main
from .main import Text2AttributePredictor
from .stage2_pre import prepare_data
Loading

0 comments on commit 1bc4627

Please sign in to comment.