Skip to content

Commit

Permalink
Merge pull request #14 from wingedsheep/feature/tidying_up
Browse files Browse the repository at this point in the history
Feature/tidying up
  • Loading branch information
wingedsheep authored Dec 11, 2021
2 parents ec935cb + 83b464e commit a2cd4f5
Show file tree
Hide file tree
Showing 22 changed files with 806 additions and 885 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## 0.4.0 - 2021-12-11

### Added
- More configuration options for all models.
- Added instrument mapping.

## 0.3.2 - 2021-12-03

### Added
Expand Down
39 changes: 17 additions & 22 deletions example/compound_word_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,26 @@
from mgt.datamanagers.compound_word_data_manager import CompoundWordDataManager
from mgt.models.compound_word_transformer_model import CompoundWordTransformerModel

"""
Example showing how to train a compound word model and generate new music with it.
"""

def run():
"""
Example showing how to train a compound word model and generate new music with it.
"""
midi_path = '../data/pop/'
midi_path = os.path.join(os.path.dirname(__file__), midi_path)
midis = glob.glob(midi_path + '*.mid')

midi_path = '../data/pop/'
midi_path = os.path.join(os.path.dirname(__file__), midi_path)
midis = glob.glob(midi_path + '*.mid')
data_manager = CompoundWordDataManager()
dataset = data_manager.prepare_data(midis)

data_manager = CompoundWordDataManager()
dataset = data_manager.prepare_data(midis)
model = CompoundWordTransformerModel()

model = CompoundWordTransformerModel()
print("Created model. Starting training for 50 epochs.")
model.train(x_train=dataset.data, epochs=50, stop_loss=0.1)

print("Created model. Starting training for 50 epochs.")
model.train(x_train=dataset.data, epochs=50, stop_loss=0.1)
# Generate music
print("Generating music.")
output = model.generate(1000)

# Generate music
print("Generating music.")
output = model.generate(1000)

# Restore events from input data
midi = data_manager.to_midi(output)
midi.save("result.midi")


run()
# Restore events from input data
midi = data_manager.to_midi(output)
midi.save("result.midi")
37 changes: 16 additions & 21 deletions example/from_and_to_midi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,21 @@

from mgt.datamanagers.time_shift_data_manager import TimeShiftDataManager

"""
Example showing how to parse midi files with different encodings.
Also outputs midi files showing how these encodings sound when converted back.
"""
midi_path = '../data/TheWeeknd-BlindingLights.midi'
midi_path = os.path.join(os.path.dirname(__file__), midi_path)

def run():
"""
Example showing how to parse midi files with different encodings.
Also outputs midi files showing how these encodings sound when converted back.
"""
midi_path = '../data/TheWeeknd-BlindingLights.midi'
midi_path = os.path.join(os.path.dirname(__file__), midi_path)
# Parse midi using timeshift method
timeshift_data_manager = TimeShiftDataManager()
timeshift_dataset = timeshift_data_manager.prepare_data([midi_path])
timeshift_midi = timeshift_data_manager.to_midi(timeshift_dataset.data[0])
timeshift_midi.save("timeshift.midi")

# Parse midi using timeshift method
timeshift_data_manager = TimeShiftDataManager()
timeshift_dataset = timeshift_data_manager.prepare_data([midi_path])
timeshift_midi = timeshift_data_manager.to_midi(timeshift_dataset.data[0])
timeshift_midi.save("timeshift.midi")

# Parse midi using remi method
remi_data_manager = RemiDataManager()
remi_dataset = remi_data_manager.prepare_data([midi_path])
remi_midi = remi_data_manager.to_midi(remi_dataset.data[0])
remi_midi.save("remi.midi")


run()
# Parse midi using remi method
remi_data_manager = RemiDataManager()
remi_dataset = remi_data_manager.prepare_data([midi_path])
remi_midi = remi_data_manager.to_midi(remi_dataset.data[0])
remi_midi.save("remi.midi")
29 changes: 12 additions & 17 deletions example/save_and_load_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,19 @@
from mgt.datamanagers.remi_data_manager import RemiDataManager


def run():
"""
Example showing how to save and load a created dataset.
Saving can be done in steps for large datasets that take a lot of time to parse.
"""
midi_path = '../data/pop/'
midi_path = os.path.join(os.path.dirname(__file__), midi_path)
midis = glob.glob(midi_path + '*.mid')
"""
Example showing how to save and load a created dataset.
Saving can be done in steps for large datasets that take a lot of time to parse.
"""

datamanager = RemiDataManager()
dataset = datamanager.prepare_data(midis)
DataHelper.save(dataset, 'test_dataset')
midi_path = '../data/pop/'
midi_path = os.path.join(os.path.dirname(__file__), midi_path)
midis = glob.glob(midi_path + '*.mid')

loaded_dataset = DataHelper.load('test_dataset')
print(len(loaded_dataset.data)) # Should contain the 10 parsed midis
datamanager = RemiDataManager()
dataset = datamanager.prepare_data(midis)
DataHelper.save(dataset, 'test_dataset')

run()
loaded_dataset = DataHelper.load('test_dataset')
print(len(loaded_dataset.data)) # Should contain the 10 parsed midis


# dataset = DataHelper.load('test_dataset')
# print(dataset.dictionary) # Should contain the 10 parsed midis
20 changes: 7 additions & 13 deletions example/save_and_load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,10 @@

from mgt.models.transformer_model import TransformerModel


def run():
"""
Example showing how to save and load a model.
"""
dictionary = DictionaryGenerator.create_dictionary();
model = TransformerModel(dictionary)
model.save_checkpoint("test_model")
model2 = TransformerModel.load_checkpoint("test_model")
print(model2.generate(1))


run()
"""
Example showing how to save and load a model.
"""
dictionary = DictionaryGenerator.create_dictionary();
model = TransformerModel(dictionary)
model.save_checkpoint("test_model")
model2 = TransformerModel.load_checkpoint("test_model")
5 changes: 0 additions & 5 deletions example/test.py

This file was deleted.

41 changes: 16 additions & 25 deletions example/training_example.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,25 @@
from mgt.datamanagers.time_shift_data_manager import TimeShiftDataManager
from mgt.datamanagers.remi_data_manager import RemiDataManager
from mgt.models.transformer_model import TransformerModel

import os
import glob

# Collect the midi paths
midi_path = 'YOUR MIDI PATH'
midi_path = os.path.join(os.path.dirname(__file__), midi_path)
midis = glob.glob(midi_path + '*.mid')

def run():
"""
Example showing how to train a new model and generate new music with it.
"""
# Create the datamanager and prepare the data
datamanager = RemiDataManager()
dataset = datamanager.prepare_data(midis)

midi_path = '../data/pop/'
midi_path = os.path.join(os.path.dirname(__file__), midi_path)
midis = glob.glob(midi_path + '*.mid')
# Create and train the model
model = TransformerModel(dataset.dictionary)
model.train(x_train=dataset.data, epochs=50, stop_loss=0.1)

time_shift_data_manager = TimeShiftDataManager()
dataset = time_shift_data_manager.prepare_data(midis)
# Generate music
output = model.generate(1000)

model = TransformerModel(dataset.dictionary)

print("Created model. Starting training for 50 epochs.")
model.train(x_train=dataset.data, epochs=50, stop_loss=0.1)

# Generate music
print("Generating music.")
output = model.generate(1000)

# Restore events from input data
midi = time_shift_data_manager.to_midi(output)
midi.save("result.midi")


run()
# Restore events from input data
midi = datamanager.to_midi(output)
midi.save("result.midi")
45 changes: 32 additions & 13 deletions mgt/datamanagers/compound_word_data_manager.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,52 @@
from pretty_midi import pretty_midi

from mgt.datamanagers.compound_word.compound_word_mapper import CompoundWordMapper
from mgt.datamanagers.data_manager import DataManager, DataSet
from mgt.datamanagers.midi_wrapper import MidiWrapper, MidiToolkitWrapper
from mgt.datamanagers.remi import util
from mgt.datamanagers.remi.data_extractor import DataExtractor
from mgt.datamanagers.remi.dictionary_generator import DictionaryGenerator
from mgt.datamanagers.remi.to_midi_mapper import ToMidiMapper


defaults = {
'transposition_steps': [0],
'map_tracks_to_instruments': {}
}


class CompoundWordDataManager(DataManager):
"""
transposition_steps: Transposed copies of the data to include. For example [-1, 0, 1] has a copy that is transposed
One semitone down, once the original track, and once transposed one semitone up.
map_tracks_to_instruments: Whether to map certain track numbers to instruments. For example {0=0, 1=25} maps
track 0 to a grand piano, and track 1 to an acoustic guitar.
instrument_mapping: Maps instruments to different instruments. For example {1:0, 2:0, 3:0, 4:0, 5:0, 6:0, 7:0, 8:0}
maps all piano-like instruments to a grand piano. Mapping to None removes the instrument entirely.
"""

def __init__(self, transposition_steps=None, map_tracks_to_instruments=None):
if map_tracks_to_instruments is None:
map_tracks_to_instruments = {}
if transposition_steps is None:
transposition_steps = [0]
def __init__(
self,
transposition_steps=defaults['transposition_steps'],
map_tracks_to_instruments=defaults['map_tracks_to_instruments'],
instrument_mapping=defaults['instrument_mapping']
):
self.transposition_steps = transposition_steps
self.map_tracks_to_instruments = map_tracks_to_instruments
self.instrument_mapping = instrument_mapping
self.dictionary = DictionaryGenerator.create_dictionary()
self.compound_word_mapper = CompoundWordMapper(self.dictionary)
self.data_extractor = DataExtractor(
dictionary=self.dictionary,
map_tracks_to_instruments=self.map_tracks_to_instruments,
use_chords=False,
instrument_mapping=self.instrument_mapping
)
self.to_midi_mapper = ToMidiMapper(self.dictionary)

def prepare_data(self, midi_paths) -> DataSet:
training_data = []
for path in midi_paths:
for transposition_step in self.transposition_steps:
try:
data = util.extract_words(path,
transposition_steps=transposition_step,
map_tracks_to_instruments=self.map_tracks_to_instruments,
use_chords=False)
data = self.data_extractor.extract_data(path, transposition_step)

compound_words = self.compound_word_mapper.map_to_compound(data, self.dictionary)
compound_data = self.compound_word_mapper.map_compound_words_to_data(compound_words)
Expand All @@ -46,4 +65,4 @@ def to_remi(self, data):

def to_midi(self, data) -> MidiWrapper:
remi = self.compound_word_mapper.map_to_remi(data)
return MidiToolkitWrapper(util.to_midi(remi, self.dictionary))
return MidiToolkitWrapper(self.to_midi_mapper.to_midi(remi))
11 changes: 11 additions & 0 deletions mgt/datamanagers/remi/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import numpy as np

DRUM_INSTRUMENT = 128

DEFAULT_VELOCITY_BINS = np.linspace(0, 128, 32 + 1, dtype=np.int)
DEFAULT_FRACTION = 16
DEFAULT_DURATION_BINS = np.arange(60, 3841, 60, dtype=int)
DEFAULT_TEMPO_INTERVALS = [range(30, 90), range(90, 150), range(150, 210)]

# parameters for outputItem
DEFAULT_RESOLUTION = 480
Loading

0 comments on commit a2cd4f5

Please sign in to comment.