diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9a40396 --- /dev/null +++ b/.gitignore @@ -0,0 +1,135 @@ +# Removing datasets + +wavlm/*.pt + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ +generate_wavlm_reps.ipynb +generate_wavlm_reps.ipynb diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..e290f98 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,28 @@ +FROM nvidia/cuda:12.2.0-devel-ubuntu22.04 + + +ENV PATH="/root/miniconda3/bin:${PATH}" +ARG PATH="/root/miniconda3/bin:${PATH}" + +RUN apt-get update +RUN apt-get install -y wget git nano ffmpeg + +RUN wget \ + https://repo.anaconda.com/miniconda/Miniconda3-py37_4.8.3-Linux-x86_64.sh \ + && mkdir /root/.conda \ + && bash Miniconda3-py37_4.8.3-Linux-x86_64.sh -b \ + && rm -f Miniconda3-py37_4.8.3-Linux-x86_64.sh + +RUN conda --version + +WORKDIR /root +COPY environment.yml /root + +RUN conda install tqdm -f +RUN conda update conda +RUN conda install pip +RUN conda --version +RUN conda env create -f environment.yml + +SHELL ["conda", "run", "-n", "ggvad", "/bin/bash", "-c"] +RUN pip install git+https://github.com/openai/CLIP.git diff --git a/README.md b/README.md index 28ce97f..fff6ad0 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,99 @@ # stylistic-gesture -Stylistic Co-Speech Gesture Generation: Modeling Personality and Communicative Styles in Virtual Agents +Official repository for the paper Stylistic Co-Speech Gesture Generation: Modeling Personality and Communicative Styles in Virtual Agents. + +## Preparing environment + +1. Git clone this repo + +2. Enter the repo and create docker image using + +```sh +docker build -t ggvad . +``` + +3. Run container using + +```sh +docker run --rm -it --gpus device=GPU_NUMBER --userns=host --shm-size 64G -v /MY_DIR/ggvad-genea2023:/workspace/ggvad/ -p PORT_NUMBR --name CONTAINER_NAME ggvad:latest /bin/bash +``` + +for example: +```sh +docker run --rm -it --gpus device=0 --userns=host --shm-size 64G -v C:\ProgramFiles\ggvad-genea2023:/workspace/my_repo -p '8888:8888' --name my_container ggvad:latest /bin/bash +``` + +> ### Cuda version < 12.0: +> +> If you have a previous cuda or nvcc release version you will need to adjust the Dockerfile. Change the first line to `FROM pytorch/pytorch:1.6.0-cuda10.1-cudnn7-devel` and remove lines 10-14 (conda is already installed in the pythorch image). Then, run container using: +> +> ```sh +> nvidia-docker run --rm -it -e NVIDIA_VISIBLE_DEVICES=GPU_NUMBER --runtime=nvidia --userns=host --shm-size 64G -v /work/rodolfo.tonoli/GestureDiffusion:/workspace/gesture-diffusion/ -p $port --name gestdiff_container$number multimodal-research-group-mdm:latest /bin/bash +> ``` + + +OR use the shell script ggvad_container.sh (don't forget to change the volume) using the flags -g, -n, and -p +example: +```sh +sh ggvad_container.sh -g 0 -n my_container -p '8888:8888' +``` + +4. Activate cuda environment: +```sh +source activate ggvad +``` + +## Data pre-processing + +1. Get the GENEA Challenge 2023 dataset and put it into `./dataset/` +(Our system is monadic so you'll only need the main-agent's data) + +2. Download the [WavLM Base +](https://github.com/microsoft/unilm/tree/master/wavlm) and put it into the folder `/wavlm/` + +3. Inside the folder `/workspace/ggvad`, run + +```sh +python -m data_loaders.gesture.scripts.genea_prep +``` + +This will convert the bvh files to npy representations, downsample wav files to 16k and save them as npy arrays, and convert these arrays to wavlm representations. The VAD data must be processed separetely due to python libraries incompatibility. + +4. (Optional) Process VAD data + +We provide the speech activity information (from speechbrain's VAD) data, but if you wish to process them yourself you should redo the steps of "Preparing environment" as before, but for the speechbrain environment: Build the image using the Dockerfile inside speechbrain (`docker build -t speechbrain .`), run the container (`docker run ... --name CONTAINER_NAME speechbrain:latest /bin/bash`) and run: + +```sh +python -m data_loaders.gesture.scripts.genea_prep_vad +``` + +## Train model + +To train the model described in the paper use the following command inside the repo: + +```sh +python -m train.train_mdm --save_dir save/my_model_run --dataset genea2023+ --step 10 --use_text --use_vad True --use_wavlm True +``` + +## Gesture Generation + +Generate motion using the trained model by running the following command. If you wish to generate gestures with the pretrained model of the Genea Challenge, use `--model_path ./save/default_vad_wavlm/model000290000.pt` + +```sh +python -m sample.generate --model_path ./save/my_model_run/model000XXXXXX.pt +``` + +## Render + +To render the official Genea 2023 visualizations follow the instructions provided [here](https://github.com/TeoNikolov/genea_visualizer/) + +## Cite + +If you with to cite this repo or the paper + +```text +@article{tonoli2024stylistic, + author = {Tonoli, Rodolfo L. and Costa, Paula D. P.}, + title = {Stylistic Co-Speech Gesture Generation: Modeling Personality and Communicative Styles in Virtual Agents}, + journal = {N/A}, + year = {N/A}, +} +``` \ No newline at end of file diff --git a/data_loaders/__init__.py b/data_loaders/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data_loaders/gesture/data/__init__.py b/data_loaders/gesture/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data_loaders/gesture/data/dataset.py b/data_loaders/gesture/data/dataset.py new file mode 100644 index 0000000..72c5839 --- /dev/null +++ b/data_loaders/gesture/data/dataset.py @@ -0,0 +1,337 @@ +import torch +from torch.utils import data +import csv +import os +import numpy as np +from python_speech_features import mfcc +import librosa +import torch.nn.functional as F + +class Genea2023(data.Dataset): + def __init__(self, name, split='trn', datapath='./dataset/Genea2023/', step=30, window=80, fps=30, sr=22050, n_seed_poses=10, use_wavlm=False, use_vad=False, vadfromtext=False): + + self.split = split + if self.split not in ['trn', 'val', 'tst']: + raise ValueError('Split not recognized') + srcpath = os.path.join(datapath, self.split, 'main-agent/') + + if use_wavlm: + self.sr = 16000 + self.audiopath = os.path.join(srcpath, 'audio16k_npy') + else: + self.sr = sr + self.audiopath = os.path.join(srcpath, 'audio_npy') + + self.name = name + self.step = step + + self.datapath = datapath + self.window=window + self.fps = fps + self.n_seed_poses = n_seed_poses + + self.loadstats(os.path.join(datapath, 'trn/main-agent/')) + self.std = np.array([ item if item != 0 else 1 for item in self.std ]) + self.vel_std = np.array([ item if item != 0 else 1 for item in self.vel_std ]) + self.rot6dpos_std = np.array([ item if item != 0 else 1 for item in self.rot6dpos_std ]) + + if self.split in ['trn', 'val']: + self.motionpath = os.path.join(srcpath, 'motion_npy_rotpos') + self.motionpath_rot6d = os.path.join(srcpath, 'motion_npy_rot6dpos') + + + with open(os.path.join(srcpath, '../metadata.csv')) as csvfile: + reader = csv.reader(csvfile, delimiter=',') + self.takes = [take for take in reader] + self.takes = self.takes[1:] + for take in self.takes: + take[0] += '_main-agent' + + # Here we use the motion files to get the number of frames per take (there are some takes that audio files are longer for the trn and val sets) + self.frames = [] + if self.split in ['trn', 'val']: + for motionfile in self.takes: + motion = np.load(os.path.join(self.motionpath_rot6d, motionfile[0] + '.npy')) + self.frames.append(motion.shape[0]) + elif self.split in ['tst']: + for audiofile in os.listdir(self.audiopath): + if audiofile.endswith('.npy'): + audio = np.load(os.path.join(self.audiopath, audiofile)) + self.frames.append( int(audio.shape[0]/self.sr*self.fps)) + self.frames = np.array(self.frames) + self.frames = np.array(self.frames) + + self.samples_per_file = [int(np.floor( (n - self.window ) / self.step)) for n in self.frames] + self.samples_cumulative = [np.sum(self.samples_per_file[:i+1]) for i in range(len(self.samples_per_file))] + self.length = self.samples_cumulative[-1] + self.textpath = os.path.join(srcpath, 'tsv') + + self.use_wavlm = use_wavlm + if self.use_wavlm: + self.wavlm_rep_path = os.path.join(srcpath, 'wavlm_representations') + + self.use_vad = use_vad + if self.use_vad: + self.vad_path = os.path.join(srcpath, "vad") + self.vadfromtext = vadfromtext + if self.vadfromtext: print('Getting speech activity from text') + + self.alljoints = {'body_world':0,'b_root':1,'b_spine0':2,'b_spine1':3,'b_spine2':4,'b_spine3':5,'b_neck0':6,'b_head':7,'b_head_null':8,'b_l_eye':9,'b_r_eye':10,'b_jaw':11,'b_jaw_null':12,'b_teeth':13,'b_tongue0':14,'b_tongue1':15,'b_tongue2':16,'b_tongue3':17,'b_tongue4':18,'b_l_tongue4_1':19,'b_r_tongue4_1':20,'b_l_tongue3_1':21,'b_r_tongue3_1':22,'b_l_tongue2_1':23,'b_r_tongue2_1':24,'b_r_tongue1_1':25,'b_l_tongue1_1':26,'b_r_shoulder':27,'p_r_scap':28,'b_r_arm':29,'b_r_arm_twist':30,'b_r_forearm':31,'b_r_wrist_twist':32,'b_r_wrist':33,'b_r_index1':34,'b_r_index2':35,'b_r_index3':36,'b_r_ring1':37,'b_r_ring2':38,'b_r_ring3':39,'b_r_middle1':40,'b_r_middle2':41,'b_r_middle3':42,'b_r_pinky1':43,'b_r_pinky2':44,'b_r_pinky3':45,'b_r_thumb0':46,'b_r_thumb1':47,'b_r_thumb2':48,'b_r_thumb3':49,'b_l_shoulder':50,'p_l_delt':51,'p_l_scap':52,'b_l_arm':53,'b_l_arm_twist':54,'b_l_forearm':55,'b_l_wrist_twist':56,'b_l_wrist':57,'b_l_thumb0':58,'b_l_thumb1':59,'b_l_thumb2':60,'b_l_thumb3':61,'b_l_index1':62,'b_l_index2':63,'b_l_index3':64,'b_l_middle1':65,'b_l_middle2':66,'b_l_middle3':67,'b_l_ring1':68,'b_l_ring2':69,'b_l_ring3':70,'b_l_pinky1':71,'b_l_pinky2':72,'b_l_pinky3':73,'p_navel':74,'b_r_upleg':75,'b_r_leg':76,'b_r_foot_twist':77,'b_r_foot':78,'b_l_upleg':79,'b_l_leg':80,'b_l_foot_twist':81,'b_l_foot':82} + + if False: + for take in self.takes: + name = take[0] + m = os.path.join(self.motionpath, name+'.npy') + a = os.path.join(self.audiopath, name+'.npy') + t = os.path.join(self.textpath, name+'.tsv') + assert os.path.isfile( m ), "Motion file {} not found".format(m) + assert os.path.isfile( a ), "Audio file {} not found".format(a) + assert os.path.isfile( t ), "Text file {} not found".format(t) + + def __getitem__(self, idx): + if self.split == 'tst': + raise ValueError('Test set does should not use __getitem__(), use gettestbatch() instead') + # find the file that the sample belongs two + file_idx = np.searchsorted(self.samples_cumulative, idx+1, side='left') + # find sample's index + if file_idx > 0: + sample = idx - self.samples_cumulative[file_idx-1] + else: + sample = idx + take_name = self.takes[file_idx][0] + motion, seed_poses = self.__getmotion( file_idx, sample) + audio, audio_rep = self.__getaudiofeats(file_idx, sample) + n_text, text, tokens, vad = self.__gettext(file_idx, sample) + if self.use_vad: + if not self.vadfromtext: + vad = self.__getvad(file_idx, sample) + else: + vad = np.ones(int(self.window)) # Dummy + return motion, text, self.window, audio, audio_rep, seed_poses, vad, take_name + + def __len__(self): + return self.length + + def __getvad(self, file, sample): + # Cut Chunk + vad_file = np.load(os.path.join(self.vad_path,self.takes[file][0]+'.npy')) + vad_vals = vad_file[sample*self.step: sample*self.step + self.window] # [CHUNK_LEN, ] + + # Reshape + vad_vals = np.expand_dims(vad_vals, 1) # [CHUNK_LEN, 1] + vad_vals = np.transpose(vad_vals, (1,0)) # [1, CHUNK_LEN] + return vad_vals + + def __getmotion(self, file, sample): + if self.name == 'genea2023+': + # loading rot6d and position representations + rot6dpos_file = np.load(os.path.join(self.motionpath_rot6d,self.takes[file][0]+'.npy')) + rot6dpos = (rot6dpos_file[sample*self.step: sample*self.step + self.window,:] - self.rot6dpos_mean) / self.rot6dpos_std + + # loading rotpos representation and computing velocity + rotpos_file = np.load(os.path.join(self.motionpath,self.takes[file][0]+'.npy')) + rotpos_file[1:,:] = rotpos_file[1:,:] - rotpos_file[:-1,:] + rotpos_file[0,:] = np.zeros(rotpos_file.shape[1]) + rotpos = (rotpos_file[sample*self.step: sample*self.step + self.window,:] - self.vel_mean) / self.vel_std + if sample*self.step - self.n_seed_poses < 0: + rot6dpos_seed = np.zeros((self.n_seed_poses, rot6dpos.shape[1])) + rotpos_seed = np.zeros((self.n_seed_poses, rotpos.shape[1])) + else: + rot6dpos_seed = (rot6dpos_file[sample*self.step - self.n_seed_poses: sample*self.step ,:] - self.rot6dpos_mean) / self.rot6dpos_std + rotpos_seed = (rotpos_file[sample*self.step - self.n_seed_poses: sample*self.step,:] - self.vel_mean) / self.vel_std + + motion = np.concatenate((rot6dpos, rotpos), axis=1) + seed_poses = np.concatenate((rot6dpos_seed, rotpos_seed), axis=1) + + else: + motion_file = np.load(os.path.join(self.motionpath,self.takes[file][0]+'.npy')) + motion = (motion_file[sample*self.step: sample*self.step + self.window,:] - self.mean) / self.std + if sample*self.step - self.n_seed_poses < 0: + seed_poses = np.zeros((self.n_seed_poses, motion.shape[1])) + else: + seed_poses = (motion_file[sample*self.step - self.n_seed_poses: sample*self.step,:] - self.mean) / self.std + + return motion, seed_poses + + def __getaudiofeats(self, file, sample): + + # Load Audio + signal = np.load(os.path.join(self.motionpath,'..', 'audio16k_npy',self.takes[file][0]+'.npy')) + + # Cut Chunk + i = sample*16000*self.step/self.fps + signal = signal[ int(i) : int(i+self.window*16000/self.fps) ] + + if self.use_wavlm: + # Cut Chunk + representation_file = np.load(os.path.join(self.wavlm_rep_path,self.takes[file][0]+'.npy')) + wavlm_reps = representation_file[sample*self.step: sample*self.step + self.window,:] # [CHUNK_LEN, WAVLM_DIM] + + # Reshape + wavlm_reps = np.transpose(wavlm_reps, (1,0)) # [WAVLM_DIM, CHUNK_LEN] + wavlm_reps = np.expand_dims(wavlm_reps, 1) # [WAVLM_DIM, 1, CHUNK_LEN] + wavlm_reps = np.expand_dims(wavlm_reps, 0) # [1, WAVLM_DIM, 1, CHUNK_LEN] + return signal, wavlm_reps + else: + return self.__compute_audiofeats(signal) + + def __compute_audiofeats(self, signal): + + # MFCCs + mfcc_vectors = mfcc(signal, winlen=0.06, winstep= (1/self.fps), samplerate=16000, numcep=27, nfft=5000) + + # Normalize + #mfcc_vectors = (mfcc_vectors - self.mfcc_mean) / self.mfcc_std + + # Format + mfcc_vectors = mfcc_vectors.T + mfcc_vectors = np.expand_dims(mfcc_vectors, 1) + mfcc_vectors = np.expand_dims(mfcc_vectors, 0) # should be [1, MFCC_DIM, 1, CHUNK_LEN] + return signal, mfcc_vectors + + def __gettext(self, file, sample): + with open(os.path.join(self.textpath, self.takes[file][0]+'.tsv')) as tsv: + reader = csv.reader(tsv, delimiter='\t') + file = [ [float(word[0])*self.fps, float(word[1])*self.fps, word[2]] for word in reader] + begin = self.search_time(file, sample*self.step) + end = self.search_time(file, sample*self.step + self.window) + text = [ word[-1] for word in file[begin: end] ] + tokens = self.__gentokens(text) + vad = None + if self.vadfromtext: + times = [(np.floor(word[0] - sample*self.step).astype(int), np.ceil(word[1] - sample*self.step).astype(int)) for word in file[begin: end]] + vad = np.zeros(self.window) + for (i, f) in times: + vad[i:f] = 1 + vad = np.expand_dims(vad, 1) # [CHUNK_LEN, 1] + vad = np.transpose(vad, (1,0)) # [1, CHUNK_LEN] + return len(text), ' '.join(text), tokens, vad + + def __gentokens(self, text): + tokens = [ word+'/OTHER' for word in text] + tokens = '_'.join(tokens) + tokens = 'sos/OTHER_' + tokens + '_eos/OTHER' + return tokens + + def search_time(self, text, frame): + for i in range(len(text)): + if frame <= text[i][0]: + return i if (frame > text[i-1][1] or i==0) else i-1 + + def inv_transform(self, data): + if self.name == 'genea2023': + return data * self.std + self.mean + elif self.name == 'genea2023+': + return data * np.concatenate((self.rot6dpos_std, self.vel_std)) + np.concatenate((self.rot6dpos_mean, self.vel_mean)) + else: + raise ValueError('Dataset name not recognized') + + + def gettime(self): + import time + start = time.time() + for i in range(200): + sample = self.__getitem__(i) + print(time.time()-start) + + def loadstats(self, statspath): + self.std = np.load(os.path.join(statspath, 'rotpos_Std.npy')) + self.mean = np.load(os.path.join(statspath, 'rotpos_Mean.npy')) + #self.mfcc_std = np.load(os.path.join(statspath, 'mfccs_Std.npy')) + #self.mfcc_mean = np.load(os.path.join(statspath, 'mfccs_Mean.npy')) + self.rot6dpos_std = np.load(os.path.join(statspath, 'rot6dpos_Std.npy')) + self.rot6dpos_mean = np.load(os.path.join(statspath, 'rot6dpos_Mean.npy')) + self.vel_std = np.load(os.path.join(statspath, 'velrotpos_Std.npy')) + self.vel_mean = np.load(os.path.join(statspath, 'velrotpos_Mean.npy')) + + def gettestbatch(self, num_samples): + max_length = max(self.frames[:num_samples]) + max_length = max_length + self.window - max_length%self.window # increase length so it can be divisible by window + batch_audio = [] + batch_audio_rep = [] + batch_text = [] + vad_vals = [] + for i, _ in enumerate(self.takes[:num_samples]): + # Get audio file + audio_feats = [] + signal = np.zeros(int(max_length*self.sr/self.fps)) + signal_ = np.load(os.path.join(self.audiopath,self.takes[i][0]+'.npy')) + signal[:len(signal_)] = signal_ + + if self.use_wavlm: + # Cut Chunk + wavlm_reps_ = np.load(os.path.join(self.wavlm_rep_path,self.takes[i][0]+'.npy')) + audio_feat = np.zeros((max_length, wavlm_reps_.shape[1])) + audio_feat[:wavlm_reps_.shape[0],:] = wavlm_reps_ + + # Reshape + audio_feat = np.transpose(audio_feat, (1,0)) # [WAVLM_DIM, CHUNK_LEN] + audio_feat = np.expand_dims(audio_feat, 1) # [WAVLM_DIM, 1, CHUNK_LEN] + audio_feat = np.expand_dims(audio_feat, 0) # [1, WAVLM_DIM, 1, CHUNK_LEN] + audio_feats.append(audio_feat) + + if self.use_vad: + # Cut Chunk + vad_val_ = np.load(os.path.join(self.vad_path,self.takes[i][0]+'.npy')) + vad_val = np.zeros(max_length) + vad_val[:vad_val_.shape[0]] = vad_val_ # [CHUNK_LEN, ] + + # Reshape + vad_val = np.expand_dims(vad_val, 1) # [CHUNK_LEN, 1] + vad_val = np.transpose(vad_val, (1,0)) # [1, CHUNK_LEN] + vad_vals.append(vad_val) + + # Get text file + text_feats = [] + with open(os.path.join(self.textpath, self.takes[i][0]+'.tsv')) as tsv: + reader = csv.reader(tsv, delimiter='\t') + file = [ [float(word[0])*self.fps, float(word[1])*self.fps, word[2]] for word in reader] + + for chunk in range(int(max_length/self.window)): + if not self.use_wavlm: + # Get audio features + k = chunk*self.window*self.sr/self.fps + _, audio_feat = self.__compute_audiofeats(signal[int(k) : int(k+self.window*self.sr/self.fps)]) + audio_feats.append(audio_feat) + + # Get text + begin = self.search_time(file, chunk*self.window) + end = self.search_time(file, chunk*self.window + self.window) + text = [ word[-1] for word in file[begin: end] ] if begin or end else [] + text_feats.append(' '.join(text)) + + + audio_feats = np.concatenate(audio_feats, axis=-1) + end_audio = int(len(signal_)/self.sr*self.fps) + audio_feats[..., end_audio:] = np.zeros_like(audio_feats[..., end_audio:]) # zero audio feats after end of audio + batch_audio_rep.append(audio_feats) + batch_text.append(text_feats) + batch_audio.append(signal_) + + # Dummy motions and seed poses + feats = 1245 if self.name == 'genea2023+' else 498 + motion, seed_poses = np.zeros((self.window, feats)), np.zeros((self.n_seed_poses, feats)) #dummy + + # Attention: this is not collate-ready! + return motion, batch_text, self.window, batch_audio, batch_audio_rep, seed_poses, max_length, vad_vals, self.takes[:num_samples] + + def getvalbatch(self, num_takes, index): + # Get batch of data from the validation set, index refer to the chunk that you want to get + # Example: index = 0 and num_takes = 10 will return the first chunk of the first 10 takes + # index = 5 and num_takes = 30 will return the moment starting at 5*num_frames (120 by default) and ending at 6*num_frames of the first 30 takes + # num_takes = batch_size + batch = [] + assert num_takes <= len(self.takes) + # for each take + for take in np.arange(num_takes): + # get the corresponding index to call __getitem__ + sampleindex = self.samples_cumulative[take-1] + index if take != 0 else index + # check if the index is from the take and call __getitem__ + out = self.__getitem__(sampleindex) if sampleindex <= self.samples_per_file[take] + sampleindex - index else None + batch.append(out) + return batch + + def getjoints(self, toget= ['b_r_forearm','b_l_forearm']): + #toget = ['b_r_shoulder','b_r_arm','b_r_arm_twist','b_r_forearm','b_r_wrist_twist','b_r_wrist', + # 'b_l_shoulder','b_l_arm','b_l_arm_twist','b_l_forearm','b_l_wrist_twist','b_l_wrist'] + return {k:self.alljoints[k] for k in self.alljoints if k in toget} \ No newline at end of file diff --git a/data_loaders/gesture/scripts/genea_prep.py b/data_loaders/gesture/scripts/genea_prep.py new file mode 100644 index 0000000..8c8dd13 --- /dev/null +++ b/data_loaders/gesture/scripts/genea_prep.py @@ -0,0 +1,154 @@ +from argparse import ArgumentParser +import os +import numpy as np +import librosa +import torch +from wavlm.WavLM import WavLM, WavLMConfig +import torch.nn.functional as F +from data_loaders.gesture.scripts.motion_process import bvh2representations2 +import bvhsdk +from tqdm import tqdm + +def main(args): + #paths_check(args.data_dir) + assert args.split in ['all', 'trn', 'tst', 'val'], f"Split {args.split} not recognized. Options: \'all\', \'trn\', \'tst\', \'val\'" # Check if user is trying to process a split that does not exist + splits = [args.split] if args.split != 'all' else ['trn', 'tst', 'val'] + assert args.step in ['all', 'bvh', 'wav', 'wavlm'], f"Step {args.step} not recognized. Options: \'all\', \'bvh\', \'wav\', \'wavlm\'" # Check if user is trying to process a step that does not exist + steps = [args.step] if args.step != 'all' else ['bvh', 'wav', 'wavlm'] + print('WARNING: Running all steps and all splits will take a long time.') + print('Processing splits: ', splits) + print('Processing steps: ', steps) + for split in splits: + print(f'Processing {split} split') + if 'bvh' in steps and split != 'tst': + print(f'Processing bvh for {split} split') + r6p, rp = process_bvh(args.data_dir, split) + statspath = os.path.join(args.data_dir, split, 'main-agent') + print(f'Computing mean and std for {split} split') + compute_meanstd(r6p, os.path.join(statspath, 'rot6dpos'), npstep=5) + compute_meanstd(rp, os.path.join(statspath, 'rotpos'), npstep=5) + compute_meanstd(rp, os.path.join(statspath, 'velrotpos'), npstep=5, vel=True) + if 'wav' in steps: + print(f'Processing wav for {split} split') + process_wav(args.data_dir, split) + if 'wavlm' in steps: + print(f'Processing wavlm for {split} split') + process_wavlm(args.data_dir, split) + +def process_bvh(path, split): + sourcepath = os.path.join(path, split, 'main-agent', 'bvh') + savepathrot6d = os.path.join(path, split, 'main-agent', 'motion_npy_rot6dpos') + savepathrot = os.path.join(path, split, 'main-agent', 'motion_npy_rotpos') + assert not os.path.exists(savepathrot6d), f"motion_npy_rot6dpos already exists in {savepathrot6d}. Delete it to process again." + assert not os.path.exists(savepathrot), f"motion_npy_rotpos already exists in {savepathrot}. Delete it to process again." + if not os.path.exists(savepathrot6d): + os.mkdir(savepathrot6d) + if not os.path.exists(savepathrot): + os.mkdir(savepathrot) + for file in tqdm(os.listdir(sourcepath)): + #if not os.path.exists(os.path.join(savepathrot6d, file[:-4] + '.npy')) or not os.path.exists(os.path.join(savepathrot, file[:-4] + '.npy')): + anim = bvhsdk.ReadFile(os.path.join(sourcepath, file)) + rot6dpos, rotpos = bvh2representations2(anim) + np.save(os.path.join(savepathrot6d, file[:-4]), rot6dpos) + np.save(os.path.join(savepathrot, file[:-4]), rotpos) + return savepathrot6d, savepathrot + +def compute_meanstd(path, savepath, npstep=1, vel=False): + all_data = [] + for f in os.listdir(path)[::npstep]: + data = np.load(os.path.join(path, f)) + if vel: + data = data[1:,:] - data[:-1,:] + data[0,:] = np.zeros(data.shape[1]) + all_data.append(data) + all_data = np.vstack(all_data) + mean = np.mean(all_data, axis=0) + std = np.std(all_data, axis=0) + np.save(savepath + '_Mean.npy', mean) + np.save(savepath + '_Std.npy', std) + + +def process_wav(path, split, sr=16000): + sourcepath = os.path.join(path, split, 'main-agent', 'wav') + savepath = os.path.join(path, split, 'main-agent', 'audio16k_npy') + assert not os.path.exists(savepath), f"audio_16k_npy already exists in {savepath}. Delete it to process again." + if not os.path.exists(savepath): + os.mkdir(savepath) + for file in tqdm(os.listdir(sourcepath)): + #if not os.path.exists(os.path.join(savepath, file[:-4] + '.npy')): + signal, _sr = librosa.load(os.path.join(sourcepath, file), mono=True, sr=sr) + assert _sr == sr + np.save(os.path.join(savepath, file[:-4]+'.npy'), signal) + return savepath + +def process_wavlm(path, split): + wavlm_layer = 11 + fps=30 + sr=16000 + device = 'cuda' if torch.cuda.is_available() else 'cpu' + sourcepath = os.path.join(path, split, 'main-agent', 'audio16k_npy') + savepath = os.path.join(path, split, 'main-agent', 'wavlm_representations') + #assert os.path.exists(sourcepath), f"audio16k_npy not found in {sourcepath}. Required to process wavlm representations, make sure wav files were processed first." + #assert os.path.exists(savepath), f"wavlm model directory not found in current directory." + if not os.path.exists(savepath): + os.mkdir(savepath) + checkpoint = torch.load('./wavlm/WavLM-Base+.pt') + wavlm_cfg = WavLMConfig(checkpoint['cfg']) + wavlm = WavLM(wavlm_cfg) + #wavlm.to(device) + wavlm.load_state_dict(checkpoint['model']) + wavlm.eval() + for file in tqdm(os.listdir(sourcepath)): + if not os.path.exists(os.path.join(savepath, file)): + audio_path = os.path.join(sourcepath, file) + # Load with Numpy + signal = np.load(audio_path) + # Set to model innput format + signal = torch.tensor(signal).unsqueeze(0)#.to(device) + # Normalize + if wavlm_cfg.normalize: + signal_norm = torch.nn.functional.layer_norm(signal , signal.shape) + else: + signal_norm = signal + # Run Model (rep=Desired Layer, layer_results=all layers) + rep, layer_results = wavlm.extract_features(signal_norm, output_layer=wavlm_layer, ret_layer_results=True)[0] + layer_reps = [x.transpose(0, 1) for x, _ in layer_results] # fix shape + # Get Number of Seconds of Audio File + n_secs = signal.shape[1] / sr + # Get Number of poses equivalent to audio file duration, given fps (alignment len) + n_pose = n_secs * fps + # Interpolate number of representations to match number of poses corresponding to audio file + interp_reps = F.interpolate(rep.transpose(1, 2), size=int(n_pose), align_corners=True, mode='linear') + # Prepare to save + interp_reps = interp_reps.squeeze(0).transpose(0,1).cpu().detach().data.cpu().numpy() + # Double check dimension + assert (interp_reps.shape[0] == int(np.ceil(n_pose)) or interp_reps.shape[0] == int(np.floor(n_pose))) + np.save(os.path.join(savepath, file), interp_reps) + + +def paths_check(data_dir): + # First check if everything is in place + for split in ['trn', 'tst', 'val']: + split_dir = os.path.join(data_dir, split) + assert os.path.exists(split_dir), f"Split {split} not found in {data_dir}" + main_agent_dir = os.path.join(split_dir, 'main-agent') + assert os.path.exists(main_agent_dir), f"main_agent not found in {split_dir}" + tsv_dir = os.path.join(main_agent_dir, 'tsv') + wav_dir = os.path.join(main_agent_dir, 'wav') + assert os.path.exists(tsv_dir), f"tsv not found in {main_agent_dir}" + assert os.path.exists(wav_dir), f"wav not found in {main_agent_dir}" + assert len(os.listdir(tsv_dir)) == len(os.listdir(wav_dir)), f"tsv and wav have different number of files in {main_agent_dir}" + if split != 'tst': + bvh_dir = os.path.join(main_agent_dir, 'bvh') + assert os.path.exists(bvh_dir), f"bvhs not found in {main_agent_dir}" + assert len(os.listdir(tsv_dir)) == len(os.listdir(bvh_dir)), f"tsv and bvh have different number of files in {main_agent_dir}" + print('Data paths and files seem correct') + + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument('--data_dir', type=str, default='./dataset/Genea2023', help='path to the dataset directory') + parser.add_argument('--split', type=str, default='all', help='Which split to process. Use \'all\' to process all splits') + parser.add_argument('--step', type=str, default='all', help='Which step to process. Use \'all\' to process all steps. Options: \'bvh\', \'wav\', \'wavlm\'') + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/data_loaders/gesture/scripts/genea_prep_vad.py b/data_loaders/gesture/scripts/genea_prep_vad.py new file mode 100644 index 0000000..cff5fc6 --- /dev/null +++ b/data_loaders/gesture/scripts/genea_prep_vad.py @@ -0,0 +1,72 @@ +from argparse import ArgumentParser +import os +import numpy as np +import torch +from speechbrain.inference.VAD import VAD +import torchaudio +from scipy.signal import resample +from tqdm import tqdm + + + +def main(args): + #paths_check(args.data_dir) + assert args.split in ['all', 'trn', 'tst', 'val'], f"Split {args.split} not recognized. Options: \'all\', \'trn\', \'tst\', \'val\'" # Check if user is trying to process a split that does not exist + splits = [args.split] if args.split != 'all' else ['trn', 'tst', 'val'] + + print('Processing VAD.') + print('Processing splits: ', splits) + for split in splits: + print(f'Processing vad for {split} split') + process_vad(args.data_dir, split) + + + +def process_vad(path, split): + sr=16000 + fps=30 + sourcepath = os.path.join(path, split, 'main-agent', 'wav') + savepathrot = os.path.join(path, split, 'main-agent', 'vad') + _VAD = VAD.from_hparams(source= "speechbrain/vad-crdnn-libriparty", savedir= os.path.join(path, '..','..','speechbrain', 'pretrained_models', 'vad-crdnn-libriparty')) + #assert not os.path.exists(savepathrot), f"vad already exists in {savepathrot}. Delete it to process again." + if not os.path.exists(savepathrot): + os.mkdir(savepathrot) + # VAD requires a torch tensor with sample rate = 16k. This process saves a temporary wav file with 16k sr. It can be deleted after processing. + for file in tqdm(os.listdir(sourcepath)): + audio, old_sr = torchaudio.load(os.path.join(sourcepath,file)) + audio = torchaudio.functional.resample(audio, orig_freq=44100, new_freq=sr) + tmpfile = "tmp.wav" + torchaudio.save( + tmpfile , audio, sr + ) + boundaries = _VAD.get_speech_prob_file(audio_file=tmpfile, large_chunk_size=4, small_chunk_size=0.2) + boundaries = resample(boundaries[0,:,0], int(boundaries.shape[1]*fps/100)) + boundaries[boundaries>=0.5] = 1 + boundaries[boundaries<0.5] = 0 + np.save(os.path.join(savepathrot, file[:-4]+'.npy'), boundaries) + +def paths_check(data_dir): + # First check if everything is in place + for split in ['trn', 'tst', 'val']: + split_dir = os.path.join(data_dir, split) + assert os.path.exists(split_dir), f"Split {split} not found in {data_dir}" + main_agent_dir = os.path.join(split_dir, 'main-agent') + assert os.path.exists(main_agent_dir), f"main_agent not found in {split_dir}" + tsv_dir = os.path.join(main_agent_dir, 'tsv') + wav_dir = os.path.join(main_agent_dir, 'wav') + assert os.path.exists(tsv_dir), f"tsv not found in {main_agent_dir}" + assert os.path.exists(wav_dir), f"wav not found in {main_agent_dir}" + assert len(os.listdir(tsv_dir)) == len(os.listdir(wav_dir)), f"tsv and wav have different number of files in {main_agent_dir}" + if split != 'tst': + bvh_dir = os.path.join(main_agent_dir, 'bvh') + assert os.path.exists(bvh_dir), f"bvhs not found in {main_agent_dir}" + assert len(os.listdir(tsv_dir)) == len(os.listdir(bvh_dir)), f"tsv and bvh have different number of files in {main_agent_dir}" + print('Data paths and files seem correct') + + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument('--data_dir', type=str, default='./dataset/Genea2023', help='path to the dataset directory') + parser.add_argument('--split', type=str, default='all', help='Which split to process. Use \'all\' to process all splits') + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/data_loaders/gesture/scripts/motion_process.py b/data_loaders/gesture/scripts/motion_process.py new file mode 100644 index 0000000..2c9b819 --- /dev/null +++ b/data_loaders/gesture/scripts/motion_process.py @@ -0,0 +1,124 @@ +import numpy as np +import utils.rotation_conversions as geometry +import bvhsdk +from scipy.signal import savgol_filter + + +def get_indexes(dataset): + n_joints = 83 + if dataset == 'genea2023': + idx_positions = np.asarray([ [i*6+3, i*6+4, i*6+5] for i in range(n_joints) ]).flatten() + idx_rotations = np.asarray([ [i*6, i*6+1, i*6+2] for i in range(n_joints) ]).flatten() + elif dataset == 'genea2023+': + idx_positions = np.asarray([ [i*9+6, i*9+7, i*9+8] for i in range(n_joints) ]).flatten() + idx_rotations = np.asarray([ [i*9, i*9+1, i*9+2, i*9+3, i*9+4, i*9+5] for i in range(n_joints) ]).flatten() + else: + raise NotImplementedError("This dataset is not implemented.") + return idx_positions, idx_rotations + +def split_pos_rot(dataset, data): + # Split the data into positions and rotations + # Shape expected [num_samples(bs), 1, chunk_len, 1245 or 498] + # Output shape [num_samples(bs), 1, chunk_len, 498 or 249] + idx_positions, idx_rotations = get_indexes(dataset) + return data[..., idx_positions], data[..., idx_rotations] + +def rot6d_to_euler(data): + # Convert numpy array to euler angles + # Shape expected [num_samples(bs), 1, chunk_len, 498] + # Output shape [num_samples(bs) * chunk_len, n_joints, 3] + n_joints = 83 + assert data.shape[-1] == n_joints*6 + sample_rot = data.view(data.shape[:-1] + (-1, 6)) # [num_samples(bs), 1, chunk_len, n_joints, 6] + sample_rot = geometry.rotation_6d_to_matrix(sample_rot) # [num_samples(bs), 1, chunk_len, n_joints, 3, 3] + sample_rot = geometry.matrix_to_euler_angles(sample_rot, "ZXY")[..., [1, 2, 0] ]*180/np.pi # [num_samples(bs), 1, chunk_len, n_joints, 3] + sample_rot = sample_rot.view(-1, *sample_rot.shape[2:]).permute(0, 2, 3, 1).squeeze() # [num_samples(bs), n_joints, 3, chunk_len] + return sample_rot + +def tobvh(bvhreference, rotation, position=None): + # Converts to bvh format + # Shape expected [njoints, 3, frames] + # returns a bvh object + rotation = rotation.transpose(2, 0, 1) # [frames, njoints, 3] + bvhreference.frames = rotation.shape[0] + for j, joint in enumerate(bvhreference.getlistofjoints()): + joint.rotation = rotation[:, j, :] + joint.translation = np.tile(joint.offset, (bvhreference.frames, 1)) + if position.any(): + position = position.transpose(2, 0, 1) # [frames, njoints, 3] + bvhreference.root.translation = position[:, 0, :] + return bvhreference + +def posfrombvh(bvh): + # Extracts positions from bvh + # returns a numpy array shaped [frames, njoints, 3] + position = np.zeros((bvh.frames, len(bvh.getlistofjoints()) * 3)) + # This way takes advantage of the implementarion of getPosition (16.9 seconds ~4000 frames) + for frame in range(bvh.frames): + for i, joint in enumerate(bvh.getlistofjoints()): + position[frame, i*3:i*3+3] = joint.getPosition(frame) + return position + + +def filter_and_interp(rotation, position, num_frames=120, chunks=None): + # Smooth chunk transitions + # + n_chunks = chunks if chunks else int(rotation.shape[-1]/num_frames) + inter_range = 10 #interpolation range in frames + for transition in np.arange(1, n_chunks-1)*num_frames: + position[..., transition:transition+2] = np.tile(np.expand_dims(position[..., transition]/2 + position[..., transition-1]/2,-1),2) + rotation[..., transition:transition+2] = np.tile(np.expand_dims(rotation[..., transition]/2 + rotation[..., transition-1]/2,-1),2) + for i, s in enumerate(np.linspace(0, 1, inter_range-1)): + forward = transition-inter_range+i + backward = transition+inter_range-i + position[..., forward] = position[..., forward]*(1-s) + position[:, :, :, transition-1]*s + position[..., backward] = position[..., backward]*(1-s) + position[:, :, :, transition]*s + rotation[..., forward] = rotation[..., forward]*(1-s) + rotation[:, :, :, transition-1]*s + rotation[..., backward] = rotation[..., backward]*(1-s) + rotation[:, :, :, transition]*s + + position = savgol_filter(position, 9, 3, axis=-1) + rotation = savgol_filter(rotation, 9, 3, axis=-1) + + return position, rotation + +def np_matrix_to_rotation_6d(matrix: np.ndarray) -> np.ndarray: + """ + Same as utils.rotation_conversions.matrix_to_rotation_6d but for numpy arrays. + """ + return matrix[..., :2, :].copy().reshape(6) + +def bvh2representations2(anim: bvhsdk.Animation): + # Converts bvh to two representations: 6d rotations and 3d positions + # And 3d rotations (euler angles) and 3d positions + # The 3d positions of both representations are the same (duplicated data) + # This representation is used in the genea challenge + njoints = len(anim.getlistofjoints()) + npyrot6dpos = np.empty(shape=(anim.frames, 9*njoints)) + npyrotpos = np.empty(shape=(anim.frames, 6*njoints)) + for i, joint in enumerate(anim.getlistofjoints()): + npyrot6dpos[:,i*9:i*9+6] = [ np_matrix_to_rotation_6d(joint.getLocalTransform(frame)[:-1,:-1]) for frame in range(anim.frames) ] + npyrotpos[:,i*6:i*6+3] = [ joint.rotation[frame] for frame in range(anim.frames) ] + + for frame in range(anim.frames): + for i, joint in enumerate(anim.getlistofjoints()): + pos = joint.getPosition(frame) + npyrot6dpos[frame, i*9+6:i*9+9] = pos + npyrotpos[frame, i*6+3:i*6+6] = pos + + return npyrot6dpos, npyrotpos + +def bvh2representations1(anim: bvhsdk.Animation): + # Converts bvh to three representations: 6d rotations, 3d positions (euler angles) and 3d positions + njoints = len(anim.getlistofjoints()) + npyrot6d = np.empty(shape=(anim.frames, 6*njoints)) + npyrot = np.empty(shape=(anim.frames, 3*njoints)) + npypos = np.empty(shape=(anim.frames, 3*njoints)) + for i, joint in enumerate(anim.getlistofjoints()): + npyrot6d[:,i*6:i*6+6] = [ np_matrix_to_rotation_6d(joint.getLocalTransform(frame)[:-1,:-1]) for frame in range(anim.frames) ] + npyrot[:,i*3:i*3+3] = [ joint.rotation[frame] for frame in range(anim.frames) ] + + for frame in range(anim.frames): + for i, joint in enumerate(anim.getlistofjoints()): + npypos[frame, i*3:i*3+3] = joint.getPosition(frame) + + return npyrot6d, npyrot, npypos \ No newline at end of file diff --git a/data_loaders/get_data.py b/data_loaders/get_data.py new file mode 100644 index 0000000..30f2053 --- /dev/null +++ b/data_loaders/get_data.py @@ -0,0 +1,34 @@ +from torch.utils.data import DataLoader +from data_loaders.tensors import collate as all_collate +from data_loaders.tensors import gg_collate + +def get_dataset_class(name): + if name in ["genea2023", "genea2023+"]: + from data_loaders.gesture.data.dataset import Genea2023 + return Genea2023 + else: + raise ValueError(f'Unsupported dataset name [{name}]') + +def get_collate_fn(name, hml_mode='train'): + if name in ["genea2023", "genea2023+"]: + return gg_collate + else: + raise ValueError(f'Unsupported dataset name [{name}]') + +def get_dataset(name, data_dir,num_frames, seed_poses, step, use_wavlm, use_vad, vadfromtext, split='trn', hml_mode='train', ): + DATA = get_dataset_class(name) + dataset = DATA(name=name, datapath=data_dir, split=split, window=num_frames, n_seed_poses=seed_poses, step=step, use_wavlm=use_wavlm, use_vad=use_vad, vadfromtext=vadfromtext) + return dataset + + +def get_dataset_loader(name, data_dir, batch_size, num_frames, step, use_wavlm, use_vad, vadfromtext, split='trn', hml_mode='train', seed_poses=10): + dataset = get_dataset(name, data_dir,num_frames, seed_poses, step, use_wavlm, use_vad, vadfromtext, split, hml_mode) + collate = get_collate_fn(name, hml_mode) + + shuffled = True if split == 'trn' else False + loader = DataLoader( + dataset, batch_size=batch_size, shuffle=shuffled, + num_workers=16, drop_last=True, collate_fn=collate + ) + + return loader \ No newline at end of file diff --git a/data_loaders/tensors.py b/data_loaders/tensors.py new file mode 100644 index 0000000..1688f69 --- /dev/null +++ b/data_loaders/tensors.py @@ -0,0 +1,76 @@ +import torch + +def lengths_to_mask(lengths, max_len): + # max_len = max(lengths) + mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1) + return mask + + +def collate_tensors(batch): + dims = batch[0].dim() + max_size = [max([b.size(i) for b in batch]) for i in range(dims)] + size = (len(batch),) + tuple(max_size) + canvas = batch[0].new_zeros(size=size) + for i, b in enumerate(batch): + sub_tensor = canvas[i] + for d in range(dims): + sub_tensor = sub_tensor.narrow(d, 0, b.size(d)) + sub_tensor.add_(b) + return canvas + + +def collate(batch): + notnone_batches = [b for b in batch if b is not None] + databatch = [b['inp'] for b in notnone_batches] + if 'lengths' in notnone_batches[0]: + lenbatch = [b['lengths'] for b in notnone_batches] + else: + lenbatch = [len(b['inp'][0][0]) for b in notnone_batches] + + + databatchTensor = collate_tensors(databatch) + lenbatchTensor = torch.as_tensor(lenbatch) + maskbatchTensor = lengths_to_mask(lenbatchTensor, databatchTensor.shape[-1]).unsqueeze(1).unsqueeze(1) # unqueeze for broadcasting + + motion = databatchTensor + cond = {'y': {'mask': maskbatchTensor, 'lengths': lenbatchTensor}} + + if 'text' in notnone_batches[0]: + textbatch = [b['text'] for b in notnone_batches] + cond['y'].update({'text': textbatch}) + if 'audio_rep' in notnone_batches[0]: + audio_repbatch = [b['audio_rep'] for b in notnone_batches] + audio_repbatch = torch.cat(audio_repbatch, dim=0) + cond['y'].update({'audio_rep': audio_repbatch}) + if 'audio' in notnone_batches[0]: + audiobatch = [b['audio'] for b in notnone_batches] + audiobatch = torch.cat(audiobatch, dim=0) + cond['y'].update({'audio': audiobatch}) + if 'seed' in notnone_batches[0]: + seedbatch = [b['seed'].unsqueeze(0) for b in notnone_batches] + seedbatch = torch.cat(seedbatch, dim=0) + cond['y'].update({'seed': seedbatch}) + if 'vad' in notnone_batches[0]: + vadbatch = [b['vad'] for b in notnone_batches] + vadbatch = torch.cat(vadbatch, dim=0) + cond['y'].update({'vad': vadbatch}) + if 'takename' in notnone_batches[0]: + takenamebatch = [b['takename'] for b in notnone_batches] + cond['y'].update({'takename': takenamebatch}) + return motion, cond + +# an adapter to our collate func +def gg_collate(batch): + # batch.sort(key=lambda x: x[3], reverse=True) + adapted_batch = [{ + 'inp': torch.tensor(b[0].T).float().unsqueeze(1), # [seqlen, J] -> [J, 1, seqlen] + 'text': b[1], #b[0]['caption'] + 'lengths': b[2], + 'audio': torch.tensor(b[3]).unsqueeze(0), # [seqlen] -> [1, seqlen] + 'audio_rep': torch.from_numpy(b[4]).float(), # [1, AUDIO_HID_DIM, 1, CHUNK_LEN] , (AUDIO_HID_DIM = MFCC_DIM or 768) + 'seed': torch.tensor(b[5].T).float().unsqueeze(1), # [n_seed_poses, J] -> [J, 1, n_seed_poses] + 'vad': torch.from_numpy(b[6]).long(), # [1, CHUNK_LEN] + 'takename': b[7] + } for b in batch] + return collate(adapted_batch) + diff --git a/diffusion/fp16_util.py b/diffusion/fp16_util.py new file mode 100644 index 0000000..1ccb93e --- /dev/null +++ b/diffusion/fp16_util.py @@ -0,0 +1,236 @@ +""" +Helpers to train with 16-bit precision. +""" + +import numpy as np +import torch as th +import torch.nn as nn +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + +from diffusion import logger + +INITIAL_LOG_LOSS_SCALE = 20.0 + + +def convert_module_to_f16(l): + """ + Convert primitive modules to float16. + """ + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + +def convert_module_to_f32(l): + """ + Convert primitive modules to float32, undoing convert_module_to_f16(). + """ + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + l.weight.data = l.weight.data.float() + if l.bias is not None: + l.bias.data = l.bias.data.float() + + +def make_master_params(param_groups_and_shapes): + """ + Copy model parameters into a (differently-shaped) list of full-precision + parameters. + """ + master_params = [] + for param_group, shape in param_groups_and_shapes: + master_param = nn.Parameter( + _flatten_dense_tensors( + [param.detach().float() for (_, param) in param_group] + ).view(shape) + ) + master_param.requires_grad = True + master_params.append(master_param) + return master_params + + +def model_grads_to_master_grads(param_groups_and_shapes, master_params): + """ + Copy the gradients from the model parameters into the master parameters + from make_master_params(). + """ + for master_param, (param_group, shape) in zip( + master_params, param_groups_and_shapes + ): + master_param.grad = _flatten_dense_tensors( + [param_grad_or_zeros(param) for (_, param) in param_group] + ).view(shape) + + +def master_params_to_model_params(param_groups_and_shapes, master_params): + """ + Copy the master parameter data back into the model parameters. + """ + # Without copying to a list, if a generator is passed, this will + # silently not copy any parameters. + for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): + for (_, param), unflat_master_param in zip( + param_group, unflatten_master_params(param_group, master_param.view(-1)) + ): + param.detach().copy_(unflat_master_param) + + +def unflatten_master_params(param_group, master_param): + return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) + + +def get_param_groups_and_shapes(named_model_params): + named_model_params = list(named_model_params) + scalar_vector_named_params = ( + [(n, p) for (n, p) in named_model_params if p.ndim <= 1], + (-1), + ) + matrix_named_params = ( + [(n, p) for (n, p) in named_model_params if p.ndim > 1], + (1, -1), + ) + return [scalar_vector_named_params, matrix_named_params] + + +def master_params_to_state_dict( + model, param_groups_and_shapes, master_params, use_fp16 +): + if use_fp16: + state_dict = model.state_dict() + for master_param, (param_group, _) in zip( + master_params, param_groups_and_shapes + ): + for (name, _), unflat_master_param in zip( + param_group, unflatten_master_params(param_group, master_param.view(-1)) + ): + assert name in state_dict + state_dict[name] = unflat_master_param + else: + state_dict = model.state_dict() + for i, (name, _value) in enumerate(model.named_parameters()): + assert name in state_dict + state_dict[name] = master_params[i] + return state_dict + + +def state_dict_to_master_params(model, state_dict, use_fp16): + if use_fp16: + named_model_params = [ + (name, state_dict[name]) for name, _ in model.named_parameters() + ] + param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) + master_params = make_master_params(param_groups_and_shapes) + else: + master_params = [state_dict[name] for name, _ in model.named_parameters()] + return master_params + + +def zero_master_grads(master_params): + for param in master_params: + param.grad = None + + +def zero_grad(model_params): + for param in model_params: + # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group + if param.grad is not None: + param.grad.detach_() + param.grad.zero_() + + +def param_grad_or_zeros(param): + if param.grad is not None: + return param.grad.data.detach() + else: + return th.zeros_like(param) + + +class MixedPrecisionTrainer: + def __init__( + self, + *, + model, + use_fp16=False, + fp16_scale_growth=1e-3, + initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, + ): + self.model = model + self.use_fp16 = use_fp16 + self.fp16_scale_growth = fp16_scale_growth + + self.model_params = list(self.model.parameters()) + self.master_params = self.model_params + self.param_groups_and_shapes = None + self.lg_loss_scale = initial_lg_loss_scale + + if self.use_fp16: + self.param_groups_and_shapes = get_param_groups_and_shapes( + self.model.named_parameters() + ) + self.master_params = make_master_params(self.param_groups_and_shapes) + self.model.convert_to_fp16() + + def zero_grad(self): + zero_grad(self.model_params) + + def backward(self, loss: th.Tensor): + if self.use_fp16: + loss_scale = 2 ** self.lg_loss_scale + (loss * loss_scale).backward() + else: + loss.backward() + + def optimize(self, opt: th.optim.Optimizer): + if self.use_fp16: + return self._optimize_fp16(opt) + else: + return self._optimize_normal(opt) + + def _optimize_fp16(self, opt: th.optim.Optimizer): + logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) + model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) + grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale) + if check_overflow(grad_norm): + self.lg_loss_scale -= 1 + logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") + zero_master_grads(self.master_params) + return False + + logger.logkv_mean("grad_norm", grad_norm) + logger.logkv_mean("param_norm", param_norm) + + self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) + opt.step() + zero_master_grads(self.master_params) + master_params_to_model_params(self.param_groups_and_shapes, self.master_params) + self.lg_loss_scale += self.fp16_scale_growth + return True + + def _optimize_normal(self, opt: th.optim.Optimizer): + grad_norm, param_norm = self._compute_norms() + logger.logkv_mean("grad_norm", grad_norm) + logger.logkv_mean("param_norm", param_norm) + opt.step() + return True + + def _compute_norms(self, grad_scale=1.0): + grad_norm = 0.0 + param_norm = 0.0 + for p in self.master_params: + with th.no_grad(): + param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 + if p.grad is not None: + grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 + return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) + + def master_params_to_state_dict(self, master_params): + return master_params_to_state_dict( + self.model, self.param_groups_and_shapes, master_params, self.use_fp16 + ) + + def state_dict_to_master_params(self, state_dict): + return state_dict_to_master_params(self.model, state_dict, self.use_fp16) + + +def check_overflow(value): + return (value == float("inf")) or (value == -float("inf")) or (value != value) diff --git a/diffusion/gaussian_diffusion.py b/diffusion/gaussian_diffusion.py new file mode 100644 index 0000000..f242432 --- /dev/null +++ b/diffusion/gaussian_diffusion.py @@ -0,0 +1,1562 @@ +# This code is based on https://github.com/openai/guided-diffusion +""" +This code started out as a PyTorch port of Ho et al's diffusion models: +https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py + +Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules. +""" + +import enum +import math + +import numpy as np +import torch +import torch as th +from copy import deepcopy +from diffusion.nn import mean_flat, sum_flat +from diffusion.losses import normal_kl, discretized_gaussian_log_likelihood + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps, scale_betas=1.): + """ + Get a pre-defined beta schedule for the given name. + + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = scale_betas * 1000 / num_diffusion_timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + return np.linspace( + beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 + ) + elif schedule_name == "cosine": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = enum.auto() # the model predicts x_{t-1} + START_X = enum.auto() # the model predicts x_0 + EPSILON = enum.auto() # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = ( + enum.auto() + ) # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +class GaussianDiffusion: + """ + Utilities for training and sampling diffusion models. + + Ported directly from here, and then adapted over time to further experimentation. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + :param model_mean_type: a ModelMeanType determining what the model outputs. + :param model_var_type: a ModelVarType determining how variance is output. + :param loss_type: a LossType determining the loss function to use. + :param rescale_timesteps: if True, pass floating point timesteps into the + model so that they are always scaled like in the + original paper (0 to 1000). + """ + + def __init__( + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type, + rescale_timesteps=False, + lambda_rcxyz=0., + lambda_vel=0., + lambda_pose=1., + lambda_orient=1., + lambda_loc=1., + data_rep='rot6d', + lambda_root_vel=0., + lambda_vel_rcxyz=0., + lambda_fc=0., + ): + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + self.rescale_timesteps = rescale_timesteps + self.data_rep = data_rep + + if data_rep != 'rot_vel' and lambda_pose != 1.: + raise ValueError('lambda_pose is relevant only when training on velocities!') + self.lambda_pose = lambda_pose + self.lambda_orient = lambda_orient + self.lambda_loc = lambda_loc + + self.lambda_rcxyz = lambda_rcxyz + self.lambda_vel = lambda_vel + self.lambda_root_vel = lambda_root_vel + self.lambda_vel_rcxyz = lambda_vel_rcxyz + self.lambda_fc = lambda_fc + + if self.lambda_rcxyz > 0. or self.lambda_vel > 0. or self.lambda_root_vel > 0. or \ + self.lambda_vel_rcxyz > 0. or self.lambda_fc > 0.: + assert self.loss_type == LossType.MSE, 'Geometric losses are supported by MSE loss type only!' + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = ( + betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + # log calculation clipped because the posterior variance is 0 at the + # beginning of the diffusion chain. + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:]) + ) + self.posterior_mean_coef1 = ( + betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + self.posterior_mean_coef2 = ( + (1.0 - self.alphas_cumprod_prev) + * np.sqrt(alphas) + / (1.0 - self.alphas_cumprod) + ) + + self.l2_loss = lambda a, b: (a - b) ** 2 # th.nn.MSELoss(reduction='none') # must be None for handling mask later on. + self.smooth_l1_loss = th.nn.SmoothL1Loss(reduction='none') + + def masked_l2(self, a, b, mask): + # assuming a.shape == b.shape == bs, J, Jdim, seqlen + # assuming mask.shape == bs, 1, 1, seqlen + #loss = self.l2_loss(a, b) + loss = self.smooth_l1_loss(a, b) + loss = sum_flat(loss * mask.float()) # gives \sigma_euclidean over unmasked elements + n_entries = a.shape[1] * a.shape[2] + non_zero_elements = sum_flat(mask) * n_entries + # print('mask', mask.shape) + # print('non_zero_elements', non_zero_elements) + # print('loss', loss) + mse_loss_val = loss / non_zero_elements + # print('mse_loss_val', mse_loss_val) + return mse_loss_val + + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + ) + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor( + self.log_one_minus_alphas_cumprod, t, x_start.shape + ) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the dataset for a given number of diffusion steps. + + In other words, sample from q(x_t | x_0). + + :param x_start: the initial dataset batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + + q(x_{t-1} | x_t, x_0) + + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance( + self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None + ): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, C = x.shape[:2] + assert t.shape == (B,) + model_output = model(x, self._scale_timesteps(t), **model_kwargs) + + if 'inpainting_mask' in model_kwargs['y'].keys() and 'inpainted_motion' in model_kwargs['y'].keys(): + inpainting_mask, inpainted_motion = model_kwargs['y']['inpainting_mask'], model_kwargs['y']['inpainted_motion'] + assert self.model_mean_type == ModelMeanType.START_X, 'This feature supports only X_start pred for mow!' + assert model_output.shape == inpainting_mask.shape == inpainted_motion.shape + model_output = (model_output * ~inpainting_mask) + (inpainted_motion * inpainting_mask) + # print('model_output', model_output.shape, model_output) + # print('inpainting_mask', inpainting_mask.shape, inpainting_mask[0,0,0,:]) + # print('inpainted_motion', inpainted_motion.shape, inpainted_motion) + + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert model_output.shape == (B, C * 2, *x.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + if self.model_var_type == ModelVarType.LEARNED: + model_log_variance = model_var_values + model_variance = th.exp(model_log_variance) + else: + min_log = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x.shape + ) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + # print('model_variance', model_variance) + # print('model_log_variance',model_log_variance) + # print('self.posterior_variance', self.posterior_variance) + # print('self.posterior_log_variance_clipped', self.posterior_log_variance_clipped) + # print('self.model_var_type', self.model_var_type) + + + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + # print('clip_denoised', clip_denoised) + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.PREVIOUS_X: + pred_xstart = process_xstart( + self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output) + ) + model_mean = model_output + elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: # THIS IS US! + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart( + self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) + ) + model_mean, _, _ = self.q_posterior_mean_variance( + x_start=pred_xstart, x_t=x, t=t + ) + else: + raise NotImplementedError(self.model_mean_type) + + assert ( + model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + ) + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_xstart_from_xprev(self, x_t, t, xprev): + assert x_t.shape == xprev.shape + return ( # (xprev - coef2*x_t) / coef1 + _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev + - _extract_into_tensor( + self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape + ) + * x_t + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return t.float() * (1000.0 / self.num_timesteps) + return t + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) + new_mean = ( + p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + ) + return new_mean + + def condition_mean_with_grad(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, t, p_mean_var, **model_kwargs) + new_mean = ( + p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + ) + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + + See condition_mean() for details on cond_fn. + + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn( + x, self._scale_timesteps(t), **model_kwargs + ) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance( + x_start=out["pred_xstart"], x_t=x, t=t + ) + return out + + def condition_score_with_grad(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + + See condition_mean() for details on cond_fn. + + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn( + x, t, p_mean_var, **model_kwargs + ) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance( + x_start=out["pred_xstart"], x_t=x, t=t + ) + return out + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + const_noise=False, + ): + """ + Sample x_{t-1} from the model at the given timestep. + + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = th.randn_like(x) + # print('const_noise', const_noise) + if const_noise: + noise = noise[[0]].repeat(x.shape[0], 1, 1, 1) + + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean( + cond_fn, out, x, t, model_kwargs=model_kwargs + ) + # print('mean', out["mean"].shape, out["mean"]) + # print('log_variance', out["log_variance"].shape, out["log_variance"]) + # print('nonzero_mask', nonzero_mask.shape, nonzero_mask) + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_with_grad( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + ): + """ + Sample x_{t-1} from the model at the given timestep. + + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + with th.enable_grad(): + x = x.detach().requires_grad_() + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = th.randn_like(x) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean_with_grad( + cond_fn, out, x, t, model_kwargs=model_kwargs + ) + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"].detach()} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + skip_timesteps=0, + init_image=None, + randomize_class=False, + cond_fn_with_grad=False, + dump_steps=None, + const_noise=False, + ): + """ + Generate samples from the model. + + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :param const_noise: If True, will noise all samples with the same noise throughout sampling + :return: a non-differentiable batch of samples. + """ + final = None + if dump_steps is not None: + dump = [] + + for i, sample in enumerate(self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + skip_timesteps=skip_timesteps, + init_image=init_image, + randomize_class=randomize_class, + cond_fn_with_grad=cond_fn_with_grad, + const_noise=const_noise, + )): + if dump_steps is not None and i in dump_steps: + dump.append(deepcopy(sample["sample"])) + final = sample + if dump_steps is not None: + return dump + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + skip_timesteps=0, + init_image=None, + randomize_class=False, + cond_fn_with_grad=False, + const_noise=False, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + + if skip_timesteps and init_image is None: + init_image = th.zeros_like(img) + + indices = list(range(self.num_timesteps - skip_timesteps))[::-1] + + if init_image is not None: + my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] + img = self.q_sample(init_image, my_t, img) + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + if randomize_class and 'y' in model_kwargs: + model_kwargs['y'] = th.randint(low=0, high=model.num_classes, + size=model_kwargs['y'].shape, + device=model_kwargs['y'].device) + with th.no_grad(): + sample_fn = self.p_sample_with_grad if cond_fn_with_grad else self.p_sample + out = sample_fn( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + const_noise=const_noise, + ) + yield out + img = out["sample"] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + + Same usage as p_sample(). + """ + out_orig = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) + else: + out = out_orig + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = ( + eta + * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * th.sqrt(1 - alpha_bar / alpha_bar_prev) + ) + # Equation 12. + noise = th.randn_like(x) + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps + ) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out_orig["pred_xstart"]} + + def ddim_sample_with_grad( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + + Same usage as p_sample(). + """ + with th.enable_grad(): + x = x.detach().requires_grad_() + out_orig = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score_with_grad(cond_fn, out_orig, x, t, + model_kwargs=model_kwargs) + else: + out = out_orig + + out["pred_xstart"] = out["pred_xstart"].detach() + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = ( + eta + * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * th.sqrt(1 - alpha_bar / alpha_bar_prev) + ) + # Equation 12. + noise = th.randn_like(x) + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps + ) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out_orig["pred_xstart"].detach()} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x + - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_next) + + th.sqrt(1 - alpha_bar_next) * eps + ) + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + skip_timesteps=0, + init_image=None, + randomize_class=False, + cond_fn_with_grad=False, + dump_steps=None, + const_noise=False, + ): + """ + Generate samples from the model using DDIM. + + Same usage as p_sample_loop(). + """ + if dump_steps is not None: + raise NotImplementedError() + if const_noise == True: + raise NotImplementedError() + + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + skip_timesteps=skip_timesteps, + init_image=init_image, + randomize_class=randomize_class, + cond_fn_with_grad=cond_fn_with_grad, + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + skip_timesteps=0, + init_image=None, + randomize_class=False, + cond_fn_with_grad=False, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + + if skip_timesteps and init_image is None: + init_image = th.zeros_like(img) + + indices = list(range(self.num_timesteps - skip_timesteps))[::-1] + + if init_image is not None: + my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] + img = self.q_sample(init_image, my_t, img) + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + if randomize_class and 'y' in model_kwargs: + model_kwargs['y'] = th.randint(low=0, high=model.num_classes, + size=model_kwargs['y'].shape, + device=model_kwargs['y'].device) + with th.no_grad(): + sample_fn = self.ddim_sample_with_grad if cond_fn_with_grad else self.ddim_sample + out = sample_fn( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + ) + yield out + img = out["sample"] + + def plms_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + cond_fn_with_grad=False, + order=2, + old_out=None, + ): + """ + Sample x_{t-1} from the model using Pseudo Linear Multistep. + + Same usage as p_sample(). + """ + if not int(order) or not 1 <= order <= 4: + raise ValueError('order is invalid (should be int from 1-4).') + + def get_model_output(x, t): + with th.set_grad_enabled(cond_fn_with_grad and cond_fn is not None): + x = x.detach().requires_grad_() if cond_fn_with_grad else x + out_orig = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + if cond_fn_with_grad: + out = self.condition_score_with_grad(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) + x = x.detach() + else: + out = self.condition_score(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) + else: + out = out_orig + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + return eps, out, out_orig + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + eps, out, out_orig = get_model_output(x, t) + + if order > 1 and old_out is None: + # Pseudo Improved Euler + old_eps = [eps] + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps + eps_2, _, _ = get_model_output(mean_pred, t - 1) + eps_prime = (eps + eps_2) / 2 + pred_prime = self._predict_xstart_from_eps(x, t, eps_prime) + mean_pred = pred_prime * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps_prime + else: + # Pseudo Linear Multistep (Adams-Bashforth) + old_eps = old_out["old_eps"] + old_eps.append(eps) + cur_order = min(order, len(old_eps)) + if cur_order == 1: + eps_prime = old_eps[-1] + elif cur_order == 2: + eps_prime = (3 * old_eps[-1] - old_eps[-2]) / 2 + elif cur_order == 3: + eps_prime = (23 * old_eps[-1] - 16 * old_eps[-2] + 5 * old_eps[-3]) / 12 + elif cur_order == 4: + eps_prime = (55 * old_eps[-1] - 59 * old_eps[-2] + 37 * old_eps[-3] - 9 * old_eps[-4]) / 24 + else: + raise RuntimeError('cur_order is invalid.') + pred_prime = self._predict_xstart_from_eps(x, t, eps_prime) + mean_pred = pred_prime * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps_prime + + if len(old_eps) >= order: + old_eps.pop(0) + + nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + sample = mean_pred * nonzero_mask + out["pred_xstart"] * (1 - nonzero_mask) + + return {"sample": sample, "pred_xstart": out_orig["pred_xstart"], "old_eps": old_eps} + + def plms_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + skip_timesteps=0, + init_image=None, + randomize_class=False, + cond_fn_with_grad=False, + order=2, + ): + """ + Generate samples from the model using Pseudo Linear Multistep. + + Same usage as p_sample_loop(). + """ + final = None + for sample in self.plms_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + skip_timesteps=skip_timesteps, + init_image=init_image, + randomize_class=randomize_class, + cond_fn_with_grad=cond_fn_with_grad, + order=order, + ): + final = sample + return final["sample"] + + def plms_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + skip_timesteps=0, + init_image=None, + randomize_class=False, + cond_fn_with_grad=False, + order=2, + ): + """ + Use PLMS to sample from the model and yield intermediate samples from each + timestep of PLMS. + + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + + if skip_timesteps and init_image is None: + init_image = th.zeros_like(img) + + indices = list(range(self.num_timesteps - skip_timesteps))[::-1] + + if init_image is not None: + my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] + img = self.q_sample(init_image, my_t, img) + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + old_out = None + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + if randomize_class and 'y' in model_kwargs: + model_kwargs['y'] = th.randint(low=0, high=model.num_classes, + size=model_kwargs['y'].shape, + device=model_kwargs['y'].device) + with th.no_grad(): + out = self.plms_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + cond_fn_with_grad=cond_fn_with_grad, + order=order, + old_out=old_out, + ) + yield out + old_out = out + img = out["sample"] + + def _vb_terms_bpd( + self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None + ): + """ + Get a term for the variational lower-bound. + + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + ) + out = self.p_mean_variance( + model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs + ) + kl = normal_kl( + true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] + ) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = th.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"]} + + def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, dataset=None): + """ + Compute training losses for a single timestep. + + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + + # enc = model.model._modules['module'] + enc = model.model + mask = model_kwargs['y']['mask'] + get_xyz = lambda sample: enc.rot2xyz(sample, mask=None, pose_rep=enc.pose_rep, translation=enc.translation, + glob=enc.glob, + # jointstype='vertices', # 3.4 iter/sec # USED ALSO IN MotionCLIP + jointstype='smpl', # 3.4 iter/sec + vertstrans=False) + + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + + terms = {} + + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + terms["loss"] = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + model_kwargs=model_kwargs, + )["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss"] *= self.num_timesteps + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + model_output = model(x_t, self._scale_timesteps(t), **model_kwargs) + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, C = x_t.shape[:2] + assert model_output.shape == (B, C * 2, *x_t.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + target = { + ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + )[0], + ModelMeanType.START_X: x_start, + ModelMeanType.EPSILON: noise, + }[self.model_mean_type] + assert model_output.shape == target.shape == x_start.shape # [bs, njoints, nfeats, nframes] + + terms["rot_mse"] = self.masked_l2(target, model_output, mask) # mean_flat(rot_mse) + + target_xyz, model_output_xyz = None, None + + if self.lambda_rcxyz > 0.: + target_xyz = get_xyz(target) # [bs, nvertices(vertices)/njoints(smpl), 3, nframes] + model_output_xyz = get_xyz(model_output) # [bs, nvertices, 3, nframes] + terms["rcxyz_mse"] = self.masked_l2(target_xyz, model_output_xyz, mask) # mean_flat((target_xyz - model_output_xyz) ** 2) + + if self.lambda_vel_rcxyz > 0.: + if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: + target_xyz = get_xyz(target) if target_xyz is None else target_xyz + model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz + target_xyz_vel = (target_xyz[:, :, :, 1:] - target_xyz[:, :, :, :-1]) + model_output_xyz_vel = (model_output_xyz[:, :, :, 1:] - model_output_xyz[:, :, :, :-1]) + terms["vel_xyz_mse"] = self.masked_l2(target_xyz_vel, model_output_xyz_vel, mask[:, :, :, 1:]) + + if self.lambda_fc > 0.: + torch.autograd.set_detect_anomaly(True) + if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: + target_xyz = get_xyz(target) if target_xyz is None else target_xyz + model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz + # 'L_Ankle', # 7, 'R_Ankle', # 8 , 'L_Foot', # 10, 'R_Foot', # 11 + l_ankle_idx, r_ankle_idx, l_foot_idx, r_foot_idx = 7, 8, 10, 11 + relevant_joints = [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx] + gt_joint_xyz = target_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] + gt_joint_vel = torch.linalg.norm(gt_joint_xyz[:, :, :, 1:] - gt_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] + fc_mask = torch.unsqueeze((gt_joint_vel <= 0.01), dim=2).repeat(1, 1, 3, 1) + pred_joint_xyz = model_output_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] + pred_vel = pred_joint_xyz[:, :, :, 1:] - pred_joint_xyz[:, :, :, :-1] + pred_vel[~fc_mask] = 0 + terms["fc"] = self.masked_l2(pred_vel, + torch.zeros(pred_vel.shape, device=pred_vel.device), + mask[:, :, :, 1:]) + if self.lambda_vel > 0.: + target_vel = (target[..., 1:] - target[..., :-1]) + model_output_vel = (model_output[..., 1:] - model_output[..., :-1]) + terms["vel_mse"] = self.masked_l2(target_vel[:, :-1, :, :], # Remove last joint, is the root location! + model_output_vel[:, :-1, :, :], + mask[:, :, :, 1:]) # mean_flat((target_vel - model_output_vel) ** 2) + + terms["loss"] = terms["rot_mse"] + terms.get('vb', 0.) +\ + (self.lambda_vel * terms.get('vel_mse', 0.)) +\ + (self.lambda_rcxyz * terms.get('rcxyz_mse', 0.)) + \ + (self.lambda_fc * terms.get('fc', 0.)) + + else: + raise NotImplementedError(self.loss_type) + + return terms + + def fc_loss_rot_repr(self, gt_xyz, pred_xyz, mask): + def to_np_cpu(x): + return x.detach().cpu().numpy() + """ + pose_xyz: SMPL batch tensor of shape: [BatchSize, 24, 3, Frames] + """ + # 'L_Ankle', # 7, 'R_Ankle', # 8 , 'L_Foot', # 10, 'R_Foot', # 11 + + l_ankle_idx, r_ankle_idx = 7, 8 + l_foot_idx, r_foot_idx = 10, 11 + """ Contact calculated by 'Kfir Method' Commented code)""" + # contact_signal = torch.zeros((pose_xyz.shape[0], pose_xyz.shape[3], 2), device=pose_xyz.device) # [BatchSize, Frames, 2] + # left_xyz = 0.5 * (pose_xyz[:, l_ankle_idx, :, :] + pose_xyz[:, l_foot_idx, :, :]) # [BatchSize, 3, Frames] + # right_xyz = 0.5 * (pose_xyz[:, r_ankle_idx, :, :] + pose_xyz[:, r_foot_idx, :, :]) + # left_z, right_z = left_xyz[:, 2, :], right_xyz[:, 2, :] # [BatchSize, Frames] + # left_velocity = torch.linalg.norm(left_xyz[:, :, 2:] - left_xyz[:, :, :-2], axis=1) # [BatchSize, Frames] + # right_velocity = torch.linalg.norm(left_xyz[:, :, 2:] - left_xyz[:, :, :-2], axis=1) + # + # left_z_mask = left_z <= torch.mean(torch.sort(left_z)[0][:, :left_z.shape[1] // 5], axis=-1) + # left_z_mask = torch.stack([left_z_mask, left_z_mask], dim=-1) # [BatchSize, Frames, 2] + # left_z_mask[:, :, 1] = False # Blank right side + # contact_signal[left_z_mask] = 0.4 + # + # right_z_mask = right_z <= torch.mean(torch.sort(right_z)[0][:, :right_z.shape[1] // 5], axis=-1) + # right_z_mask = torch.stack([right_z_mask, right_z_mask], dim=-1) # [BatchSize, Frames, 2] + # right_z_mask[:, :, 0] = False # Blank left side + # contact_signal[right_z_mask] = 0.4 + # contact_signal[left_z <= (torch.mean(torch.sort(left_z)[:left_z.shape[0] // 5]) + 20), 0] = 1 + # contact_signal[right_z <= (torch.mean(torch.sort(right_z)[:right_z.shape[0] // 5]) + 20), 1] = 1 + + # plt.plot(to_np_cpu(left_z[0]), label='left_z') + # plt.plot(to_np_cpu(left_velocity[0]), label='left_velocity') + # plt.plot(to_np_cpu(contact_signal[0, :, 0]), label='left_fc') + # plt.grid() + # plt.legend() + # plt.show() + # plt.plot(to_np_cpu(right_z[0]), label='right_z') + # plt.plot(to_np_cpu(right_velocity[0]), label='right_velocity') + # plt.plot(to_np_cpu(contact_signal[0, :, 1]), label='right_fc') + # plt.grid() + # plt.legend() + # plt.show() + + gt_joint_xyz = gt_xyz[:, [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx], :, :] # [BatchSize, 4, 3, Frames] + gt_joint_vel = torch.linalg.norm(gt_joint_xyz[:, :, :, 1:] - gt_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] + fc_mask = (gt_joint_vel <= 0.01) + pred_joint_xyz = pred_xyz[:, [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx], :, :] # [BatchSize, 4, 3, Frames] + pred_joint_vel = torch.linalg.norm(pred_joint_xyz[:, :, :, 1:] - pred_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] + pred_joint_vel[~fc_mask] = 0 # Blank non-contact velocities frames. [BS,4,FRAMES] + pred_joint_vel = torch.unsqueeze(pred_joint_vel, dim=2) + + """DEBUG CODE""" + # print(f'mask: {mask.shape}') + # print(f'pred_joint_vel: {pred_joint_vel.shape}') + # plt.title(f'Joint: {joint_idx}') + # plt.plot(to_np_cpu(gt_joint_vel[0]), label='velocity') + # plt.plot(to_np_cpu(fc_mask[0]), label='fc') + # plt.grid() + # plt.legend() + # plt.show() + return self.masked_l2(pred_joint_vel, torch.zeros(pred_joint_vel.shape, device=pred_joint_vel.device), + mask[:, :, :, 1:]) + # TODO - NOT USED YET, JUST COMMITING TO NOT DELETE THIS AND KEEP INITIAL IMPLEMENTATION, NOT DONE! + def foot_contact_loss_humanml3d(self, target, model_output): + # root_rot_velocity (B, seq_len, 1) + # root_linear_velocity (B, seq_len, 2) + # root_y (B, seq_len, 1) + # ric_data (B, seq_len, (joint_num - 1)*3) , XYZ + # rot_data (B, seq_len, (joint_num - 1)*6) , 6D + # local_velocity (B, seq_len, joint_num*3) , XYZ + # foot contact (B, seq_len, 4) , + + target_fc = target[:, -4:, :, :] + root_rot_velocity = target[:, :1, :, :] + root_linear_velocity = target[:, 1:3, :, :] + root_y = target[:, 3:4, :, :] + ric_data = target[:, 4:67, :, :] # 4+(3*21)=67 + rot_data = target[:, 67:193, :, :] # 67+(6*21)=193 + local_velocity = target[:, 193:259, :, :] # 193+(3*22)=259 + contact = target[:, 259:, :, :] # 193+(3*22)=259 + contact_mask_gt = contact > 0.5 # contact mask order for indexes are fid_l [7, 10], fid_r [8, 11] + vel_lf_7 = local_velocity[:, 7 * 3:8 * 3, :, :] + vel_rf_8 = local_velocity[:, 8 * 3:9 * 3, :, :] + vel_lf_10 = local_velocity[:, 10 * 3:11 * 3, :, :] + vel_rf_11 = local_velocity[:, 11 * 3:12 * 3, :, :] + + calc_vel_lf_7 = ric_data[:, 6 * 3:7 * 3, :, 1:] - ric_data[:, 6 * 3:7 * 3, :, :-1] + calc_vel_rf_8 = ric_data[:, 7 * 3:8 * 3, :, 1:] - ric_data[:, 7 * 3:8 * 3, :, :-1] + calc_vel_lf_10 = ric_data[:, 9 * 3:10 * 3, :, 1:] - ric_data[:, 9 * 3:10 * 3, :, :-1] + calc_vel_rf_11 = ric_data[:, 10 * 3:11 * 3, :, 1:] - ric_data[:, 10 * 3:11 * 3, :, :-1] + + # vel_foots = torch.stack([vel_lf_7, vel_lf_10, vel_rf_8, vel_rf_11], dim=1) + for chosen_vel_foot_calc, chosen_vel_foot, joint_idx, contact_mask_idx in zip( + [calc_vel_lf_7, calc_vel_rf_8, calc_vel_lf_10, calc_vel_rf_11], + [vel_lf_7, vel_lf_10, vel_rf_8, vel_rf_11], + [7, 10, 8, 11], + [0, 1, 2, 3]): + tmp_mask_gt = contact_mask_gt[:, contact_mask_idx, :, :].cpu().detach().numpy().reshape(-1).astype(int) + chosen_vel_norm = np.linalg.norm(chosen_vel_foot.cpu().detach().numpy().reshape((3, -1)), axis=0) + chosen_vel_calc_norm = np.linalg.norm(chosen_vel_foot_calc.cpu().detach().numpy().reshape((3, -1)), + axis=0) + + print(tmp_mask_gt.shape) + print(chosen_vel_foot.shape) + print(chosen_vel_calc_norm.shape) + import matplotlib.pyplot as plt + plt.plot(tmp_mask_gt, label='FC mask') + plt.plot(chosen_vel_norm, label='Vel. XYZ norm (from vector)') + plt.plot(chosen_vel_calc_norm, label='Vel. XYZ norm (calculated diff XYZ)') + + plt.title(f'FC idx {contact_mask_idx}, Joint Index {joint_idx}') + plt.legend() + plt.show() + # print(vel_foots.shape) + return 0 + + + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + + This term can't be optimized, as it only depends on the encoder. + + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl( + mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 + ) + return mean_flat(kl_prior) / np.log(2.0) + + def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): + """ + Compute the entire variational lower-bound, measured in bits-per-dim, + as well as other related quantities. + + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param clip_denoised: if True, clip denoised samples. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + + :return: a dict containing the following keys: + - total_bpd: the total variational lower-bound, per batch element. + - prior_bpd: the prior term in the lower-bound. + - vb: an [N x T] tensor of terms in the lower-bound. + - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. + - mse: an [N x T] tensor of epsilon MSEs for each timestep. + """ + device = x_start.device + batch_size = x_start.shape[0] + + vb = [] + xstart_mse = [] + mse = [] + for t in list(range(self.num_timesteps))[::-1]: + t_batch = th.tensor([t] * batch_size, device=device) + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) + # Calculate VLB term at the current timestep + with th.no_grad(): + out = self._vb_terms_bpd( + model, + x_start=x_start, + x_t=x_t, + t=t_batch, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + vb.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) + eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) + mse.append(mean_flat((eps - noise) ** 2)) + + vb = th.stack(vb, dim=1) + xstart_mse = th.stack(xstart_mse, dim=1) + mse = th.stack(mse, dim=1) + + prior_bpd = self._prior_bpd(x_start) + total_bpd = vb.sum(dim=1) + prior_bpd + return { + "total_bpd": total_bpd, + "prior_bpd": prior_bpd, + "vb": vb, + "xstart_mse": xstart_mse, + "mse": mse, + } + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res.expand(broadcast_shape) diff --git a/diffusion/logger.py b/diffusion/logger.py new file mode 100644 index 0000000..b1d856d --- /dev/null +++ b/diffusion/logger.py @@ -0,0 +1,495 @@ +""" +Logger copied from OpenAI baselines to avoid extra RL-based dependencies: +https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py +""" + +import os +import sys +import shutil +import os.path as osp +import json +import time +import datetime +import tempfile +import warnings +from collections import defaultdict +from contextlib import contextmanager + +DEBUG = 10 +INFO = 20 +WARN = 30 +ERROR = 40 + +DISABLED = 50 + + +class KVWriter(object): + def writekvs(self, kvs): + raise NotImplementedError + + +class SeqWriter(object): + def writeseq(self, seq): + raise NotImplementedError + + +class HumanOutputFormat(KVWriter, SeqWriter): + def __init__(self, filename_or_file): + if isinstance(filename_or_file, str): + self.file = open(filename_or_file, "wt") + self.own_file = True + else: + assert hasattr(filename_or_file, "read"), ( + "expected file or str, got %s" % filename_or_file + ) + self.file = filename_or_file + self.own_file = False + + def writekvs(self, kvs): + # Create strings for printing + key2str = {} + for (key, val) in sorted(kvs.items()): + if hasattr(val, "__float__"): + valstr = "%-8.3g" % val + else: + valstr = str(val) + key2str[self._truncate(key)] = self._truncate(valstr) + + # Find max widths + if len(key2str) == 0: + print("WARNING: tried to write empty key-value dict") + return + else: + keywidth = max(map(len, key2str.keys())) + valwidth = max(map(len, key2str.values())) + + # Write out the data + dashes = "-" * (keywidth + valwidth + 7) + lines = [dashes] + for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()): + lines.append( + "| %s%s | %s%s |" + % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val))) + ) + lines.append(dashes) + self.file.write("\n".join(lines) + "\n") + + # Flush the output to the file + self.file.flush() + + def _truncate(self, s): + maxlen = 30 + return s[: maxlen - 3] + "..." if len(s) > maxlen else s + + def writeseq(self, seq): + seq = list(seq) + for (i, elem) in enumerate(seq): + self.file.write(elem) + if i < len(seq) - 1: # add space unless this is the last one + self.file.write(" ") + self.file.write("\n") + self.file.flush() + + def close(self): + if self.own_file: + self.file.close() + + +class JSONOutputFormat(KVWriter): + def __init__(self, filename): + self.file = open(filename, "wt") + + def writekvs(self, kvs): + for k, v in sorted(kvs.items()): + if hasattr(v, "dtype"): + kvs[k] = float(v) + self.file.write(json.dumps(kvs) + "\n") + self.file.flush() + + def close(self): + self.file.close() + + +class CSVOutputFormat(KVWriter): + def __init__(self, filename): + self.file = open(filename, "w+t") + self.keys = [] + self.sep = "," + + def writekvs(self, kvs): + # Add our current row to the history + extra_keys = list(kvs.keys() - self.keys) + extra_keys.sort() + if extra_keys: + self.keys.extend(extra_keys) + self.file.seek(0) + lines = self.file.readlines() + self.file.seek(0) + for (i, k) in enumerate(self.keys): + if i > 0: + self.file.write(",") + self.file.write(k) + self.file.write("\n") + for line in lines[1:]: + self.file.write(line[:-1]) + self.file.write(self.sep * len(extra_keys)) + self.file.write("\n") + for (i, k) in enumerate(self.keys): + if i > 0: + self.file.write(",") + v = kvs.get(k) + if v is not None: + self.file.write(str(v)) + self.file.write("\n") + self.file.flush() + + def close(self): + self.file.close() + + +class TensorBoardOutputFormat(KVWriter): + """ + Dumps key/value pairs into TensorBoard's numeric format. + """ + + def __init__(self, dir): + os.makedirs(dir, exist_ok=True) + self.dir = dir + self.step = 1 + prefix = "events" + path = osp.join(osp.abspath(dir), prefix) + import tensorflow as tf + from tensorflow.python import pywrap_tensorflow + from tensorflow.core.util import event_pb2 + from tensorflow.python.util import compat + + self.tf = tf + self.event_pb2 = event_pb2 + self.pywrap_tensorflow = pywrap_tensorflow + self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) + + def writekvs(self, kvs): + def summary_val(k, v): + kwargs = {"tag": k, "simple_value": float(v)} + return self.tf.Summary.Value(**kwargs) + + summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) + event = self.event_pb2.Event(wall_time=time.time(), summary=summary) + event.step = ( + self.step + ) # is there any reason why you'd want to specify the step? + self.writer.WriteEvent(event) + self.writer.Flush() + self.step += 1 + + def close(self): + if self.writer: + self.writer.Close() + self.writer = None + + +def make_output_format(format, ev_dir, log_suffix=""): + os.makedirs(ev_dir, exist_ok=True) + if format == "stdout": + return HumanOutputFormat(sys.stdout) + elif format == "log": + return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) + elif format == "json": + return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) + elif format == "csv": + return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) + elif format == "tensorboard": + return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix)) + else: + raise ValueError("Unknown format specified: %s" % (format,)) + + +# ================================================================ +# API +# ================================================================ + + +def logkv(key, val): + """ + Log a value of some diagnostic + Call this once for each diagnostic quantity, each iteration + If called many times, last value will be used. + """ + get_current().logkv(key, val) + + +def logkv_mean(key, val): + """ + The same as logkv(), but if called many times, values averaged. + """ + get_current().logkv_mean(key, val) + + +def logkvs(d): + """ + Log a dictionary of key-value pairs + """ + for (k, v) in d.items(): + logkv(k, v) + + +def dumpkvs(): + """ + Write all of the diagnostics from the current iteration + """ + return get_current().dumpkvs() + + +def getkvs(): + return get_current().name2val + + +def log(*args, level=INFO): + """ + Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). + """ + get_current().log(*args, level=level) + + +def debug(*args): + log(*args, level=DEBUG) + + +def info(*args): + log(*args, level=INFO) + + +def warn(*args): + log(*args, level=WARN) + + +def error(*args): + log(*args, level=ERROR) + + +def set_level(level): + """ + Set logging threshold on current logger. + """ + get_current().set_level(level) + + +def set_comm(comm): + get_current().set_comm(comm) + + +def get_dir(): + """ + Get directory that log files are being written to. + will be None if there is no output directory (i.e., if you didn't call start) + """ + return get_current().get_dir() + + +record_tabular = logkv +dump_tabular = dumpkvs + + +@contextmanager +def profile_kv(scopename): + logkey = "wait_" + scopename + tstart = time.time() + try: + yield + finally: + get_current().name2val[logkey] += time.time() - tstart + + +def profile(n): + """ + Usage: + @profile("my_func") + def my_func(): code + """ + + def decorator_with_name(func): + def func_wrapper(*args, **kwargs): + with profile_kv(n): + return func(*args, **kwargs) + + return func_wrapper + + return decorator_with_name + + +# ================================================================ +# Backend +# ================================================================ + + +def get_current(): + if Logger.CURRENT is None: + _configure_default_logger() + + return Logger.CURRENT + + +class Logger(object): + DEFAULT = None # A logger with no output files. (See right below class definition) + # So that you can still log to the terminal without setting up any output files + CURRENT = None # Current logger being used by the free functions above + + def __init__(self, dir, output_formats, comm=None): + self.name2val = defaultdict(float) # values this iteration + self.name2cnt = defaultdict(int) + self.level = INFO + self.dir = dir + self.output_formats = output_formats + self.comm = comm + + # Logging API, forwarded + # ---------------------------------------- + def logkv(self, key, val): + self.name2val[key] = val + + def logkv_mean(self, key, val): + oldval, cnt = self.name2val[key], self.name2cnt[key] + self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) + self.name2cnt[key] = cnt + 1 + + def dumpkvs(self): + if self.comm is None: + d = self.name2val + else: + d = mpi_weighted_mean( + self.comm, + { + name: (val, self.name2cnt.get(name, 1)) + for (name, val) in self.name2val.items() + }, + ) + if self.comm.rank != 0: + d["dummy"] = 1 # so we don't get a warning about empty dict + out = d.copy() # Return the dict for unit testing purposes + for fmt in self.output_formats: + if isinstance(fmt, KVWriter): + fmt.writekvs(d) + self.name2val.clear() + self.name2cnt.clear() + return out + + def log(self, *args, level=INFO): + if self.level <= level: + self._do_log(args) + + # Configuration + # ---------------------------------------- + def set_level(self, level): + self.level = level + + def set_comm(self, comm): + self.comm = comm + + def get_dir(self): + return self.dir + + def close(self): + for fmt in self.output_formats: + fmt.close() + + # Misc + # ---------------------------------------- + def _do_log(self, args): + for fmt in self.output_formats: + if isinstance(fmt, SeqWriter): + fmt.writeseq(map(str, args)) + + +def get_rank_without_mpi_import(): + # check environment variables here instead of importing mpi4py + # to avoid calling MPI_Init() when this module is imported + for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: + if varname in os.environ: + return int(os.environ[varname]) + return 0 + + +def mpi_weighted_mean(comm, local_name2valcount): + """ + Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 + Perform a weighted average over dicts that are each on a different node + Input: local_name2valcount: dict mapping key -> (value, count) + Returns: key -> mean + """ + all_name2valcount = comm.gather(local_name2valcount) + if comm.rank == 0: + name2sum = defaultdict(float) + name2count = defaultdict(float) + for n2vc in all_name2valcount: + for (name, (val, count)) in n2vc.items(): + try: + val = float(val) + except ValueError: + if comm.rank == 0: + warnings.warn( + "WARNING: tried to compute mean on non-float {}={}".format( + name, val + ) + ) + else: + name2sum[name] += val * count + name2count[name] += count + return {name: name2sum[name] / name2count[name] for name in name2sum} + else: + return {} + + +def configure(dir=None, format_strs=None, comm=None, log_suffix=""): + """ + If comm is provided, average all numerical stats across that comm + """ + if dir is None: + dir = os.getenv("OPENAI_LOGDIR") + if dir is None: + dir = osp.join( + tempfile.gettempdir(), + datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"), + ) + assert isinstance(dir, str) + dir = os.path.expanduser(dir) + os.makedirs(os.path.expanduser(dir), exist_ok=True) + + rank = get_rank_without_mpi_import() + if rank > 0: + log_suffix = log_suffix + "-rank%03i" % rank + + if format_strs is None: + if rank == 0: + format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",") + else: + format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") + format_strs = filter(None, format_strs) + output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] + + Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) + if output_formats: + log("Logging to %s" % dir) + + +def _configure_default_logger(): + configure() + Logger.DEFAULT = Logger.CURRENT + + +def reset(): + if Logger.CURRENT is not Logger.DEFAULT: + Logger.CURRENT.close() + Logger.CURRENT = Logger.DEFAULT + log("Reset logger") + + +@contextmanager +def scoped_configure(dir=None, format_strs=None, comm=None): + prevlogger = Logger.CURRENT + configure(dir=dir, format_strs=format_strs, comm=comm) + try: + yield + finally: + Logger.CURRENT.close() + Logger.CURRENT = prevlogger + diff --git a/diffusion/losses.py b/diffusion/losses.py new file mode 100644 index 0000000..e3fded1 --- /dev/null +++ b/diffusion/losses.py @@ -0,0 +1,77 @@ +# This code is based on https://github.com/openai/guided-diffusion +""" +Helpers for various likelihood-based losses. These are ported from the original +Ho et al. diffusion models codebase: +https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py +""" + +import numpy as np +import torch as th + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, th.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for th.exp(). + logvar1, logvar2 = [ + x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + th.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * th.exp(-logvar2) + ) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = th.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = th.where( + x < -0.999, + log_cdf_plus, + th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs diff --git a/diffusion/nn.py b/diffusion/nn.py new file mode 100644 index 0000000..41c18e7 --- /dev/null +++ b/diffusion/nn.py @@ -0,0 +1,197 @@ +# This code is based on https://github.com/openai/guided-diffusion +""" +Various utilities for neural networks. +""" + +import math + +import torch as th +import torch.nn as nn + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * th.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def update_ema(target_params, source_params, rate=0.99): + """ + Update target parameters to be closer to those of source parameters using + an exponential moving average. + + :param target_params: the target parameter sequence. + :param source_params: the source parameter sequence. + :param rate: the EMA rate (closer to 1 means slower). + """ + for targ, src in zip(target_params, source_params): + targ.detach().mul_(rate).add_(src, alpha=1 - rate) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + +def sum_flat(tensor): + """ + Take the sum over all non-batch dimensions. + """ + return tensor.sum(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = th.exp( + -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) + if dim % 2: + embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(th.autograd.Function): + @staticmethod + @th.cuda.amp.custom_fwd + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_length = length + ctx.save_for_backward(*args) + with th.no_grad(): + output_tensors = ctx.run_function(*args[:length]) + return output_tensors + + @staticmethod + @th.cuda.amp.custom_bwd + def backward(ctx, *output_grads): + args = list(ctx.saved_tensors) + + # Filter for inputs that require grad. If none, exit early. + input_indices = [i for (i, x) in enumerate(args) if x.requires_grad] + if not input_indices: + return (None, None) + tuple(None for _ in args) + + with th.enable_grad(): + for i in input_indices: + if i < ctx.input_length: + # Not sure why the OAI code does this little + # dance. It might not be necessary. + args[i] = args[i].detach().requires_grad_() + args[i] = args[i].view_as(args[i]) + output_tensors = ctx.run_function(*args[:ctx.input_length]) + + if isinstance(output_tensors, th.Tensor): + output_tensors = [output_tensors] + + # Filter for outputs that require grad. If none, exit early. + out_and_grads = [(o, g) for (o, g) in zip(output_tensors, output_grads) if o.requires_grad] + if not out_and_grads: + return (None, None) + tuple(None for _ in args) + + # Compute gradients on the filtered tensors. + computed_grads = th.autograd.grad( + [o for (o, g) in out_and_grads], + [args[i] for i in input_indices], + [g for (o, g) in out_and_grads] + ) + + # Reassemble the complete gradient tuple. + input_grads = [None for _ in args] + for (i, g) in zip(input_indices, computed_grads): + input_grads[i] = g + return (None, None) + tuple(input_grads) diff --git a/diffusion/resample.py b/diffusion/resample.py new file mode 100644 index 0000000..c82eccd --- /dev/null +++ b/diffusion/resample.py @@ -0,0 +1,154 @@ +from abc import ABC, abstractmethod + +import numpy as np +import torch as th +import torch.distributed as dist + + +def create_named_schedule_sampler(name, diffusion): + """ + Create a ScheduleSampler from a library of pre-defined samplers. + + :param name: the name of the sampler. + :param diffusion: the diffusion object to sample for. + """ + if name == "uniform": + return UniformSampler(diffusion) + elif name == "loss-second-moment": + return LossSecondMomentResampler(diffusion) + else: + raise NotImplementedError(f"unknown schedule sampler: {name}") + + +class ScheduleSampler(ABC): + """ + A distribution over timesteps in the diffusion process, intended to reduce + variance of the objective. + + By default, samplers perform unbiased importance sampling, in which the + objective's mean is unchanged. + However, subclasses may override sample() to change how the resampled + terms are reweighted, allowing for actual changes in the objective. + """ + + @abstractmethod + def weights(self): + """ + Get a numpy array of weights, one per diffusion step. + + The weights needn't be normalized, but must be positive. + """ + + def sample(self, batch_size, device): + """ + Importance-sample timesteps for a batch. + + :param batch_size: the number of timesteps. + :param device: the torch device to save to. + :return: a tuple (timesteps, weights): + - timesteps: a tensor of timestep indices. + - weights: a tensor of weights to scale the resulting losses. + """ + w = self.weights() + p = w / np.sum(w) + indices_np = np.random.choice(len(p), size=(batch_size,), p=p) + indices = th.from_numpy(indices_np).long().to(device) + weights_np = 1 / (len(p) * p[indices_np]) + weights = th.from_numpy(weights_np).float().to(device) + return indices, weights + + +class UniformSampler(ScheduleSampler): + def __init__(self, diffusion): + self.diffusion = diffusion + self._weights = np.ones([diffusion.num_timesteps]) + + def weights(self): + return self._weights + + +class LossAwareSampler(ScheduleSampler): + def update_with_local_losses(self, local_ts, local_losses): + """ + Update the reweighting using losses from a model. + + Call this method from each rank with a batch of timesteps and the + corresponding losses for each of those timesteps. + This method will perform synchronization to make sure all of the ranks + maintain the exact same reweighting. + + :param local_ts: an integer Tensor of timesteps. + :param local_losses: a 1D Tensor of losses. + """ + batch_sizes = [ + th.tensor([0], dtype=th.int32, device=local_ts.device) + for _ in range(dist.get_world_size()) + ] + dist.all_gather( + batch_sizes, + th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), + ) + + # Pad all_gather batches to be the maximum batch size. + batch_sizes = [x.item() for x in batch_sizes] + max_bs = max(batch_sizes) + + timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] + loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] + dist.all_gather(timestep_batches, local_ts) + dist.all_gather(loss_batches, local_losses) + timesteps = [ + x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] + ] + losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] + self.update_with_all_losses(timesteps, losses) + + @abstractmethod + def update_with_all_losses(self, ts, losses): + """ + Update the reweighting using losses from a model. + + Sub-classes should override this method to update the reweighting + using losses from the model. + + This method directly updates the reweighting without synchronizing + between workers. It is called by update_with_local_losses from all + ranks with identical arguments. Thus, it should have deterministic + behavior to maintain state across workers. + + :param ts: a list of int timesteps. + :param losses: a list of float losses, one per timestep. + """ + + +class LossSecondMomentResampler(LossAwareSampler): + def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): + self.diffusion = diffusion + self.history_per_term = history_per_term + self.uniform_prob = uniform_prob + self._loss_history = np.zeros( + [diffusion.num_timesteps, history_per_term], dtype=np.float64 + ) + self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) + + def weights(self): + if not self._warmed_up(): + return np.ones([self.diffusion.num_timesteps], dtype=np.float64) + weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) + weights /= np.sum(weights) + weights *= 1 - self.uniform_prob + weights += self.uniform_prob / len(weights) + return weights + + def update_with_all_losses(self, ts, losses): + for t, loss in zip(ts, losses): + if self._loss_counts[t] == self.history_per_term: + # Shift out the oldest loss term. + self._loss_history[t, :-1] = self._loss_history[t, 1:] + self._loss_history[t, -1] = loss + else: + self._loss_history[t, self._loss_counts[t]] = loss + self._loss_counts[t] += 1 + + def _warmed_up(self): + return (self._loss_counts == self.history_per_term).all() diff --git a/diffusion/respace.py b/diffusion/respace.py new file mode 100644 index 0000000..13a3c06 --- /dev/null +++ b/diffusion/respace.py @@ -0,0 +1,129 @@ +# This code is based on https://github.com/openai/guided-diffusion +import numpy as np +import torch as th + +from .gaussian_diffusion import GaussianDiffusion + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError( + f"cannot create exactly {num_timesteps} steps with an integer stride" + ) + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError( + f"cannot divide section of {size} steps into {section_count}" + ) + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + def training_losses( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel( + model, self.timestep_map, self.rescale_timesteps, self.original_num_steps + ) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +class _WrappedModel: + def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): + self.model = model + self.timestep_map = timestep_map + self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + if self.rescale_timesteps: + new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, new_ts, **kwargs) diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..d8bc72d --- /dev/null +++ b/environment.yml @@ -0,0 +1,217 @@ +name: ggvad +channels: + - pytorch + - conda-forge + - anaconda + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - beautifulsoup4=4.11.1=pyha770c72_0 + - blas=1.0=mkl + - brotlipy=0.7.0=py37h540881e_1004 + - ca-certificates=2022.07.19=h06a4308_0 + - catalogue=2.0.8=py37h89c1867_0 + - certifi=2022.6.15=py37h06a4308_0 + - cffi=1.15.1=py37h74dc2b5_0 + - charset-normalizer=2.1.1=pyhd8ed1ab_0 + - colorama=0.4.5=pyhd8ed1ab_0 + - cryptography=35.0.0=py37hf1a17b8_2 + - cudatoolkit=11.0.221=h6bb024c_0 + - cycler=0.11.0=pyhd3eb1b0_0 + - cymem=2.0.6=py37hd23a5d3_3 + - dataclasses=0.8=pyhc8e2a94_3 + - dbus=1.13.18=hb2f20db_0 + - expat=2.4.9=h6a678d5_0 + - fftw=3.3.9=h27cfd23_1 + - filelock=3.8.0=pyhd8ed1ab_0 + - fontconfig=2.13.1=h6c09931_0 + - freetype=2.11.0=h70c0345_0 + - gdown=4.5.1=pyhd8ed1ab_0 + - giflib=5.2.1=h7b6447c_0 + - glib=2.69.1=h4ff587b_1 + - gst-plugins-base=1.14.0=h8213a91_2 + - gstreamer=1.14.0=h28cd5cc_2 + - h5py=3.7.0=py37h737f45e_0 + - hdf5=1.10.6=h3ffc7dd_1 + - icu=58.2=he6710b0_3 + - idna=3.4=pyhd8ed1ab_0 + - intel-openmp=2021.4.0=h06a4308_3561 + - jinja2=3.1.2=pyhd8ed1ab_1 + - joblib=1.1.0=pyhd3eb1b0_0 + - jpeg=9b=h024ee3a_2 + - kiwisolver=1.4.2=py37h295c915_0 + - langcodes=3.3.0=pyhd8ed1ab_0 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.38=h1181459_1 + - libffi=3.3=he6710b0_2 + - libgcc-ng=11.2.0=h1234567_1 + - libgfortran-ng=11.2.0=h00389a5_1 + - libgfortran5=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libpng=1.6.37=hbc83047_0 + - libstdcxx-ng=11.2.0=h1234567_1 + - libtiff=4.1.0=h2733197_1 + - libuuid=1.0.3=h7f8727e_2 + - libuv=1.40.0=h7b6447c_0 + - libwebp=1.2.0=h89dd481_0 + - libxcb=1.15=h7f8727e_0 + - libxml2=2.9.14=h74e7548_0 + - lz4-c=1.9.3=h295c915_1 + - markupsafe=2.1.1=py37h540881e_1 + - matplotlib=3.1.3=py37_0 + - matplotlib-base=3.1.3=py37hef1b27d_0 + - mkl=2021.4.0=h06a4308_640 + - mkl-service=2.4.0=py37h7f8727e_0 + - mkl_fft=1.3.1=py37hd3c417c_0 + - mkl_random=1.2.2=py37h51133e4_0 + - ncurses=6.3=h5eee18b_3 + - ninja=1.10.2=h06a4308_5 + - ninja-base=1.10.2=hd09550d_5 + - numpy=1.21.5=py37h6c91a56_3 + - numpy-base=1.21.5=py37ha15fc14_3 + - openssl=1.1.1q=h7f8727e_0 + - packaging=21.3=pyhd8ed1ab_0 + - pathy=0.6.2=pyhd8ed1ab_0 + - pcre=8.45=h295c915_0 + - pillow=9.2.0=py37hace64e9_1 + - pip=22.2.2=py37h06a4308_0 + - pycparser=2.21=pyhd8ed1ab_0 + - pydantic=1.8.2=py37h5e8e339_2 + - pyopenssl=22.0.0=pyhd8ed1ab_1 + - pyparsing=3.0.9=py37h06a4308_0 + - pyqt=5.9.2=py37h05f1152_2 + - pysocks=1.7.1=py37h89c1867_5 + - python=3.7.13=h12debd9_0 + - python-dateutil=2.8.2=pyhd3eb1b0_0 + - python_abi=3.7=2_cp37m + - pytorch=1.7.1=py3.7_cuda11.0.221_cudnn8.0.5_0 + - qt=5.9.7=h5867ecd_1 + - readline=8.1.2=h7f8727e_1 + - requests=2.28.1=pyhd8ed1ab_1 + - scikit-learn=1.0.2=py37h51133e4_1 + - scipy=1.7.3=py37h6c91a56_2 + - setuptools=63.4.1=py37h06a4308_0 + - shellingham=1.5.0=pyhd8ed1ab_0 + - sip=4.19.8=py37hf484d3e_0 + - six=1.16.0=pyhd3eb1b0_1 + - smart_open=5.2.1=pyhd8ed1ab_0 + - soupsieve=2.3.2.post1=pyhd8ed1ab_0 + - spacy=3.3.1=py37h79cecc1_0 + - spacy-legacy=3.0.10=pyhd8ed1ab_0 + - spacy-loggers=1.0.3=pyhd8ed1ab_0 + - sqlite=3.39.3=h5082296_0 + - threadpoolctl=2.2.0=pyh0d69192_0 + - tk=8.6.12=h1ccaba5_0 + - torchaudio=0.7.2=py37 + - torchvision=0.8.2=py37_cu110 + - tornado=6.2=py37h5eee18b_0 + - tqdm=4.64.1=py37h06a4308_0 + - trimesh=3.15.3=pyh1a96a4e_0 + - typer=0.4.2=pyhd8ed1ab_0 + - wheel=0.37.1=pyhd3eb1b0_0 + - xz=5.2.6=h5eee18b_0 + - zipp=3.8.1=pyhd8ed1ab_0 + - zlib=1.2.12=h5eee18b_3 + - zstd=1.4.9=haebb681_0 + - pip: + - anyio==3.7.1 + - appdirs==1.4.4 + - argon2-cffi==21.3.0 + - argon2-cffi-bindings==21.2.0 + - attrs==23.1.0 + - audioread==3.0.0 + - backcall==0.2.0 + - bleach==6.0.0 + - blis==0.7.8 + - blobfile==2.0.2 + - chumpy==0.70 + - click==8.1.3 + - comm==0.1.4 + - confection==0.0.2 + - debugpy==1.6.7.post1 + - decorator==5.1.1 + - defusedxml==0.7.1 + - docker-pycreds==0.4.0 + - einops==0.6.1 + - entrypoints==0.4 + - exceptiongroup==1.1.2 + - fastjsonschema==2.18.0 + - ftfy==6.1.1 + - gitdb==4.0.10 + - gitpython==3.1.32 + - importlib-metadata==5.0.0 + - importlib-resources==5.12.0 + - ipykernel==6.16.2 + - ipython==7.34.0 + - ipython-genutils==0.2.0 + - ipywidgets==8.1.0 + - jedi==0.19.0 + - jsonschema==4.17.3 + - jupyter==1.0.0 + - jupyter-client==7.4.9 + - jupyter-console==6.6.3 + - jupyter-core==4.12.0 + - jupyter-server==1.24.0 + - jupyterlab-pygments==0.2.2 + - jupyterlab-widgets==3.0.8 + - lazy-loader==0.2 + - librosa==0.10.0.post2 + - llvmlite==0.39.1 + - lxml==4.9.1 + - matplotlib-inline==0.1.6 + - mistune==3.0.1 + - msgpack==1.0.5 + - murmurhash==1.0.8 + - nbclassic==1.0.0 + - nbclient==0.7.4 + - nbconvert==7.6.0 + - nbformat==5.8.0 + - nest-asyncio==1.5.7 + - notebook==6.5.5 + - notebook-shim==0.2.3 + - numba==0.56.4 + - pandocfilters==1.5.0 + - parso==0.8.3 + - pathtools==0.1.2 + - pexpect==4.8.0 + - pickleshare==0.7.5 + - pkgutil-resolve-name==1.3.10 + - pooch==1.6.0 + - preshed==3.0.7 + - prometheus-client==0.17.1 + - prompt-toolkit==3.0.39 + - protobuf==4.24.0 + - psutil==5.9.5 + - ptyprocess==0.7.0 + - pycryptodomex==3.15.0 + - pygments==2.16.1 + - pyrsistent==0.19.3 + - python-speech-features==0.6 + - pyyaml==6.0 + - pyzmq==24.0.1 + - qtconsole==5.4.3 + - qtpy==2.3.1 + - regex==2022.9.13 + - send2trash==1.8.2 + - sentry-sdk==1.29.2 + - setproctitle==1.3.2 + - smmap==5.0.0 + - smplx==0.1.28 + - sniffio==1.3.0 + - soundfile==0.12.1 + - soxr==0.3.5 + - srsly==2.4.4 + - terminado==0.17.1 + - thinc==8.0.17 + - tinycss2==1.2.1 + - traitlets==5.9.0 + - typing-extensions==4.1.1 + - urllib3==1.26.12 + - wandb==0.15.8 + - wasabi==0.10.1 + - wcwidth==0.2.5 + - webencodings==0.5.1 + - websocket-client==1.6.1 + - widgetsnbextension==4.0.8 +prefix: /root/miniconda3/envs/ggvad \ No newline at end of file diff --git a/eval/eval_genea.py b/eval/eval_genea.py new file mode 100644 index 0000000..52d0d84 --- /dev/null +++ b/eval/eval_genea.py @@ -0,0 +1,200 @@ +from data_loaders.gesture.scripts import motion_process as mp +from data_loaders.get_data import get_dataset_loader +import numpy as np +from tqdm import tqdm +from utils import dist_util +import torch +import bvhsdk +from evaluation_metric.embedding_space_evaluator import EmbeddingSpaceEvaluator +from evaluation_metric.train_AE import make_tensor +import matplotlib.pyplot as plt + +# Imports for calling from command line +from utils.parser_util import generate_args +from utils.fixseed import fixseed +from utils.model_util import create_model_and_diffusion, load_model_wo_clip + + +class GeneaEvaluator: + def __init__(self, args, model, diffusion): + self.args = args + self.model = model + self.diffusion = diffusion + self.dataloader = get_dataset_loader(name=args.dataset, + batch_size=args.batch_size, + data_dir=args.data_dir, + num_frames=args.num_frames, + step=args.num_frames, #no overlap + use_wavlm=args.use_wavlm, + use_vad=True, #Hard-coded to get vad from files but the model will not use it since args.use_vad=False + vadfromtext=args.vadfromtext, + split='val') + self.data = self.dataloader.dataset + self.bvhreference = bvhsdk.ReadFile(args.bvh_reference_file, skipmotion=True) + self.idx_positions, self.idx_rotations = mp.get_indexes('genea2023') # hard-coded 'genea2023' because std and mean vec are computed for this representation + self.fgd_evaluator = EmbeddingSpaceEvaluator(args.fgd_embedding, args.num_frames, dist_util.dev()) + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + + def eval(self, samples=None): + print('Starting evaluation...') + n_samples = samples if samples else len(self.data.takes) + n_chunks = np.min(self.data.samples_per_file) + rot, gt_rot, pos, gt_pos, vad = self.sampleval(n_samples, n_chunks) + pos, rot = mp.filter_and_interp(rot, pos, num_frames=self.args.num_frames) + + listpos, listposgt = [], [] + print('Converting to BVH and back to get positions...') + for i in tqdm(range(len(pos))): + # Transform to BVH and get positions of sampled motion + bvhreference = mp.tobvh(self.bvhreference, rot[i], pos[i]) + listpos.append(mp.posfrombvh(bvhreference)) + # Transform to BVH and get positions of ground truth motion + # This is just a sanity check since we could get directly from the npy files + #bvhreference = mp.tobvh(self.bvhreference, gt_rot[i], gt_pos[i]) + #listposgt.append(mp.posfrombvh(bvhreference)) + + # Compute FGD + fgd_on_feat = self.fgd(listpos, listposgt, n_samples, stride=40) + + histfig = self.getvelhist(rot, vad) + + return fgd_on_feat, histfig + + def sampleval(self, samples=None, chunks=None): + assert chunks <= np.min(self.data.samples_per_file) # assert that we don't go over the number of chunks per file + allsampledmotion = [] + allsampleposition = [] + allgtmotion = [] + allgtposition = [] + allvad = [] + print('Evaluating validation set') + for idx in range(chunks): + print('### Sampling chunk {} of {}'.format(idx+1, chunks)) + batch = self.data.getvalbatch(num_takes=samples, index=idx) + gt_motion, model_kwargs = self.dataloader.collate_fn(batch) # gt_motion: [num_samples(bs), njoints, 1, chunk_len] + model_kwargs['y'] = {key: val.to(dist_util.dev()) if torch.is_tensor(val) else val for key, val in model_kwargs['y'].items()} #seed: [num_samples(bs), njoints, 1, seed_len] + if idx > 0 and self.args.seed_poses > 0: + model_kwargs['y']['seed'] = sample_out[...,-self.data.n_seed_poses:] + sample_fn = self.diffusion.p_sample_loop + sample_out = sample_fn( + self.model, + (samples, self.model.njoints, self.model.nfeats, self.args.num_frames), + clip_denoised=False, + model_kwargs=model_kwargs, + skip_timesteps=0, # 0 is the default value - i.e. don't skip any step + init_image=None, + progress=True, + dump_steps=None, + noise=None, + const_noise=False, + ) # [num_samples(bs), njoints, 1, chunk_len] + + sample = self.data.inv_transform(sample_out.cpu().permute(0, 2, 3, 1)).float() # [num_samples(bs), 1, chunk_len, njoints] + gt_motion = self.data.inv_transform(gt_motion.cpu().permute(0, 2, 3, 1)).float() # [num_samples(bs), 1, chunk_len, njoints] + + # Split the data into positions and rotations + gt_pos, gt_rot = mp.split_pos_rot(self.args.dataset, gt_motion) + sample_pos, sample_rot = mp.split_pos_rot(self.args.dataset, sample) + + # Convert numpy array to euler angles + if self.args.dataset == 'genea2023+': + gt_rot = mp.rot6d_to_euler(gt_rot) + sample_rot = mp.rot6d_to_euler(sample_rot) + + sample_pos = sample_pos.view(sample_pos.shape[:-1] + (-1, 3)) # [num_samples(bs), 1, chunk_len, n_joints/3, 3] + sample_pos = sample_pos.view(-1, *sample_pos.shape[2:]).permute(0, 2, 3, 1) + gt_pos = gt_pos.view(gt_pos.shape[:-1] + (-1, 3)) # [num_samples(bs), 1, chunk_len, n_joints/3, 3] + gt_pos = gt_pos.view(-1, *gt_pos.shape[2:]).permute(0, 2, 3, 1) + + allsampledmotion.append(sample_rot.cpu().numpy()) + allgtmotion.append(gt_rot.cpu().numpy()) + allsampleposition.append(sample_pos.squeeze().cpu().numpy()) + allgtposition.append(gt_pos.squeeze().cpu().numpy()) + allvad.append(model_kwargs['y']['vad'].cpu().numpy()) + + allsampledmotion = np.concatenate(allsampledmotion, axis=3) + allgtmotion = np.concatenate(allgtmotion, axis=3) + allsampleposition = np.concatenate(allsampleposition, axis=3) + allgtposition = np.concatenate(allgtposition, axis=3) + allvad = np.concatenate(allvad, axis=1) + + return allsampledmotion, allgtmotion, allsampleposition, allgtposition, allvad + + def getvelhist(self, motion, vad): + joints = self.data.getjoints() + fvad = vad[:,1:].flatten() + wvad, wovad = [], [] + for joint, index in joints.items(): + vels = np.sum(np.abs((motion[:,index,:, 1:] - motion[:,index,:, :-1])), axis=1).flatten() + wvad += list(vels[fvad==1]) + wovad += list(vels[fvad==0]) + n_bins =200 + fig, axs = plt.subplots(1, 1, sharex = True, tight_layout=True, figsize=(20,20)) + axs.hist(wvad, bins = n_bins, histtype='step', label='VAD = 1', linewidth=4, color='red') + axs.hist(wovad, bins = n_bins, histtype='step', label='VAD = 0', linewidth=4, color='black') + axs.set_yscale('log') + return fig + + def fgd(self, listpos, listposgt=None, n_samples=100, stride=None): + # "Direct" ground truth positions + real_val = make_tensor(f'./dataset/Genea2023/val/main-agent/motion_npy_rotpos', self.args.num_frames, max_files=n_samples, n_chunks=None, stride=stride).to(self.device) + real_trn = make_tensor(f'./dataset/Genea2023/trn/main-agent/motion_npy_rotpos', self.args.num_frames, max_files=n_samples, n_chunks=None, stride=stride).to(self.device) + + #gt_data = self.fgd_prep(listposgt).to(self.device) + test_data = self.fgd_prep(listpos, stride=stride).to(self.device) + + print('Samples shape:') + print(test_data.shape) + print('Validation shape:') + print(real_val.shape) + print('Train shape:') + print(real_trn.shape) + + fgd_on_feat = self.run_fgd(real_val, test_data) + print(f'Sampled to validation: {fgd_on_feat:8.3f}') + + fgd_on_feat_ = self.run_fgd(real_trn, test_data) + print(f'Sampled to train: {fgd_on_feat_:8.3f}') + #fgd_on_feat = self.run_fgd(gt_data, test_data) + #print(f'Sampled to validation from pipeline: {fgd_on_feat:8.3f}') + + #fgd_on_feat = self.run_fgd(real_val, gt_data) + #print(f'Validation from pipeline to validation (should be zero): {fgd_on_feat:8.3f}') + return fgd_on_feat + + def fgd_prep(self, data, n_frames=120, stride=None): + # Prepare samples for FGD evaluation + samples = [] + stride = n_frames // 2 if stride is None else stride + for take in data: + for i in range(0, len(take) - n_frames, stride): + sample = take[i:i+n_frames] + sample = (sample - self.data.mean[self.idx_positions]) / self.data.std[self.idx_positions] + samples.append(sample) + return torch.Tensor(samples) + + def run_fgd(self, gt_data, test_data): + # Run FGD evaluation on the given data + self.fgd_evaluator.reset() + self.fgd_evaluator.push_real_samples(gt_data) + self.fgd_evaluator.push_generated_samples(test_data) + fgd_on_feat = self.fgd_evaluator.get_fgd(use_feat_space=True) + return fgd_on_feat + +def main(): + args = generate_args() + fixseed(args.seed) + dist_util.setup_dist(args.device) + print("Creating model and diffusion...") + model, diffusion = create_model_and_diffusion(args, None) + print(f"Loading checkpoints from [{args.model_path}]...") + state_dict = torch.load(args.model_path, map_location='cpu') + load_model_wo_clip(model, state_dict) + model.to(dist_util.dev()) + model.eval() # disable random masking + GeneaEvaluator(args, model, diffusion).eval() + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/eval/eval_geneaval.py b/eval/eval_geneaval.py new file mode 100644 index 0000000..e85dc61 --- /dev/null +++ b/eval/eval_geneaval.py @@ -0,0 +1,62 @@ +from data_loaders.gesture.scripts import motion_process as mp +from data_loaders.get_data import get_dataset_loader +import numpy as np +from tqdm import tqdm +from utils import dist_util +import torch +import bvhsdk +from evaluation_metric.embedding_space_evaluator import EmbeddingSpaceEvaluator +from evaluation_metric.train_AE import make_tensor +import matplotlib.pyplot as plt + +# Imports for calling from command line +from utils.parser_util import generate_args +from utils.fixseed import fixseed +from utils.model_util import create_model_and_diffusion, load_model_wo_clip + + +class GeneaEvaluator: + def __init__(self): + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.fgd_evaluator = EmbeddingSpaceEvaluator('./evaluation_metric/output/model_checkpoint_120.bin', 120, self.device) + + def eval(self, samples=None, chunks=None): + print('Starting evaluation...') + + # Compute FGD + fgd_on_feat = self.fgd(40, None) + + def fgd(self, n_samples=100, n_chunks=1): + # "Direct" ground truth positions + real_val = make_tensor(f'./dataset/Genea2023/trn/main-agent/motion_npy_rotpos', 120, max_files=n_samples, n_chunks=n_chunks, stride=40).to(self.device) + test_data = make_tensor(f'./dataset/Genea2023/val/main-agent/motion_npy_rotpos', 120, max_files=n_samples, n_chunks=n_chunks, stride=40).to(self.device) + + fgd_on_feat = self.run_fgd(real_val, test_data) + print(f'Validation to train: {fgd_on_feat:8.3f}') + return fgd_on_feat + + def fgd_prep(self, data, n_frames=120, stride=None): + # Prepare samples for FGD evaluation + samples = [] + stride = n_frames // 2 if stride is None else stride + for take in data: + for i in range(0, len(take) - n_frames, stride): + sample = take[i:i+n_frames] + sample = (sample - self.data.mean[self.idx_positions]) / self.data.std[self.idx_positions] + samples.append(sample) + return torch.Tensor(samples) + + def run_fgd(self, gt_data, test_data): + # Run FGD evaluation on the given data + self.fgd_evaluator.reset() + self.fgd_evaluator.push_real_samples(gt_data) + self.fgd_evaluator.push_generated_samples(test_data) + fgd_on_feat = self.fgd_evaluator.get_fgd(use_feat_space=True) + return fgd_on_feat + +def main(): + GeneaEvaluator().eval() + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/evaluation_metric/README.md b/evaluation_metric/README.md new file mode 100644 index 0000000..5bc2187 --- /dev/null +++ b/evaluation_metric/README.md @@ -0,0 +1,31 @@ +## Fréchet Gesture Distance (FGD) + +Scripts to calculate FGD for the GENEA Gesture Generation Challenge 2023. +We followed the FGD implementation in [Speech Gesture Generation from the Trimodal Context of Text, Audio, and Speaker Identity (ACM TOG, 2020)](https://arxiv.org/abs/2009.02119). It compares distributions of human motion and generated motion to evaluate how the generated motion similar to human motion. Note that FGD only considers main agent motion, not speech and interlocutor context. + +**Disclaimer: Official evaluation of the GENEA Challenge 2023 is subjective human evaluation. We provide this objective metric to help participants evaluate their models faster in the development phase since we found a moderate correlation between FGD and subjective evaluation ratings. Please see [this arXiv paper](https://arxiv.org/pdf/2303.08737.pdf) for more details. Again, a good (low) FGD does not guarantee human preferences.** + +The scripts were developed and tested on Ubuntu 22.04, Python 3.6, Pytorch 1.8.0. + +### Usage +1. Prepare the pre-trained feature extractor checkpoint (included in this repo; see `output` folder). +2. Convert the generated motion to 3D joint coordinates. You can refer `extract_joint_positions.py` script. +3. Calculate FGD between the sets of natural motion and generated motion + ```bash + # make sure you're correctly loading the generated motion + $ python evaluate_FGD.py + ``` + +### Training the feature extractor + +You can follow the steps below to train the feature extractor. + +1. Download GENEA 2023 dataset +2. Convert BVH files to 3D joint coordinates + ```bash + $ python extract_joint_positions.py + ``` +3. Train an autoencoder on the train set. You can set `n_frames` in `train_AE.py` to change the number of frames in a sample. + ```bash + $ python train_AE.py + ``` diff --git a/evaluation_metric/calculate_mean_std.py b/evaluation_metric/calculate_mean_std.py new file mode 100644 index 0000000..60c82fc --- /dev/null +++ b/evaluation_metric/calculate_mean_std.py @@ -0,0 +1,26 @@ +''' +calculate mean and std of the train dataset +''' + +import glob +import numpy as np + +files = glob.glob(f'../dataset/genea2023_dataset/trn/main-agent/npy/*.npy') +files = files[::5] # use a subset to avoid out of memory + +all_data = [] +for file in files: + data = np.load(file) + all_data.append(data) + +all_data = np.vstack(all_data) + +print(all_data.shape) + +mean = np.mean(all_data, axis=0) +std = np.std(all_data, axis=0) +print(mean.shape) +print(std.shape) + +print(*mean, sep=',') +print(*std, sep=',') diff --git a/evaluation_metric/embedding_net.py b/evaluation_metric/embedding_net.py new file mode 100644 index 0000000..2c81d4d --- /dev/null +++ b/evaluation_metric/embedding_net.py @@ -0,0 +1,137 @@ +import torch +import torch.nn as nn + + +def ConvNormRelu(in_channels, out_channels, downsample=False, padding=0, batchnorm=True): + if not downsample: + k = 3 + s = 1 + else: + k = 4 + s = 2 + + conv_block = nn.Conv1d(in_channels, out_channels, kernel_size=k, stride=s, padding=padding) + norm_block = nn.BatchNorm1d(out_channels) + + if batchnorm: + net = nn.Sequential( + conv_block, + norm_block, + nn.LeakyReLU(0.2, True) + ) + else: + net = nn.Sequential( + conv_block, + nn.LeakyReLU(0.2, True) + ) + + return net + + +class PoseEncoderConv(nn.Module): + def __init__(self, dim, length): + super().__init__() + + self.net = nn.Sequential( + ConvNormRelu(dim, 128, batchnorm=True), + ConvNormRelu(128, 64, batchnorm=True), + ConvNormRelu(64, 64, True, batchnorm=True), + nn.Conv1d(64, 32, 3) + ) + + if length == 30: + in_channels = 320 + elif length == 60: + in_channels = 800 + elif length == 90: + in_channels = 1280 + elif length == 120: + in_channels = 1760 + else: + assert False + + self.out_net = nn.Sequential( + nn.Linear(in_channels, 256), + nn.BatchNorm1d(256), + nn.LeakyReLU(True), + nn.Linear(256, 128), + nn.BatchNorm1d(128), + nn.LeakyReLU(True), + nn.Linear(128, 32), + ) + + def forward(self, poses): + poses = poses.transpose(1, 2) # to (bs, dim, seq) + out = self.net(poses) + out = out.flatten(1) + z = self.out_net(out) + + return z + + +class PoseDecoderConv(nn.Module): + def __init__(self, dim, length): + super().__init__() + + if length == 30: + out_channels = 120 + elif length == 60: + out_channels = 240 + elif length == 90: + out_channels = 360 + elif length == 120: + out_channels = 480 + else: + assert False + + self.pre_net = nn.Sequential( + nn.Linear(32, 64), + nn.BatchNorm1d(64), + nn.LeakyReLU(True), + nn.Linear(64, out_channels), + ) + + self.net = nn.Sequential( + nn.ConvTranspose1d(4, 32, 3), + nn.BatchNorm1d(32), + nn.LeakyReLU(0.2, True), + nn.ConvTranspose1d(32, 32, 3), + nn.BatchNorm1d(32), + nn.LeakyReLU(0.2, True), + nn.Conv1d(32, 32, 3), + nn.Conv1d(32, dim, 3), + ) + + def forward(self, feat): + out = self.pre_net(feat) + out = out.view(feat.shape[0], 4, -1) + out = self.net(out) + out = out.transpose(1, 2) + return out + + +class EmbeddingNet(nn.Module): + def __init__(self, pose_dim, n_frames): + super().__init__() + self.pose_encoder = PoseEncoderConv(pose_dim, n_frames) + self.decoder = PoseDecoderConv(pose_dim, n_frames) + + def forward(self, poses): + poses_feat = self.pose_encoder(poses) + out_poses = self.decoder(poses_feat) + return poses_feat, out_poses + + +if __name__ == '__main__': # model test + n_frames = 90 + pose_dim = 174 + encoder = PoseEncoderConv(pose_dim, n_frames) + decoder = PoseDecoderConv(pose_dim, n_frames) + + poses = torch.randn((4, n_frames, pose_dim)) + feat = encoder(poses) + recon_poses = decoder(feat) + + print('input', poses.shape) + print('feat', feat.shape) + print('output', recon_poses.shape) diff --git a/evaluation_metric/embedding_space_evaluator.py b/evaluation_metric/embedding_space_evaluator.py new file mode 100644 index 0000000..6909909 --- /dev/null +++ b/evaluation_metric/embedding_space_evaluator.py @@ -0,0 +1,115 @@ +import numpy as np +import torch +from scipy import linalg + +from evaluation_metric.embedding_net import EmbeddingNet + +import warnings +warnings.filterwarnings("ignore", category=RuntimeWarning) # ignore warnings + + +class EmbeddingSpaceEvaluator: + def __init__(self, embed_net_path, n_frames, device): + # init embed net + ckpt = torch.load(embed_net_path, map_location=device) + self.pose_dim = ckpt['pose_dim'] + self.net = EmbeddingNet(self.pose_dim, n_frames).to(device) + self.net.load_state_dict(ckpt['gen_dict']) + self.net.train(False) + + self.reset() + + def reset(self): + self.real_samples = [] + self.generate_samples = [] + self.real_feat_list = [] + self.generated_feat_list = [] + + def get_no_of_samples(self): + return len(self.real_feat_list) + + def push_real_samples(self, samples): + feat, _ = self.net(samples) + self.real_samples.append(samples.cpu().numpy().reshape(samples.shape[0], -1)) + self.real_feat_list.append(feat.data.cpu().numpy()) + + def push_generated_samples(self, samples): + feat, _ = self.net(samples) + self.generate_samples.append(samples.cpu().numpy().reshape(samples.shape[0], -1)) + self.generated_feat_list.append(feat.data.cpu().numpy()) + + def get_fgd(self, use_feat_space=True): + if use_feat_space: # on feature space + generated_data = np.vstack(self.generated_feat_list) + real_data = np.vstack(self.real_feat_list) + else: # on raw pose space + generated_data = np.vstack(self.generate_samples) + real_data = np.vstack(self.real_samples) + + frechet_dist = self.frechet_distance(generated_data, real_data) + return frechet_dist + + def frechet_distance(self, samples_A, samples_B): + A_mu = np.mean(samples_A, axis=0) + A_sigma = np.cov(samples_A, rowvar=False) + B_mu = np.mean(samples_B, axis=0) + B_sigma = np.cov(samples_B, rowvar=False) + try: + print('Computing frechet distance') + frechet_dist = self.calculate_frechet_distance(A_mu, A_sigma, B_mu, B_sigma) + except ValueError: + print('Something went wrong') + frechet_dist = 1e+10 + return frechet_dist + + @staticmethod + def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """ from https://github.com/mseitzer/pytorch-fid/blob/master/fid_score.py """ + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + Stable version by Dougal J. Sutherland. + Params: + -- mu1 : Numpy array containing the activations of a layer of the + inception net (like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations, precalculated on an + representative data set. + -- sigma1: The covariance matrix over activations for generated samples. + -- sigma2: The covariance matrix over activations, precalculated on an + representative data set. + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, \ + 'Training and test mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, \ + 'Training and test covariances have different dimensions' + + diff = mu1 - mu2 + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ('fid calculation produces singular product; ' + 'adding %s to diagonal of cov estimates') % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError('Imaginary component {}'.format(m)) + covmean = covmean.real + tr_covmean = np.trace(covmean) + + return (diff.dot(diff) + np.trace(sigma1) + + np.trace(sigma2) - 2 * tr_covmean) diff --git a/evaluation_metric/evaluate_FGD.py b/evaluation_metric/evaluate_FGD.py new file mode 100644 index 0000000..f2f0fa6 --- /dev/null +++ b/evaluation_metric/evaluate_FGD.py @@ -0,0 +1,57 @@ +import os + +import numpy as np +import torch + +from embedding_space_evaluator import EmbeddingSpaceEvaluator +from train_AE import make_tensor + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +def run_fgd(fgd_evaluator, gt_data, test_data): + fgd_evaluator.reset() + + print('Pushing real samples...') + fgd_evaluator.push_real_samples(gt_data) + print('Pushing generated samples...') + fgd_evaluator.push_generated_samples(test_data) + print('Calculating FGD with feat space...') + fgd_on_feat = fgd_evaluator.get_fgd(use_feat_space=True) + #print('Calculating FGD without feat space...') + #fdg_on_raw = fgd_evaluator.get_fgd(use_feat_space=False) + return fgd_on_feat#, fdg_on_raw + + +def add_noise(x): + noise_level = 1 + x_noise = x + (noise_level ** 0.5) * torch.randn(x.size()).to(device) + return x_noise + + +def exp_base(chunk_len): + # AE model + ae_path = f'./evaluation_metric/output/model_checkpoint_{chunk_len}.bin' + fgd_evaluator = EmbeddingSpaceEvaluator(ae_path, chunk_len, device) + + # load human data + human_data = make_tensor(f'./dataset/Genea2023/trn/main-agent/motion_npy_rotpos', chunk_len).to(device) + print(human_data.size()) + + # simulate generated motion by adding noise to human motion + # load the generated motion when you actually use this code for model evaluation + #test_data = add_noise(human_data) + test_data = make_tensor(f'./dataset/Genea2023/val/main-agent/motion_npy_rotpos', chunk_len).to(device) + print(human_data.size()) + + print(f'----- Experiment (motion chunk length: {chunk_len}) -----') + print('FGDs on feature space and raw data space') + #fgd_on_feat, fgd_on_raw = run_fgd(fgd_evaluator, human_data, test_data) + fgd_on_feat = run_fgd(fgd_evaluator, human_data, test_data) + #print(f'{fgd_on_feat:8.3f}, {fgd_on_raw:8.3f}') + print(f'{fgd_on_feat:8.3f}') + print() + + +if __name__ == '__main__': + exp_base(120) diff --git a/evaluation_metric/extract_joint_positions.py b/evaluation_metric/extract_joint_positions.py new file mode 100644 index 0000000..5dfbe13 --- /dev/null +++ b/evaluation_metric/extract_joint_positions.py @@ -0,0 +1,35 @@ +import glob + +from pymo.parsers import BVHParser +from pymo.preprocessing import * +from sklearn.pipeline import Pipeline + + +def convert_bvh(bvhfile): + parser = BVHParser() + parsed_data = parser.parse(bvhfile) + + # use a subset of joints + target_joints = ['body_world', 'b_root', 'b_l_upleg', 'b_l_leg', 'b_r_upleg', 'b_r_leg', 'b_spine0', 'b_spine1', 'b_spine2', 'b_spine3', 'b_l_shoulder', 'b_l_arm', 'b_l_arm_twist', 'b_l_forearm', 'b_l_wrist_twist', 'b_l_wrist', 'b_l_pinky1', 'b_l_pinky2', 'b_l_pinky3', 'b_l_ring1', 'b_l_ring2', 'b_l_ring3', 'b_l_middle1', 'b_l_middle2', 'b_l_middle3', 'b_l_index1', 'b_l_index2', 'b_l_index3', 'b_l_thumb0', 'b_l_thumb1', 'b_l_thumb2', 'b_l_thumb3', 'b_r_shoulder', 'b_r_arm', 'b_r_arm_twist', 'b_r_forearm', 'b_r_wrist_twist', 'b_r_wrist', 'b_r_thumb0', 'b_r_thumb1', 'b_r_thumb2', 'b_r_thumb3', 'b_r_pinky1', 'b_r_pinky2', 'b_r_pinky3', 'b_r_middle1', 'b_r_middle2', 'b_r_middle3', 'b_r_ring1', 'b_r_ring2', 'b_r_ring3', 'b_r_index1', 'b_r_index2', 'b_r_index3', 'b_neck0', 'b_head'] + + pipe = Pipeline([ + ('param', MocapParameterizer('position')), + ('jtsel', JointSelector(target_joints, include_root=False)), + ('np', Numpyfier()), + ]) + pos_data = pipe.fit_transform([parsed_data])[0] + + return pos_data + + +if __name__ == '__main__': + bvhfiles = glob.glob('../dataset/genea2023_dataset/trn/main-agent/bvh/*.bvh') + # bvhfiles = glob.glob('../dataset/genea2023_dataset/val/main-agent/bvh/*.bvh') + + for bvhfile in sorted(bvhfiles): + print(bvhfile) + npy_data = convert_bvh(bvhfile) + print(npy_data.shape) + + out_file = bvhfile.replace('.bvh', '.npy').replace('/bvh/', '/npy/') + np.save(out_file, npy_data) diff --git a/evaluation_metric/output/.gitkeep b/evaluation_metric/output/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/evaluation_metric/output/model_checkpoint_120.bin b/evaluation_metric/output/model_checkpoint_120.bin new file mode 100644 index 0000000..12a59e4 Binary files /dev/null and b/evaluation_metric/output/model_checkpoint_120.bin differ diff --git a/evaluation_metric/output/model_checkpoint_30.bin b/evaluation_metric/output/model_checkpoint_30.bin new file mode 100644 index 0000000..6adf8c2 Binary files /dev/null and b/evaluation_metric/output/model_checkpoint_30.bin differ diff --git a/evaluation_metric/pymo/Pivots.py b/evaluation_metric/pymo/Pivots.py new file mode 100644 index 0000000..a29285d --- /dev/null +++ b/evaluation_metric/pymo/Pivots.py @@ -0,0 +1,126 @@ +import numpy as np + +from pymo.Quaternions import Quaternions + + +class Pivots: + """ + Pivots is an ndarray of angular rotations + + This wrapper provides some functions for + working with pivots. + + These are particularly useful as a number + of atomic operations (such as adding or + subtracting) cannot be achieved using + the standard arithmatic and need to be + defined differently to work correctly + """ + + def __init__(self, ps): + self.ps = np.array(ps) + + def __str__(self): + return "Pivots(" + str(self.ps) + ")" + + def __repr__(self): + return "Pivots(" + repr(self.ps) + ")" + + def __add__(self, other): + return Pivots(np.arctan2(np.sin(self.ps + other.ps), np.cos(self.ps + other.ps))) + + def __sub__(self, other): + return Pivots(np.arctan2(np.sin(self.ps - other.ps), np.cos(self.ps - other.ps))) + + def __mul__(self, other): + return Pivots(self.ps * other.ps) + + def __div__(self, other): + return Pivots(self.ps / other.ps) + + def __mod__(self, other): + return Pivots(self.ps % other.ps) + + def __pow__(self, other): + return Pivots(self.ps ** other.ps) + + def __lt__(self, other): + return self.ps < other.ps + + def __le__(self, other): + return self.ps <= other.ps + + def __eq__(self, other): + return self.ps == other.ps + + def __ne__(self, other): + return self.ps != other.ps + + def __ge__(self, other): + return self.ps >= other.ps + + def __gt__(self, other): + return self.ps > other.ps + + def __abs__(self): + return Pivots(abs(self.ps)) + + def __neg__(self): + return Pivots(-self.ps) + + def __iter__(self): + return iter(self.ps) + + def __len__(self): + return len(self.ps) + + def __getitem__(self, k): + return Pivots(self.ps[k]) + + def __setitem__(self, k, v): + self.ps[k] = v.ps + + def _ellipsis(self): + return tuple(map(lambda x: slice(None), self.shape)) + + def quaternions(self, plane='xz'): + fa = self._ellipsis() + axises = np.ones(self.ps.shape + (3,)) + axises[fa + ("xyz".index(plane[0]),)] = 0.0 + axises[fa + ("xyz".index(plane[1]),)] = 0.0 + return Quaternions.from_angle_axis(self.ps, axises) + + def directions(self, plane='xz'): + dirs = np.zeros((len(self.ps), 3)) + dirs["xyz".index(plane[0])] = np.sin(self.ps) + dirs["xyz".index(plane[1])] = np.cos(self.ps) + return dirs + + def normalized(self): + xs = np.copy(self.ps) + while np.any(xs > np.pi): xs[xs > np.pi] = xs[xs > np.pi] - 2 * np.pi + while np.any(xs < -np.pi): xs[xs < -np.pi] = xs[xs < -np.pi] + 2 * np.pi + return Pivots(xs) + + def interpolate(self, ws): + dir = np.average(self.directions, weights=ws, axis=0) + return np.arctan2(dir[2], dir[0]) + + def copy(self): + return Pivots(np.copy(self.ps)) + + @property + def shape(self): + return self.ps.shape + + @classmethod + def from_quaternions(cls, qs, forward='z', plane='xz'): + ds = np.zeros(qs.shape + (3,)) + ds[..., 'xyz'.index(forward)] = 1.0 + return Pivots.from_directions(qs * ds, plane=plane) + + @classmethod + def from_directions(cls, ds, plane='xz'): + ys = ds[..., 'xyz'.index(plane[0])] + xs = ds[..., 'xyz'.index(plane[1])] + return Pivots(np.arctan2(ys, xs)) diff --git a/evaluation_metric/pymo/Quaternions.py b/evaluation_metric/pymo/Quaternions.py new file mode 100644 index 0000000..8718439 --- /dev/null +++ b/evaluation_metric/pymo/Quaternions.py @@ -0,0 +1,500 @@ +import numpy as np + + +class Quaternions: + """ + Quaternions is a wrapper around a numpy ndarray + that allows it to act as if it were an narray of + a quaternion data type. + + Therefore addition, subtraction, multiplication, + division, negation, absolute, are all defined + in terms of quaternion operations such as quaternion + multiplication. + + This allows for much neater code and many routines + which conceptually do the same thing to be written + in the same way for point data and for rotation data. + + The Quaternions class has been desgined such that it + should support broadcasting and slicing in all of the + usual ways. + """ + + def __init__(self, qs): + if isinstance(qs, np.ndarray): + + if len(qs.shape) == 1: qs = np.array([qs]) + self.qs = qs + return + + if isinstance(qs, Quaternions): + self.qs = qs.qs + return + + raise TypeError('Quaternions must be constructed from iterable, numpy array, or Quaternions, not %s' % type(qs)) + + def __str__(self): + return "Quaternions(" + str(self.qs) + ")" + + def __repr__(self): + return "Quaternions(" + repr(self.qs) + ")" + + """ Helper Methods for Broadcasting and Data extraction """ + + @classmethod + def _broadcast(cls, sqs, oqs, scalar=False): + + if isinstance(oqs, float): return sqs, oqs * np.ones(sqs.shape[:-1]) + + ss = np.array(sqs.shape) if not scalar else np.array(sqs.shape[:-1]) + os = np.array(oqs.shape) + + if len(ss) != len(os): + raise TypeError('Quaternions cannot broadcast together shapes %s and %s' % (sqs.shape, oqs.shape)) + + if np.all(ss == os): return sqs, oqs + + if not np.all((ss == os) | (os == np.ones(len(os))) | (ss == np.ones(len(ss)))): + raise TypeError('Quaternions cannot broadcast together shapes %s and %s' % (sqs.shape, oqs.shape)) + + sqsn, oqsn = sqs.copy(), oqs.copy() + + for a in np.where(ss == 1)[0]: sqsn = sqsn.repeat(os[a], axis=a) + for a in np.where(os == 1)[0]: oqsn = oqsn.repeat(ss[a], axis=a) + + return sqsn, oqsn + + """ Adding Quaterions is just Defined as Multiplication """ + + def __add__(self, other): + return self * other + + def __sub__(self, other): + return self / other + + """ Quaterion Multiplication """ + + def __mul__(self, other): + """ + Quaternion multiplication has three main methods. + + When multiplying a Quaternions array by Quaternions + normal quaternion multiplication is performed. + + When multiplying a Quaternions array by a vector + array of the same shape, where the last axis is 3, + it is assumed to be a Quaternion by 3D-Vector + multiplication and the 3D-Vectors are rotated + in space by the Quaternions. + + When multipplying a Quaternions array by a scalar + or vector of different shape it is assumed to be + a Quaternions by Scalars multiplication and the + Quaternions are scaled using Slerp and the identity + quaternions. + """ + + """ If Quaternions type do Quaternions * Quaternions """ + if isinstance(other, Quaternions): + sqs, oqs = Quaternions._broadcast(self.qs, other.qs) + + q0 = sqs[..., 0]; + q1 = sqs[..., 1]; + q2 = sqs[..., 2]; + q3 = sqs[..., 3]; + r0 = oqs[..., 0]; + r1 = oqs[..., 1]; + r2 = oqs[..., 2]; + r3 = oqs[..., 3]; + + qs = np.empty(sqs.shape) + qs[..., 0] = r0 * q0 - r1 * q1 - r2 * q2 - r3 * q3 + qs[..., 1] = r0 * q1 + r1 * q0 - r2 * q3 + r3 * q2 + qs[..., 2] = r0 * q2 + r1 * q3 + r2 * q0 - r3 * q1 + qs[..., 3] = r0 * q3 - r1 * q2 + r2 * q1 + r3 * q0 + + return Quaternions(qs) + + """ If array type do Quaternions * Vectors """ + if isinstance(other, np.ndarray) and other.shape[-1] == 3: + vs = Quaternions(np.concatenate([np.zeros(other.shape[:-1] + (1,)), other], axis=-1)) + return (self * (vs * -self)).imaginaries + + """ If float do Quaternions * Scalars """ + if isinstance(other, np.ndarray) or isinstance(other, float): + return Quaternions.slerp(Quaternions.id_like(self), self, other) + + raise TypeError('Cannot multiply/add Quaternions with type %s' % str(type(other))) + + def __div__(self, other): + """ + When a Quaternion type is supplied, division is defined + as multiplication by the inverse of that Quaternion. + + When a scalar or vector is supplied it is defined + as multiplicaion of one over the supplied value. + Essentially a scaling. + """ + + if isinstance(other, Quaternions): return self * (-other) + if isinstance(other, np.ndarray): return self * (1.0 / other) + if isinstance(other, float): return self * (1.0 / other) + raise TypeError('Cannot divide/subtract Quaternions with type %s' + str(type(other))) + + def __eq__(self, other): + return self.qs == other.qs + + def __ne__(self, other): + return self.qs != other.qs + + def __neg__(self): + """ Invert Quaternions """ + return Quaternions(self.qs * np.array([[1, -1, -1, -1]])) + + def __abs__(self): + """ Unify Quaternions To Single Pole """ + qabs = self.normalized().copy() + top = np.sum((qabs.qs) * np.array([1, 0, 0, 0]), axis=-1) + bot = np.sum((-qabs.qs) * np.array([1, 0, 0, 0]), axis=-1) + qabs.qs[top < bot] = -qabs.qs[top < bot] + return qabs + + def __iter__(self): + return iter(self.qs) + + def __len__(self): + return len(self.qs) + + def __getitem__(self, k): + return Quaternions(self.qs[k]) + + def __setitem__(self, k, v): + self.qs[k] = v.qs + + @property + def lengths(self): + return np.sum(self.qs ** 2.0, axis=-1) ** 0.5 + + @property + def reals(self): + return self.qs[..., 0] + + @property + def imaginaries(self): + return self.qs[..., 1:4] + + @property + def shape(self): + return self.qs.shape[:-1] + + def repeat(self, n, **kwargs): + return Quaternions(self.qs.repeat(n, **kwargs)) + + def normalized(self): + return Quaternions(self.qs / self.lengths[..., np.newaxis]) + + def log(self): + norm = abs(self.normalized()) + imgs = norm.imaginaries + lens = np.sqrt(np.sum(imgs ** 2, axis=-1)) + lens = np.arctan2(lens, norm.reals) / (lens + 1e-10) + return imgs * lens[..., np.newaxis] + + def constrained(self, axis): + + rl = self.reals + im = np.sum(axis * self.imaginaries, axis=-1) + + t1 = -2 * np.arctan2(rl, im) + np.pi + t2 = -2 * np.arctan2(rl, im) - np.pi + + top = Quaternions.exp(axis[np.newaxis] * (t1[:, np.newaxis] / 2.0)) + bot = Quaternions.exp(axis[np.newaxis] * (t2[:, np.newaxis] / 2.0)) + img = self.dot(top) > self.dot(bot) + + ret = top.copy() + ret[img] = top[img] + ret[~img] = bot[~img] + return ret + + def constrained_x(self): + return self.constrained(np.array([1, 0, 0])) + + def constrained_y(self): + return self.constrained(np.array([0, 1, 0])) + + def constrained_z(self): + return self.constrained(np.array([0, 0, 1])) + + def dot(self, q): + return np.sum(self.qs * q.qs, axis=-1) + + def copy(self): + return Quaternions(np.copy(self.qs)) + + def reshape(self, s): + self.qs.reshape(s) + return self + + def interpolate(self, ws): + return Quaternions.exp(np.average(abs(self).log, axis=0, weights=ws)) + + def euler(self, order='xyz'): + + q = self.normalized().qs + q0 = q[..., 0] + q1 = q[..., 1] + q2 = q[..., 2] + q3 = q[..., 3] + es = np.zeros(self.shape + (3,)) + + if order == 'xyz': + es[..., 0] = np.arctan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) + es[..., 1] = np.arcsin((2 * (q0 * q2 - q3 * q1)).clip(-1, 1)) + es[..., 2] = np.arctan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) + elif order == 'yzx': + es[..., 0] = np.arctan2(2 * (q1 * q0 - q2 * q3), -q1 * q1 + q2 * q2 - q3 * q3 + q0 * q0) + es[..., 1] = np.arctan2(2 * (q2 * q0 - q1 * q3), q1 * q1 - q2 * q2 - q3 * q3 + q0 * q0) + es[..., 2] = np.arcsin((2 * (q1 * q2 + q3 * q0)).clip(-1, 1)) + else: + raise NotImplementedError('Cannot convert from ordering %s' % order) + + """ + + # These conversion don't appear to work correctly for Maya. + # http://bediyap.com/programming/convert-quaternion-to-euler-rotations/ + + if order == 'xyz': + es[...,0] = np.arctan2(2 * (q0 * q3 - q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3) + es[...,1] = np.arcsin((2 * (q1 * q3 + q0 * q2)).clip(-1,1)) + es[...,2] = np.arctan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3) + elif order == 'yzx': + es[...,0] = np.arctan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3) + es[...,1] = np.arcsin((2 * (q1 * q2 + q0 * q3)).clip(-1,1)) + es[...,2] = np.arctan2(2 * (q0 * q2 - q1 * q3), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3) + elif order == 'zxy': + es[...,0] = np.arctan2(2 * (q0 * q2 - q1 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3) + es[...,1] = np.arcsin((2 * (q0 * q1 + q2 * q3)).clip(-1,1)) + es[...,2] = np.arctan2(2 * (q0 * q3 - q1 * q2), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3) + elif order == 'xzy': + es[...,0] = np.arctan2(2 * (q0 * q2 + q1 * q3), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3) + es[...,1] = np.arcsin((2 * (q0 * q3 - q1 * q2)).clip(-1,1)) + es[...,2] = np.arctan2(2 * (q0 * q1 + q2 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3) + elif order == 'yxz': + es[...,0] = np.arctan2(2 * (q1 * q2 + q0 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3) + es[...,1] = np.arcsin((2 * (q0 * q1 - q2 * q3)).clip(-1,1)) + es[...,2] = np.arctan2(2 * (q1 * q3 + q0 * q2), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3) + elif order == 'zyx': + es[...,0] = np.arctan2(2 * (q0 * q1 + q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3) + es[...,1] = np.arcsin((2 * (q0 * q2 - q1 * q3)).clip(-1,1)) + es[...,2] = np.arctan2(2 * (q0 * q3 + q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3) + else: + raise KeyError('Unknown ordering %s' % order) + + """ + + # https://github.com/ehsan/ogre/blob/master/OgreMain/src/OgreMatrix3.cpp + # Use this class and convert from matrix + + return es + + def average(self): + + if len(self.shape) == 1: + + import numpy.core.umath_tests as ut + system = ut.matrix_multiply(self.qs[:, :, np.newaxis], self.qs[:, np.newaxis, :]).sum(axis=0) + w, v = np.linalg.eigh(system) + qiT_dot_qref = (self.qs[:, :, np.newaxis] * v[np.newaxis, :, :]).sum(axis=1) + return Quaternions(v[:, np.argmin((1. - qiT_dot_qref ** 2).sum(axis=0))]) + + else: + + raise NotImplementedError('Cannot average multi-dimensionsal Quaternions') + + def angle_axis(self): + + norm = self.normalized() + s = np.sqrt(1 - (norm.reals ** 2.0)) + s[s == 0] = 0.001 + + angles = 2.0 * np.arccos(norm.reals) + axis = norm.imaginaries / s[..., np.newaxis] + + return angles, axis + + def transforms(self): + + qw = self.qs[..., 0] + qx = self.qs[..., 1] + qy = self.qs[..., 2] + qz = self.qs[..., 3] + + x2 = qx + qx; + y2 = qy + qy; + z2 = qz + qz; + xx = qx * x2; + yy = qy * y2; + wx = qw * x2; + xy = qx * y2; + yz = qy * z2; + wy = qw * y2; + xz = qx * z2; + zz = qz * z2; + wz = qw * z2; + + m = np.empty(self.shape + (3, 3)) + m[..., 0, 0] = 1.0 - (yy + zz) + m[..., 0, 1] = xy - wz + m[..., 0, 2] = xz + wy + m[..., 1, 0] = xy + wz + m[..., 1, 1] = 1.0 - (xx + zz) + m[..., 1, 2] = yz - wx + m[..., 2, 0] = xz - wy + m[..., 2, 1] = yz + wx + m[..., 2, 2] = 1.0 - (xx + yy) + + return m + + def ravel(self): + return self.qs.ravel() + + @classmethod + def id(cls, n): + + if isinstance(n, tuple): + qs = np.zeros(n + (4,)) + qs[..., 0] = 1.0 + return Quaternions(qs) + + if isinstance(n, int) or isinstance(n, long): + qs = np.zeros((n, 4)) + qs[:, 0] = 1.0 + return Quaternions(qs) + + raise TypeError('Cannot Construct Quaternion from %s type' % str(type(n))) + + @classmethod + def id_like(cls, a): + qs = np.zeros(a.shape + (4,)) + qs[..., 0] = 1.0 + return Quaternions(qs) + + @classmethod + def exp(cls, ws): + + ts = np.sum(ws ** 2.0, axis=-1) ** 0.5 + ts[ts == 0] = 0.001 + ls = np.sin(ts) / ts + + qs = np.empty(ws.shape[:-1] + (4,)) + qs[..., 0] = np.cos(ts) + qs[..., 1] = ws[..., 0] * ls + qs[..., 2] = ws[..., 1] * ls + qs[..., 3] = ws[..., 2] * ls + + return Quaternions(qs).normalized() + + @classmethod + def slerp(cls, q0s, q1s, a): + + fst, snd = cls._broadcast(q0s.qs, q1s.qs) + fst, a = cls._broadcast(fst, a, scalar=True) + snd, a = cls._broadcast(snd, a, scalar=True) + + len = np.sum(fst * snd, axis=-1) + + neg = len < 0.0 + len[neg] = -len[neg] + snd[neg] = -snd[neg] + + amount0 = np.zeros(a.shape) + amount1 = np.zeros(a.shape) + + linear = (1.0 - len) < 0.01 + omegas = np.arccos(len[~linear]) + sinoms = np.sin(omegas) + + amount0[linear] = 1.0 - a[linear] + amount1[linear] = a[linear] + amount0[~linear] = np.sin((1.0 - a[~linear]) * omegas) / sinoms + amount1[~linear] = np.sin(a[~linear] * omegas) / sinoms + + return Quaternions( + amount0[..., np.newaxis] * fst + + amount1[..., np.newaxis] * snd) + + @classmethod + def between(cls, v0s, v1s): + a = np.cross(v0s, v1s) + w = np.sqrt((v0s ** 2).sum(axis=-1) * (v1s ** 2).sum(axis=-1)) + (v0s * v1s).sum(axis=-1) + return Quaternions(np.concatenate([w[..., np.newaxis], a], axis=-1)).normalized() + + @classmethod + def from_angle_axis(cls, angles, axis): + axis = axis / (np.sqrt(np.sum(axis ** 2, axis=-1)) + 1e-10)[..., np.newaxis] + sines = np.sin(angles / 2.0)[..., np.newaxis] + cosines = np.cos(angles / 2.0)[..., np.newaxis] + return Quaternions(np.concatenate([cosines, axis * sines], axis=-1)) + + @classmethod + def from_euler(cls, es, order='xyz', world=False): + + axis = { + 'x': np.array([1, 0, 0]), + 'y': np.array([0, 1, 0]), + 'z': np.array([0, 0, 1]), + } + + q0s = Quaternions.from_angle_axis(es[..., 0], axis[order[0]]) + q1s = Quaternions.from_angle_axis(es[..., 1], axis[order[1]]) + q2s = Quaternions.from_angle_axis(es[..., 2], axis[order[2]]) + + return (q2s * (q1s * q0s)) if world else (q0s * (q1s * q2s)) + + @classmethod + def from_transforms(cls, ts): + + d0, d1, d2 = ts[..., 0, 0], ts[..., 1, 1], ts[..., 2, 2] + + q0 = (d0 + d1 + d2 + 1.0) / 4.0 + q1 = (d0 - d1 - d2 + 1.0) / 4.0 + q2 = (-d0 + d1 - d2 + 1.0) / 4.0 + q3 = (-d0 - d1 + d2 + 1.0) / 4.0 + + q0 = np.sqrt(q0.clip(0, None)) + q1 = np.sqrt(q1.clip(0, None)) + q2 = np.sqrt(q2.clip(0, None)) + q3 = np.sqrt(q3.clip(0, None)) + + c0 = (q0 >= q1) & (q0 >= q2) & (q0 >= q3) + c1 = (q1 >= q0) & (q1 >= q2) & (q1 >= q3) + c2 = (q2 >= q0) & (q2 >= q1) & (q2 >= q3) + c3 = (q3 >= q0) & (q3 >= q1) & (q3 >= q2) + + q1[c0] *= np.sign(ts[c0, 2, 1] - ts[c0, 1, 2]) + q2[c0] *= np.sign(ts[c0, 0, 2] - ts[c0, 2, 0]) + q3[c0] *= np.sign(ts[c0, 1, 0] - ts[c0, 0, 1]) + + q0[c1] *= np.sign(ts[c1, 2, 1] - ts[c1, 1, 2]) + q2[c1] *= np.sign(ts[c1, 1, 0] + ts[c1, 0, 1]) + q3[c1] *= np.sign(ts[c1, 0, 2] + ts[c1, 2, 0]) + + q0[c2] *= np.sign(ts[c2, 0, 2] - ts[c2, 2, 0]) + q1[c2] *= np.sign(ts[c2, 1, 0] + ts[c2, 0, 1]) + q3[c2] *= np.sign(ts[c2, 2, 1] + ts[c2, 1, 2]) + + q0[c3] *= np.sign(ts[c3, 1, 0] - ts[c3, 0, 1]) + q1[c3] *= np.sign(ts[c3, 2, 0] + ts[c3, 0, 2]) + q2[c3] *= np.sign(ts[c3, 2, 1] + ts[c3, 1, 2]) + + qs = np.empty(ts.shape[:-2] + (4,)) + qs[..., 0] = q0 + qs[..., 1] = q1 + qs[..., 2] = q2 + qs[..., 3] = q3 + + return cls(qs) + + diff --git a/evaluation_metric/pymo/__init__.py b/evaluation_metric/pymo/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/evaluation_metric/pymo/data.py b/evaluation_metric/pymo/data.py new file mode 100644 index 0000000..448fbb0 --- /dev/null +++ b/evaluation_metric/pymo/data.py @@ -0,0 +1,53 @@ +import numpy as np + +class Joint(): + def __init__(self, name, parent=None, children=None): + self.name = name + self.parent = parent + self.children = children + +class MocapData(): + def __init__(self): + self.skeleton = {} + self.values = None + self.channel_names = [] + self.framerate = 0.0 + self.root_name = '' + + def traverse(self, j=None): + stack = [self.root_name] + while stack: + joint = stack.pop() + yield joint + for c in self.skeleton[joint]['children']: + stack.append(c) + + def clone(self): + import copy + new_data = MocapData() + new_data.skeleton = copy.deepcopy(self.skeleton) + new_data.values = copy.deepcopy(self.values) + new_data.channel_names = copy.deepcopy(self.channel_names) + new_data.root_name = copy.deepcopy(self.root_name) + new_data.framerate = copy.deepcopy(self.framerate) + return new_data + + def get_all_channels(self): + '''Returns all of the channels parsed from the file as a 2D numpy array''' + + frames = [f[1] for f in self.values] + return np.asarray([[channel[2] for channel in frame] for frame in frames]) + + def get_skeleton_tree(self): + tree = [] + root_key = [j for j in self.skeleton if self.skeleton[j]['parent']==None][0] + + root_joint = Joint(root_key) + + def get_empty_channels(self): + #TODO + pass + + def get_constant_channels(self): + #TODO + pass diff --git a/evaluation_metric/pymo/features.py b/evaluation_metric/pymo/features.py new file mode 100644 index 0000000..fec29ed --- /dev/null +++ b/evaluation_metric/pymo/features.py @@ -0,0 +1,43 @@ +''' +A set of mocap feature extraction functions + +Created by Omid Alemi | Nov 17 2017 + +''' +import numpy as np +import pandas as pd +import peakutils +import matplotlib.pyplot as plt + +def get_foot_contact_idxs(signal, t=0.02, min_dist=120): + up_idxs = peakutils.indexes(signal, thres=t/max(signal), min_dist=min_dist) + down_idxs = peakutils.indexes(-signal, thres=t/min(signal), min_dist=min_dist) + + return [up_idxs, down_idxs] + + +def create_foot_contact_signal(mocap_track, col_name, start=1, t=0.02, min_dist=120): + signal = mocap_track.values[col_name].values + idxs = get_foot_contact_idxs(signal, t, min_dist) + + step_signal = [] + + c = start + for f in range(len(signal)): + if f in idxs[1]: + c = 0 + elif f in idxs[0]: + c = 1 + + step_signal.append(c) + + return step_signal + +def plot_foot_up_down(mocap_track, col_name, t=0.02, min_dist=120): + + signal = mocap_track.values[col_name].values + idxs = get_foot_contact_idxs(signal, t, min_dist) + + plt.plot(mocap_track.values.index, signal) + plt.plot(mocap_track.values.index[idxs[0]], signal[idxs[0]], 'ro') + plt.plot(mocap_track.values.index[idxs[1]], signal[idxs[1]], 'go') diff --git a/evaluation_metric/pymo/parsers.py b/evaluation_metric/pymo/parsers.py new file mode 100644 index 0000000..b5fece8 --- /dev/null +++ b/evaluation_metric/pymo/parsers.py @@ -0,0 +1,260 @@ +''' +BVH Parser Class + +By Omid Alemi +Created: June 12, 2017 + +Based on: https://gist.github.com/johnfredcee/2007503 + +''' +import re +import numpy as np +from pymo.data import Joint, MocapData + +class BVHScanner(): + ''' + A wrapper class for re.Scanner + ''' + def __init__(self): + + def identifier(scanner, token): + return 'IDENT', token + + def operator(scanner, token): + return 'OPERATOR', token + + def digit(scanner, token): + return 'DIGIT', token + + def open_brace(scanner, token): + return 'OPEN_BRACE', token + + def close_brace(scanner, token): + return 'CLOSE_BRACE', token + + self.scanner = re.Scanner([ + (r'[a-zA-Z_]\w*', identifier), + #(r'-*[0-9]+(\.[0-9]+)?', digit), # won't work for .34 + #(r'[-+]?[0-9]*\.?[0-9]+', digit), # won't work for 4.56e-2 + #(r'[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?', digit), + (r'-*[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?', digit), + (r'}', close_brace), + (r'}', close_brace), + (r'{', open_brace), + (r':', None), + (r'\s+', None) + ]) + + def scan(self, stuff): + return self.scanner.scan(stuff) + + + +class BVHParser(): + ''' + A class to parse a BVH file. + + Extracts the skeleton and channel values + ''' + def __init__(self, filename=None): + self.reset() + + def reset(self): + self._skeleton = {} + self.bone_context = [] + self._motion_channels = [] + self._motions = [] + self.current_token = 0 + self.framerate = 0.0 + self.root_name = '' + + self.scanner = BVHScanner() + + self.data = MocapData() + + + def parse(self, filename, start=0, stop=-1): + self.reset() + + with open(filename, 'r') as bvh_file: + raw_contents = bvh_file.read() + tokens, remainder = self.scanner.scan(raw_contents) + self._parse_hierarchy(tokens) + self.current_token = self.current_token + 1 + self._parse_motion(tokens, start, stop) + + self.data.skeleton = self._skeleton + self.data.channel_names = self._motion_channels + self.data.values = self._to_DataFrame() + self.data.root_name = self.root_name + self.data.framerate = self.framerate + + return self.data + + def _to_DataFrame(self): + '''Returns all of the channels parsed from the file as a pandas DataFrame''' + + import pandas as pd + time_index = pd.to_timedelta([f[0] for f in self._motions], unit='s') + frames = [f[1] for f in self._motions] + channels = np.asarray([[channel[2] for channel in frame] for frame in frames]) + column_names = ['%s_%s'%(c[0], c[1]) for c in self._motion_channels] + + return pd.DataFrame(data=channels, index=time_index, columns=column_names) + + + def _new_bone(self, parent, name): + bone = {'parent': parent, 'channels': [], 'offsets': [], 'order': '','children': []} + return bone + + def _push_bone_context(self,name): + self.bone_context.append(name) + + def _get_bone_context(self): + return self.bone_context[len(self.bone_context)-1] + + def _pop_bone_context(self): + self.bone_context = self.bone_context[:-1] + return self.bone_context[len(self.bone_context)-1] + + def _read_offset(self, bvh, token_index): + if bvh[token_index] != ('IDENT', 'OFFSET'): + return None, None + token_index = token_index + 1 + offsets = [0.0] * 3 + for i in range(3): + offsets[i] = float(bvh[token_index][1]) + token_index = token_index + 1 + return offsets, token_index + + def _read_channels(self, bvh, token_index): + if bvh[token_index] != ('IDENT', 'CHANNELS'): + return None, None + token_index = token_index + 1 + channel_count = int(bvh[token_index][1]) + token_index = token_index + 1 + channels = [""] * channel_count + order = "" + for i in range(channel_count): + channels[i] = bvh[token_index][1] + token_index = token_index + 1 + if(channels[i] == "Xrotation" or channels[i]== "Yrotation" or channels[i]== "Zrotation"): + order += channels[i][0] + else : + order = "" + return channels, token_index, order + + def _parse_joint(self, bvh, token_index): + end_site = False + joint_id = bvh[token_index][1] + token_index = token_index + 1 + joint_name = bvh[token_index][1] + token_index = token_index + 1 + + parent_name = self._get_bone_context() + + if (joint_id == "End"): + joint_name = parent_name+ '_Nub' + end_site = True + joint = self._new_bone(parent_name, joint_name) + if bvh[token_index][0] != 'OPEN_BRACE': + print('Was expecting brance, got ', bvh[token_index]) + return None + token_index = token_index + 1 + offsets, token_index = self._read_offset(bvh, token_index) + joint['offsets'] = offsets + if not end_site: + channels, token_index, order = self._read_channels(bvh, token_index) + joint['channels'] = channels + joint['order'] = order + for channel in channels: + self._motion_channels.append((joint_name, channel)) + + self._skeleton[joint_name] = joint + self._skeleton[parent_name]['children'].append(joint_name) + + while (bvh[token_index][0] == 'IDENT' and bvh[token_index][1] == 'JOINT') or (bvh[token_index][0] == 'IDENT' and bvh[token_index][1] == 'End'): + self._push_bone_context(joint_name) + token_index = self._parse_joint(bvh, token_index) + self._pop_bone_context() + + if bvh[token_index][0] == 'CLOSE_BRACE': + return token_index + 1 + + print('Unexpected token ', bvh[token_index]) + + def _parse_hierarchy(self, bvh): + self.current_token = 0 + if bvh[self.current_token] != ('IDENT', 'HIERARCHY'): + return None + self.current_token = self.current_token + 1 + if bvh[self.current_token] != ('IDENT', 'ROOT'): + return None + self.current_token = self.current_token + 1 + if bvh[self.current_token][0] != 'IDENT': + return None + + root_name = bvh[self.current_token][1] + root_bone = self._new_bone(None, root_name) + self.current_token = self.current_token + 2 #skipping open brace + offsets, self.current_token = self._read_offset(bvh, self.current_token) + channels, self.current_token, order = self._read_channels(bvh, self.current_token) + root_bone['offsets'] = offsets + root_bone['channels'] = channels + root_bone['order'] = order + self._skeleton[root_name] = root_bone + self._push_bone_context(root_name) + + for channel in channels: + self._motion_channels.append((root_name, channel)) + + while bvh[self.current_token][1] == 'JOINT': + self.current_token = self._parse_joint(bvh, self.current_token) + + self.root_name = root_name + + def _parse_motion(self, bvh, start, stop): + if bvh[self.current_token][0] != 'IDENT': + print('Unexpected text') + return None + if bvh[self.current_token][1] != 'MOTION': + print('No motion section') + return None + self.current_token = self.current_token + 1 + if bvh[self.current_token][1] != 'Frames': + return None + self.current_token = self.current_token + 1 + frame_count = int(bvh[self.current_token][1]) + + if stop<0 or stop>frame_count: + stop = frame_count + + assert(start>=0) + assert(start=start: + self._motions[idx] = (frame_time, channel_values) + frame_time = frame_time + frame_rate + idx+=1 diff --git a/evaluation_metric/pymo/preprocessing.py b/evaluation_metric/pymo/preprocessing.py new file mode 100644 index 0000000..de09060 --- /dev/null +++ b/evaluation_metric/pymo/preprocessing.py @@ -0,0 +1,1196 @@ +''' +Preprocessing Tranformers Based on sci-kit's API + +By Omid Alemi +Created on June 12, 2017 + +Modified by Simon Alexanderson, 2020-06-24 +''' +import copy +import pandas as pd +import numpy as np +import transforms3d as t3d +import scipy.ndimage.filters as filters +from scipy.spatial.transform import Rotation as R +from pymo.Quaternions import Quaternions +from pymo.Pivots import Pivots +from sklearn.base import BaseEstimator, TransformerMixin + +class MocapParameterizer(BaseEstimator, TransformerMixin): + def __init__(self, param_type = 'euler'): + ''' + + param_type = {'euler', 'quat', 'expmap', 'position', 'expmap2pos'} + ''' + self.param_type = param_type + + def fit(self, X, y=None): + return self + + def transform(self, X, y=None): + # print("MocapParameterizer: " + self.param_type) + if self.param_type == 'euler': + return X + elif self.param_type == 'expmap': + return self._to_expmap(X) + elif self.param_type == 'quat': + return X + elif self.param_type == 'position': + return self._to_pos(X) + elif self.param_type == 'expmap2pos': + return self._expmap_to_pos(X) + else: + raise 'param types: euler, quat, expmap, position, expmap2pos' + +# return X + + def inverse_transform(self, X, copy=None): + if self.param_type == 'euler': + return X + elif self.param_type == 'expmap': + return self._expmap_to_euler(X) + elif self.param_type == 'quat': + raise 'quat2euler is not supported' + elif self.param_type == 'position': + raise 'positions 2 eulers is not supported' + return X + else: + raise 'param types: euler, quat, expmap, position' + + def fix_rotvec(self, rots): + '''fix problems with discontinuous rotation vectors''' + new_rots = rots.copy() + + # Compute angles and alternative rotation angles + angs = np.linalg.norm(rots, axis=1) + alt_angs=2*np.pi-angs + + #find discontinuities by checking if the alternative representation is closer + d_angs = np.diff(angs, axis=0) + d_angs2 = alt_angs[1:]-angs[:-1] + swps = np.where(np.abs(d_angs2)0: + y = np.zeros((n_sequences, self.window_size, vals.shape[1])) + + # extract sequences from the input data + for i in range(0,n_sequences): + frameIdx = (self.window_size-overlap_frames) * i + Q.append(vals[frameIdx:frameIdx+self.window_size,:]) + + return np.array(Q) + + def inverse_transform(self, X, copy=None): + Q = [] + + for track in X: + + new_mocap = self.org_mocap_.clone() + time_index = pd.to_timedelta([f for f in range(track.shape[0])], unit='s') + + new_df = pd.DataFrame(data=track, index=time_index, columns=self.org_mocap_.values.columns) + + new_mocap.values = new_df + + + Q.append(new_mocap) + + return Q + +class RootTransformer(BaseEstimator, TransformerMixin): + def __init__(self, method, position_smoothing=0, rotation_smoothing=0): + """ + Accepted methods: + abdolute_translation_deltas + pos_rot_deltas + """ + self.method = method + self.position_smoothing=position_smoothing + self.rotation_smoothing=rotation_smoothing + + def fit(self, X, y=None): + return self + + def transform(self, X, y=None): + print("RootTransformer") + Q = [] + + for track in X: + if self.method == 'abdolute_translation_deltas': + new_df = track.values.copy() + xpcol = '%s_Xposition'%track.root_name + ypcol = '%s_Yposition'%track.root_name + zpcol = '%s_Zposition'%track.root_name + + + dxpcol = '%s_dXposition'%track.root_name + dzpcol = '%s_dZposition'%track.root_name + + x=track.values[xpcol].copy() + z=track.values[zpcol].copy() + + if self.position_smoothing>0: + x_sm = filters.gaussian_filter1d(x, self.position_smoothing, axis=0, mode='nearest') + z_sm = filters.gaussian_filter1d(z, self.position_smoothing, axis=0, mode='nearest') + dx = pd.Series(data=x_sm, index=new_df.index).diff() + dz = pd.Series(data=z_sm, index=new_df.index).diff() + new_df[xpcol] = x-x_sm + new_df[zpcol] = z-z_sm + else: + dx = x.diff() + dz = z.diff() + new_df.drop([xpcol, zpcol], axis=1, inplace=True) + + dx[0] = dx[1] + dz[0] = dz[1] + + new_df[dxpcol] = dx + new_df[dzpcol] = dz + + new_track = track.clone() + new_track.values = new_df + # end of abdolute_translation_deltas + + elif self.method == 'pos_rot_deltas': + new_track = track.clone() + + # Absolute columns + xp_col = '%s_Xposition' % track.root_name + yp_col = '%s_Yposition' % track.root_name + zp_col = '%s_Zposition' % track.root_name + + # rot_order = track.skeleton[track.root_name]['order'] + # %(joint, rot_order[0]) + + rot_order = track.skeleton[track.root_name]['order'] + r1_col = '%s_%srotation' % (track.root_name, rot_order[0]) + r2_col = '%s_%srotation' % (track.root_name, rot_order[1]) + r3_col = '%s_%srotation' % (track.root_name, rot_order[2]) + + # Delta columns + dxp_col = '%s_dXposition' % track.root_name + dzp_col = '%s_dZposition' % track.root_name + + dxr_col = '%s_dXrotation' % track.root_name + dyr_col = '%s_dYrotation' % track.root_name + dzr_col = '%s_dZrotation' % track.root_name + + positions = np.transpose(np.array([track.values[xp_col], track.values[yp_col], track.values[zp_col]])) + rotations = np.pi / 180.0 * np.transpose( + np.array([track.values[r1_col], track.values[r2_col], track.values[r3_col]])) + + """ Get Trajectory and smooth it""" + trajectory_filterwidth = self.position_smoothing + reference = positions.copy() * np.array([1, 0, 1]) + if trajectory_filterwidth > 0: + reference = filters.gaussian_filter1d(reference, trajectory_filterwidth, axis=0, mode='nearest') + + """ Get Root Velocity """ + velocity = np.diff(reference, axis=0) + velocity = np.vstack((velocity[0, :], velocity)) + + """ Remove Root Translation """ + positions = positions - reference + + """ Get Forward Direction along the x-z plane, assuming character is facig z-forward """ + # forward = [Rotation(f, 'euler', from_deg=True, order=rot_order).rotmat[:,2] for f in rotations] # get the z-axis of the rotation matrix, assuming character is facig z-forward + # print("order:" + rot_order.lower()) + quats = Quaternions.from_euler(rotations, order=rot_order.lower(), world=False) + forward = quats * np.array([[0, 0, 1]]) + forward[:, 1] = 0 + + """ Smooth Forward Direction """ + direction_filterwidth = self.rotation_smoothing + if direction_filterwidth > 0: + forward = filters.gaussian_filter1d(forward, direction_filterwidth, axis=0, mode='nearest') + + forward = forward / np.sqrt((forward ** 2).sum(axis=-1))[..., np.newaxis] + + """ Remove Y Rotation """ + target = np.array([[0, 0, 1]]).repeat(len(forward), axis=0) + rotation = Quaternions.between(target, forward)[:, np.newaxis] + positions = (-rotation[:, 0]) * positions + new_rotations = (-rotation[:, 0]) * quats + + """ Get Root Rotation """ + # print(rotation[:,0]) + velocity = (-rotation[:, 0]) * velocity + rvelocity = Pivots.from_quaternions(rotation[1:] * -rotation[:-1]).ps + rvelocity = np.vstack((rvelocity[0], rvelocity)) + + eulers = np.array([t3d.euler.quat2euler(q, axes=('s' + rot_order.lower()[::-1]))[::-1] for q in + new_rotations]) * 180.0 / np.pi + + new_df = track.values.copy() + + root_pos_x = pd.Series(data=positions[:, 0], index=new_df.index) + root_pos_y = pd.Series(data=positions[:, 1], index=new_df.index) + root_pos_z = pd.Series(data=positions[:, 2], index=new_df.index) + root_pos_x_diff = pd.Series(data=velocity[:, 0], index=new_df.index) + root_pos_z_diff = pd.Series(data=velocity[:, 2], index=new_df.index) + + root_rot_1 = pd.Series(data=eulers[:, 0], index=new_df.index) + root_rot_2 = pd.Series(data=eulers[:, 1], index=new_df.index) + root_rot_3 = pd.Series(data=eulers[:, 2], index=new_df.index) + root_rot_y_diff = pd.Series(data=rvelocity[:, 0], index=new_df.index) + + # new_df.drop([xr_col, yr_col, zr_col, xp_col, zp_col], axis=1, inplace=True) + + new_df[xp_col] = root_pos_x + new_df[yp_col] = root_pos_y + new_df[zp_col] = root_pos_z + new_df[dxp_col] = root_pos_x_diff + new_df[dzp_col] = root_pos_z_diff + + new_df[r1_col] = root_rot_1 + new_df[r2_col] = root_rot_2 + new_df[r3_col] = root_rot_3 + # new_df[dxr_col] = root_rot_x_diff + new_df[dyr_col] = root_rot_y_diff + # new_df[dzr_col] = root_rot_z_diff + + new_track.values = new_df + + elif self.method == 'hip_centric': + new_track = track.clone() + + # Absolute columns + xp_col = '%s_Xposition'%track.root_name + yp_col = '%s_Yposition'%track.root_name + zp_col = '%s_Zposition'%track.root_name + + xr_col = '%s_Xrotation'%track.root_name + yr_col = '%s_Yrotation'%track.root_name + zr_col = '%s_Zrotation'%track.root_name + + new_df = track.values.copy() + + all_zeros = np.zeros(track.values[xp_col].values.shape) + + new_df[xp_col] = pd.Series(data=all_zeros, index=new_df.index) + new_df[yp_col] = pd.Series(data=all_zeros, index=new_df.index) + new_df[zp_col] = pd.Series(data=all_zeros, index=new_df.index) + + new_df[xr_col] = pd.Series(data=all_zeros, index=new_df.index) + new_df[yr_col] = pd.Series(data=all_zeros, index=new_df.index) + new_df[zr_col] = pd.Series(data=all_zeros, index=new_df.index) + + new_track.values = new_df + + #print(new_track.values.columns) + Q.append(new_track) + + return Q + + def inverse_transform(self, X, copy=None, start_pos=None): + Q = [] + + #TODO: simplify this implementation + + startx = 0 + startz = 0 + + if start_pos is not None: + startx, startz = start_pos + + for track in X: + new_track = track.clone() + if self.method == 'abdolute_translation_deltas': + new_df = new_track.values + xpcol = '%s_Xposition'%track.root_name + ypcol = '%s_Yposition'%track.root_name + zpcol = '%s_Zposition'%track.root_name + + + dxpcol = '%s_dXposition'%track.root_name + dzpcol = '%s_dZposition'%track.root_name + + dx = track.values[dxpcol].values + dz = track.values[dzpcol].values + + recx = [startx] + recz = [startz] + + for i in range(dx.shape[0]-1): + recx.append(recx[i]+dx[i+1]) + recz.append(recz[i]+dz[i+1]) + + # recx = [recx[i]+dx[i+1] for i in range(dx.shape[0]-1)] + # recz = [recz[i]+dz[i+1] for i in range(dz.shape[0]-1)] + # recx = dx[:-1] + dx[1:] + # recz = dz[:-1] + dz[1:] + if self.position_smoothing > 0: + new_df[xpcol] = pd.Series(data=new_df[xpcol]+recx, index=new_df.index) + new_df[zpcol] = pd.Series(data=new_df[zpcol]+recz, index=new_df.index) + else: + new_df[xpcol] = pd.Series(data=recx, index=new_df.index) + new_df[zpcol] = pd.Series(data=recz, index=new_df.index) + + new_df.drop([dxpcol, dzpcol], axis=1, inplace=True) + + new_track.values = new_df + # end of abdolute_translation_deltas + + Q.append(new_track) + + return Q + + +class RootNormalizer(BaseEstimator, TransformerMixin): + """ + Make subjects in TalkingWithHands16.2M face the same direction + This class is not for general uses. Only compatible to GENEA 2022 challenge dataset + Added by Youngwoo Yoon, April 2022 + """ + def __init__(self): + pass + + def fit(self, X, y=None): + return self + + def transform(self, X, y=None): + print("RootNormalizer") + Q = [] + + for track in X: + new_track = track.clone() + + xp_col = '%s_Xposition'%track.root_name + yp_col = '%s_Yposition'%track.root_name + zp_col = '%s_Zposition'%track.root_name + + xr_col = '%s_Xrotation'%track.root_name + yr_col = '%s_Yrotation'%track.root_name + zr_col = '%s_Zrotation'%track.root_name + + new_df = track.values.copy() + + all_zeros = np.zeros(track.values[xp_col].values.shape) + mean_xp = np.mean(track.values[xp_col].values) + mean_yp = np.mean(track.values[yp_col].values) + mean_zp = np.mean(track.values[zp_col].values) + + if track.values[xp_col].values[0] < 0: + new_yr = np.full(track.values[xp_col].values.shape, -90) + else: + new_yr = np.full(track.values[xp_col].values.shape, 90) + + new_df[xp_col] = pd.Series(data=track.values[xp_col]-mean_xp, index=new_df.index) + new_df[yp_col] = pd.Series(data=track.values[yp_col]-mean_yp, index=new_df.index) + new_df[zp_col] = pd.Series(data=track.values[zp_col]-mean_zp, index=new_df.index) + + new_df[xr_col] = pd.Series(data=all_zeros, index=new_df.index) + new_df[yr_col] = pd.Series(data=new_yr, index=new_df.index) + new_df[zr_col] = pd.Series(data=all_zeros, index=new_df.index) + + new_track.values = new_df + + Q.append(new_track) + + return Q + + def inverse_transform(self, X, copy=None): + # NOT IMPLEMENTED + return X + + +class RootCentricPositionNormalizer(BaseEstimator, TransformerMixin): + def __init__(self): + pass + + def fit(self, X, y=None): + return self + + def transform(self, X, y=None): + Q = [] + + for track in X: + new_track = track.clone() + + rxp = '%s_Xposition'%track.root_name + ryp = '%s_Yposition'%track.root_name + rzp = '%s_Zposition'%track.root_name + + projected_root_pos = track.values[[rxp, ryp, rzp]] + + projected_root_pos.loc[:,ryp] = 0 # we want the root's projection on the floor plane as the ref + + new_df = pd.DataFrame(index=track.values.index) + + all_but_root = [joint for joint in track.skeleton if track.root_name not in joint] + # all_but_root = [joint for joint in track.skeleton] + for joint in all_but_root: + new_df['%s_Xposition'%joint] = pd.Series(data=track.values['%s_Xposition'%joint]-projected_root_pos[rxp], index=new_df.index) + new_df['%s_Yposition'%joint] = pd.Series(data=track.values['%s_Yposition'%joint]-projected_root_pos[ryp], index=new_df.index) + new_df['%s_Zposition'%joint] = pd.Series(data=track.values['%s_Zposition'%joint]-projected_root_pos[rzp], index=new_df.index) + + + # keep the root as it is now + new_df[rxp] = track.values[rxp] + new_df[ryp] = track.values[ryp] + new_df[rzp] = track.values[rzp] + + new_track.values = new_df + + Q.append(new_track) + + return Q + + def inverse_transform(self, X, copy=None): + Q = [] + + for track in X: + new_track = track.clone() + + rxp = '%s_Xposition'%track.root_name + ryp = '%s_Yposition'%track.root_name + rzp = '%s_Zposition'%track.root_name + + projected_root_pos = track.values[[rxp, ryp, rzp]] + + projected_root_pos.loc[:,ryp] = 0 # we want the root's projection on the floor plane as the ref + + new_df = pd.DataFrame(index=track.values.index) + + for joint in track.skeleton: + new_df['%s_Xposition'%joint] = pd.Series(data=track.values['%s_Xposition'%joint]+projected_root_pos[rxp], index=new_df.index) + new_df['%s_Yposition'%joint] = pd.Series(data=track.values['%s_Yposition'%joint]+projected_root_pos[ryp], index=new_df.index) + new_df['%s_Zposition'%joint] = pd.Series(data=track.values['%s_Zposition'%joint]+projected_root_pos[rzp], index=new_df.index) + + + new_track.values = new_df + + Q.append(new_track) + + return Q + +class Flattener(BaseEstimator, TransformerMixin): + def __init__(self): + pass + + def fit(self, X, y=None): + return self + + def transform(self, X, y=None): + return np.concatenate(X, axis=0) + +class ConstantsRemover(BaseEstimator, TransformerMixin): + ''' + For now it just looks at the first track + ''' + + def __init__(self, eps = 1e-6): + self.eps = eps + + + def fit(self, X, y=None): + stds = X[0].values.std() + cols = X[0].values.columns.values + self.const_dims_ = [c for c in cols if (stds[c] < self.eps).any()] + self.const_values_ = {c:X[0].values[c].values[0] for c in cols if (stds[c] < self.eps).any()} + return self + + def transform(self, X, y=None): + Q = [] + + + for track in X: + print(self.const_dims_) + t2 = track.clone() + #for key in t2.skeleton.keys(): + # if key in self.ConstDims_: + # t2.skeleton.pop(key) + #print(track.values.columns.difference(self.const_dims_)) + t2.values.drop(self.const_dims_, axis=1, inplace=True) + #t2.values = track.values[track.values.columns.difference(self.const_dims_)] + Q.append(t2) + + return Q + + def inverse_transform(self, X, copy=None): + Q = [] + + for track in X: + t2 = track.clone() + for d in self.const_dims_: + t2.values[d] = self.const_values_[d] +# t2.values.assign(d=pd.Series(data=self.const_values_[d], index = t2.values.index)) + Q.append(t2) + + return Q + + +class ConstantsRemover_(BaseEstimator, TransformerMixin): + ''' + For now it just looks at the first track + ''' + + def __init__(self, eps=1e-6): + self.eps = eps + + def fit(self, X, y=None): + stds = X[0].values.std() + cols = X[0].values.columns.values + self.const_dims_ = [c for c in cols if "position" in c or "rotation" in c] + # self.const_dims_.remove("body_world_Xposition") + # self.const_dims_.remove("body_world_Yposition") + # self.const_dims_.remove("body_world_Zposition") + # self.const_dims_ = [c for c in cols if "position" in c or "rotation" in c or (stds[c] < self.eps).any()] + # self.const_dims_ = [c for c in cols if (stds[c] < self.eps).any()] + # self.const_dims_ = ['b_l_leg_beta', 'b_l_leg_gamma', 'b_r_leg_beta', 'b_r_leg_gamma', 'body_world_alpha', 'body_world_beta', 'body_world_gamma', 'b_root_Xposition', 'b_root_Yposition', 'b_root_Zposition', 'b_spine0_Xposition', 'b_spine0_Yposition', 'b_spine0_Zposition', 'b_spine1_Xposition', 'b_spine1_Yposition', 'b_spine1_Zposition', 'b_spine2_Xposition', 'b_spine2_Yposition', 'b_spine2_Zposition', 'b_spine3_Xposition', 'b_spine3_Yposition', 'b_spine3_Zposition', 'b_neck0_Xposition', 'b_neck0_Yposition', 'b_neck0_Zposition', 'b_head_Xposition', 'b_head_Yposition', 'b_head_Zposition', 'b_head_null_Xposition', 'b_head_null_Yposition', 'b_head_null_Zposition', 'b_head_null_Zrotation', 'b_head_null_Xrotation', 'b_head_null_Yrotation', 'b_r_shoulder_Xposition', 'b_r_shoulder_Yposition', 'b_r_shoulder_Zposition', 'b_r_arm_Xposition', 'b_r_arm_Yposition', 'b_r_arm_Zposition', 'b_r_arm_twist_Xposition', 'b_r_arm_twist_Yposition', 'b_r_arm_twist_Zposition', 'b_r_forearm_Xposition', 'b_r_forearm_Yposition', 'b_r_forearm_Zposition', 'b_r_wrist_twist_Xposition', 'b_r_wrist_twist_Yposition', 'b_r_wrist_twist_Zposition', 'b_r_wrist_Xposition', 'b_r_wrist_Yposition', 'b_r_wrist_Zposition', 'b_l_shoulder_Xposition', 'b_l_shoulder_Yposition', 'b_l_shoulder_Zposition', 'b_l_arm_Xposition', 'b_l_arm_Yposition', 'b_l_arm_Zposition', 'b_l_arm_twist_Xposition', 'b_l_arm_twist_Yposition', 'b_l_arm_twist_Zposition', 'b_l_forearm_Xposition', 'b_l_forearm_Yposition', 'b_l_forearm_Zposition', 'b_l_wrist_twist_Xposition', 'b_l_wrist_twist_Yposition', 'b_l_wrist_twist_Zposition', 'b_l_wrist_Xposition', 'b_l_wrist_Yposition', 'b_l_wrist_Zposition', 'b_r_upleg_Xposition', 'b_r_upleg_Yposition', 'b_r_upleg_Zposition', 'b_r_leg_Xposition', 'b_r_leg_Yposition', 'b_r_leg_Zposition', 'b_r_foot_twist_Xposition', 'b_r_foot_twist_Yposition', 'b_r_foot_twist_Zposition', 'b_r_foot_twist_Zrotation', 'b_r_foot_twist_Xrotation', 'b_r_foot_twist_Yrotation', 'b_r_foot_Xposition', 'b_r_foot_Yposition', 'b_r_foot_Zposition', 'b_l_upleg_Xposition', 'b_l_upleg_Yposition', 'b_l_upleg_Zposition', 'b_l_leg_Xposition', 'b_l_leg_Yposition', 'b_l_leg_Zposition', 'b_l_foot_twist_Xposition', 'b_l_foot_twist_Yposition', 'b_l_foot_twist_Zposition', 'b_l_foot_twist_Zrotation', 'b_l_foot_twist_Xrotation', 'b_l_foot_twist_Yrotation', 'b_l_foot_Xposition', 'b_l_foot_Yposition', 'b_l_foot_Zposition'] + self.const_values_ = {c: X[0].values[c].values[0] for c in self.const_dims_} + + # dims = [c for c in cols if c not in self.const_dims_] + # for i, dim in enumerate(dims): + # print(i, dim) + # print(self.const_dims_) + return self + + def transform(self, X, y=None): + Q = [] + + for track in X: + # print(self.const_dims_) + t2 = track.clone() + # for key in t2.skeleton.keys(): + # if key in self.ConstDims_: + # t2.skeleton.pop(key) + # print(track.values.columns.difference(self.const_dims_)) + t2.values.drop(self.const_dims_, axis=1, inplace=True) + # t2.values = track.values[track.values.columns.difference(self.const_dims_)] + Q.append(t2) + + return Q + + def inverse_transform(self, X, copy=None): + Q = [] + + for track in X: + t2 = track.clone() + for d in self.const_dims_: + t2.values[d] = self.const_values_[d] + # t2.values.assign(d=pd.Series(data=self.const_values_[d], index = t2.values.index)) + Q.append(t2) + + return Q + + +class ConstantsRemover_withroot(BaseEstimator, TransformerMixin): + ''' + For now it just looks at the first track + ''' + + def __init__(self, eps=1e-6): + self.eps = eps + + def fit(self, X, y=None): + stds = X[0].values.std() + cols = X[0].values.columns.values + self.const_dims_ = [c for c in cols if "position" in c or "rotation" in c] + self.const_dims_.remove("body_world_Xposition") + self.const_dims_.remove("body_world_Yposition") + self.const_dims_.remove("body_world_Zposition") + self.const_dims_.append("body_world_alpha") + self.const_dims_.append("body_world_beta") + self.const_dims_.append("body_world_gamma") + # self.const_dims_ = [c for c in cols if "position" in c or "rotation" in c or (stds[c] < self.eps).any()] + # self.const_dims_ = [c for c in cols if (stds[c] < self.eps).any()] + # self.const_dims_ = ['b_l_leg_beta', 'b_l_leg_gamma', 'b_r_leg_beta', 'b_r_leg_gamma', 'body_world_alpha', 'body_world_beta', 'body_world_gamma', 'b_root_Xposition', 'b_root_Yposition', 'b_root_Zposition', 'b_spine0_Xposition', 'b_spine0_Yposition', 'b_spine0_Zposition', 'b_spine1_Xposition', 'b_spine1_Yposition', 'b_spine1_Zposition', 'b_spine2_Xposition', 'b_spine2_Yposition', 'b_spine2_Zposition', 'b_spine3_Xposition', 'b_spine3_Yposition', 'b_spine3_Zposition', 'b_neck0_Xposition', 'b_neck0_Yposition', 'b_neck0_Zposition', 'b_head_Xposition', 'b_head_Yposition', 'b_head_Zposition', 'b_head_null_Xposition', 'b_head_null_Yposition', 'b_head_null_Zposition', 'b_head_null_Zrotation', 'b_head_null_Xrotation', 'b_head_null_Yrotation', 'b_r_shoulder_Xposition', 'b_r_shoulder_Yposition', 'b_r_shoulder_Zposition', 'b_r_arm_Xposition', 'b_r_arm_Yposition', 'b_r_arm_Zposition', 'b_r_arm_twist_Xposition', 'b_r_arm_twist_Yposition', 'b_r_arm_twist_Zposition', 'b_r_forearm_Xposition', 'b_r_forearm_Yposition', 'b_r_forearm_Zposition', 'b_r_wrist_twist_Xposition', 'b_r_wrist_twist_Yposition', 'b_r_wrist_twist_Zposition', 'b_r_wrist_Xposition', 'b_r_wrist_Yposition', 'b_r_wrist_Zposition', 'b_l_shoulder_Xposition', 'b_l_shoulder_Yposition', 'b_l_shoulder_Zposition', 'b_l_arm_Xposition', 'b_l_arm_Yposition', 'b_l_arm_Zposition', 'b_l_arm_twist_Xposition', 'b_l_arm_twist_Yposition', 'b_l_arm_twist_Zposition', 'b_l_forearm_Xposition', 'b_l_forearm_Yposition', 'b_l_forearm_Zposition', 'b_l_wrist_twist_Xposition', 'b_l_wrist_twist_Yposition', 'b_l_wrist_twist_Zposition', 'b_l_wrist_Xposition', 'b_l_wrist_Yposition', 'b_l_wrist_Zposition', 'b_r_upleg_Xposition', 'b_r_upleg_Yposition', 'b_r_upleg_Zposition', 'b_r_leg_Xposition', 'b_r_leg_Yposition', 'b_r_leg_Zposition', 'b_r_foot_twist_Xposition', 'b_r_foot_twist_Yposition', 'b_r_foot_twist_Zposition', 'b_r_foot_twist_Zrotation', 'b_r_foot_twist_Xrotation', 'b_r_foot_twist_Yrotation', 'b_r_foot_Xposition', 'b_r_foot_Yposition', 'b_r_foot_Zposition', 'b_l_upleg_Xposition', 'b_l_upleg_Yposition', 'b_l_upleg_Zposition', 'b_l_leg_Xposition', 'b_l_leg_Yposition', 'b_l_leg_Zposition', 'b_l_foot_twist_Xposition', 'b_l_foot_twist_Yposition', 'b_l_foot_twist_Zposition', 'b_l_foot_twist_Zrotation', 'b_l_foot_twist_Xrotation', 'b_l_foot_twist_Yrotation', 'b_l_foot_Xposition', 'b_l_foot_Yposition', 'b_l_foot_Zposition'] + self.const_values_ = {c: X[0].values[c].values[0] for c in self.const_dims_} + + # dims = [c for c in cols if c not in self.const_dims_] + # for i, dim in enumerate(dims): + # print(i, dim) + # print(self.const_dims_) + return self + + def transform(self, X, y=None): + Q = [] + + for track in X: + # print(self.const_dims_) + t2 = track.clone() + # for key in t2.skeleton.keys(): + # if key in self.ConstDims_: + # t2.skeleton.pop(key) + # print(track.values.columns.difference(self.const_dims_)) + t2.values.drop(self.const_dims_, axis=1, inplace=True) + # t2.values = track.values[track.values.columns.difference(self.const_dims_)] + Q.append(t2) + + return Q + + def inverse_transform(self, X, copy=None): + Q = [] + + for track in X: + t2 = track.clone() + for d in self.const_dims_: + t2.values[d] = self.const_values_[d] + # t2.values.assign(d=pd.Series(data=self.const_values_[d], index = t2.values.index)) + Q.append(t2) + + return Q + + + +class ListStandardScaler(BaseEstimator, TransformerMixin): + def __init__(self, is_DataFrame=False): + self.is_DataFrame = is_DataFrame + + def fit(self, X, y=None): + if self.is_DataFrame: + X_train_flat = np.concatenate([m.values for m in X], axis=0) + else: + X_train_flat = np.concatenate([m for m in X], axis=0) + + self.data_mean_ = np.mean(X_train_flat, axis=0) + self.data_std_ = np.std(X_train_flat, axis=0) + + return self + + def transform(self, X, y=None): + Q = [] + + for track in X: + if self.is_DataFrame: + normalized_track = track.copy() + normalized_track.values = (track.values - self.data_mean_) / self.data_std_ + else: + normalized_track = (track - self.data_mean_) / self.data_std_ + + Q.append(normalized_track) + + if self.is_DataFrame: + return Q + else: + return np.array(Q) + + def inverse_transform(self, X, copy=None): + Q = [] + + for track in X: + + if self.is_DataFrame: + unnormalized_track = track.copy() + unnormalized_track.values = (track.values * self.data_std_) + self.data_mean_ + else: + unnormalized_track = (track * self.data_std_) + self.data_mean_ + + Q.append(unnormalized_track) + + if self.is_DataFrame: + return Q + else: + return np.array(Q) + +class ListMinMaxScaler(BaseEstimator, TransformerMixin): + def __init__(self, is_DataFrame=False): + self.is_DataFrame = is_DataFrame + + def fit(self, X, y=None): + if self.is_DataFrame: + X_train_flat = np.concatenate([m.values for m in X], axis=0) + else: + X_train_flat = np.concatenate([m for m in X], axis=0) + + self.data_max_ = np.max(X_train_flat, axis=0) + self.data_min_ = np.min(X_train_flat, axis=0) + + return self + + def transform(self, X, y=None): + Q = [] + + for track in X: + if self.is_DataFrame: + normalized_track = track.copy() + normalized_track.values = (track.values - self.data_min_) / (self.data_max_ - self.data_min_) + else: + normalized_track = (track - self.data_min_) / (self.data_max_ - self.data_min_) + + Q.append(normalized_track) + + if self.is_DataFrame: + return Q + else: + return np.array(Q) + + def inverse_transform(self, X, copy=None): + Q = [] + + for track in X: + + if self.is_DataFrame: + unnormalized_track = track.copy() + unnormalized_track.values = (track.values * (self.data_max_ - self.data_min_)) + self.data_min_ + else: + unnormalized_track = (track * (self.data_max_ - self.data_min_)) + self.data_min_ + + Q.append(unnormalized_track) + + if self.is_DataFrame: + return Q + else: + return np.array(Q) + + +class DownSampler(BaseEstimator, TransformerMixin): + def __init__(self, tgt_fps, keep_all=True): + self.tgt_fps = tgt_fps + self.keep_all = keep_all + + + def fit(self, X, y=None): + + return self + + def transform(self, X, y=None): + Q = [] + + for track in X: + orig_fps=round(1.0/track.framerate) + rate = orig_fps//self.tgt_fps + if orig_fps%self.tgt_fps!=0: + print("error orig_fps (" + str(orig_fps) + ") is not dividable with tgt_fps (" + str(self.tgt_fps) + ")") + else: + print("downsampling with rate: " + str(rate)) + + #print(track.values.size) + for ii in range(0,rate): + new_track = track.clone() + new_track.values = track.values[ii:-1:rate].copy() + #print(new_track.values.size) + #new_track = track[0:-1:self.rate] + new_track.framerate = 1.0/self.tgt_fps + Q.append(new_track) + if not self.keep_all: + break + + return Q + + def inverse_transform(self, X, copy=None): + return X + + +class ReverseTime(BaseEstimator, TransformerMixin): + def __init__(self, append=True): + self.append = append + + + def fit(self, X, y=None): + + return self + + def transform(self, X, y=None): + Q = [] + if self.append: + for track in X: + Q.append(track) + for track in X: + new_track = track.clone() + new_track.values = track.values[-1::-1] + Q.append(new_track) + + return Q + + def inverse_transform(self, X, copy=None): + return X + +#TODO: JointsSelector (x) +#TODO: SegmentMaker +#TODO: DynamicFeaturesAdder +#TODO: ShapeFeaturesAdder +#TODO: DataFrameNumpier (x) + +class TemplateTransform(BaseEstimator, TransformerMixin): + def __init__(self): + pass + + def fit(self, X, y=None): + return self + + def transform(self, X, y=None): + return X + diff --git a/evaluation_metric/pymo/rotation_tools.py b/evaluation_metric/pymo/rotation_tools.py new file mode 100644 index 0000000..ee4a4f4 --- /dev/null +++ b/evaluation_metric/pymo/rotation_tools.py @@ -0,0 +1,220 @@ +''' +Tools for Manipulating and Converting 3D Rotations + +By Omid Alemi +Created: June 12, 2017 + +Adapted from that matlab file... +''' + +import math +import numpy as np +import transforms3d as t3d +from scipy.spatial.transform import Rotation as R + +def deg2rad(x): + return x/180*math.pi + + +def rad2deg(x): + return x/math.pi*180 + +def unroll(rots): + + new_rots = rots.copy() + + # Compute angles and alternative rotation angles + angs = np.linalg.norm(rots, axis=1) + dotprod = np.einsum('ij,ij->i', rots[:-1,:], rots[1:,:]) + alt_angs=2*np.pi-angs + + #find discontinuities + d_angs = np.diff(angs, axis=0) + d_angs2 = alt_angs[1:]-angs[:-1] + + # check if dot product is <0 + swps = np.where((dotprod<-1))[0] + #print(np.sum(swps)) + #swps = np.where((np.abs(d_ax)>0.5))[0] + #swps = np.where(np.abs(d_angs2) 1.0e-10: + vector = rot / theta + else: + vector = np.array([1.,0.,0.]) + theta=0.0 + eul = t3d.euler.axangle2euler(vector, theta, 'r' + order.lower()) + if use_deg: + return np.rad2deg(eul) + else: + return eul + +class Rotation(): + def __init__(self,rot, param_type, **params): + self.rotmat = [] + if param_type == 'euler': + self._from_euler(rot[0],rot[1],rot[2], params) + elif param_type == 'expmap': + self._from_expmap(rot[0], rot[1], rot[2], params) + + def _from_euler(self, alpha, beta, gamma, params): + '''Expecting degress''' + + if params['from_deg']==True: + alpha = deg2rad(alpha) + beta = deg2rad(beta) + gamma = deg2rad(gamma) + + order = "s" + ((params['order']).lower())[::-1] +# Quaternions.from_euler() + self.rotmat = np.transpose(t3d.euler.euler2mat(gamma, beta , alpha, axes=order)) + +# ca = math.cos(alpha) +# cb = math.cos(beta) +# cg = math.cos(gamma) +# sa = math.sin(alpha) +# sb = math.sin(beta) +# sg = math.sin(gamma) +# +# Rx = np.asarray([[1, 0, 0], +# [0, ca, sa], +# [0, -sa, ca] +# ]) +# +# Ry = np.asarray([[cb, 0, -sb], +# [0, 1, 0], +# [sb, 0, cb]]) +# +# Rz = np.asarray([[cg, sg, 0], +# [-sg, cg, 0], +# [0, 0, 1]]) +# +# self.rotmat = np.eye(3) +# +# order = params['order'] +# for i in range(0,len(order)): +# if order[i]=='X': +# self.rotmat = np.matmul(Rx, self.rotmat) +# elif order[i]=='Y': +# self.rotmat = np.matmul(Ry, self.rotmat) +# elif order[i]=='Z': +# self.rotmat = np.matmul(Rz, self.rotmat) +# else: +# print('unknown rotation axis: ' + order[i]) +# +# # self.rotmat = np.matmul(np.matmul(Rz, Ry), Rx) +# print ("------" + "TRUE") +# print (self.rotmat) + + def _from_expmap(self, alpha, beta, gamma, params): + if (alpha == 0 and beta == 0 and gamma == 0): + self.rotmat = np.eye(3) + return + + #TODO: Check exp map params + + theta = np.linalg.norm([alpha, beta, gamma]) + + expmap = [alpha, beta, gamma] / theta + + x = expmap[0] + y = expmap[1] + z = expmap[2] + + s = math.sin(theta/2) + c = math.cos(theta/2) + + self.rotmat = np.asarray([ + [2*(x**2-1)*s**2+1, 2*x*y*s**2-2*z*c*s, 2*x*z*s**2+2*y*c*s], + [2*x*y*s**2+2*z*c*s, 2*(y**2-1)*s**2+1, 2*y*z*s**2-2*x*c*s], + [2*x*z*s**2-2*y*c*s, 2*y*z*s**2+2*x*c*s , 2*(z**2-1)*s**2+1] + ]) + + + + def get_euler_axis(self): + R = self.rotmat + theta = math.acos((self.rotmat.trace() - 1) / 2) + axis = np.asarray([R[2,1] - R[1,2], R[0,2] - R[2,0], R[1,0] - R[0,1]]) + axis = axis/(2*math.sin(theta)) + return theta, axis + + def to_expmap(self): + axis, theta = t3d.axangles.mat2axangle(self.rotmat, unit_thresh=1e-05) +# theta, axis = self.get_euler_axis() + rot_arr = theta * axis + if np.isnan(rot_arr).any(): + rot_arr = [0, 0, 0] + return rot_arr + + def to_euler(self, use_deg=False, order='xyz'): + order = "s" + order.lower() + eulers = t3d.euler.mat2euler(np.transpose(self.rotmat), axes=order) + return eulers[::-1] + +# eulers = np.zeros((2, 3)) +# +# if np.absolute(np.absolute(self.rotmat[2, 0]) - 1) < 1e-12: +# #GIMBAL LOCK! +# print('Gimbal') +# if np.absolute(self.rotmat[2, 0]) - 1 < 1e-12: +# eulers[:,0] = math.atan2(-self.rotmat[0,1], -self.rotmat[0,2]) +# eulers[:,1] = -math.pi/2 +# else: +# eulers[:,0] = math.atan2(self.rotmat[0,1], -elf.rotmat[0,2]) +# eulers[:,1] = math.pi/2 +# +# return eulers +# +# theta = - math.asin(self.rotmat[2,0]) +# theta2 = math.pi - theta +# +# # psi1, psi2 +# eulers[0,0] = math.atan2(self.rotmat[2,1]/math.cos(theta), self.rotmat[2,2]/math.cos(theta)) +# eulers[1,0] = math.atan2(self.rotmat[2,1]/math.cos(theta2), self.rotmat[2,2]/math.cos(theta2)) +# +# # theta1, theta2 +# eulers[0,1] = theta +# eulers[1,1] = theta2 +# +# # phi1, phi2 +# eulers[0,2] = math.atan2(self.rotmat[1,0]/math.cos(theta), self.rotmat[0,0]/math.cos(theta)) +# eulers[1,2] = math.atan2(self.rotmat[1,0]/math.cos(theta2), self.rotmat[0,0]/math.cos(theta2)) +# + if use_deg: + eulers = rad2deg(eulers) + + return eulers + + def to_quat(self): + #TODO + pass + + def __str__(self): + return "Rotation Matrix: \n " + self.rotmat.__str__() + + + + diff --git a/evaluation_metric/pymo/viz_tools.py b/evaluation_metric/pymo/viz_tools.py new file mode 100644 index 0000000..be325be --- /dev/null +++ b/evaluation_metric/pymo/viz_tools.py @@ -0,0 +1,235 @@ +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +import os + +def save_fig(fig_id, tight_layout=True): + if tight_layout: + plt.tight_layout() + plt.savefig(fig_id + '.png', format='png', dpi=300) + + +def draw_stickfigure(mocap_track, frame, data=None, joints=None, draw_names=False, ax=None, figsize=(8,8)): + if ax is None: + fig = plt.figure(figsize=figsize) + ax = fig.add_subplot(111) + + if joints is None: + joints_to_draw = mocap_track.skeleton.keys() + else: + joints_to_draw = joints + + if data is None: + df = mocap_track.values + else: + df = data + + for joint in joints_to_draw: + ax.scatter(x=df['%s_Xposition'%joint][frame], + y=df['%s_Yposition'%joint][frame], + alpha=0.6, c='b', marker='o') + + parent_x = df['%s_Xposition'%joint][frame] + parent_y = df['%s_Yposition'%joint][frame] + + children_to_draw = [c for c in mocap_track.skeleton[joint]['children'] if c in joints_to_draw] + + for c in children_to_draw: + child_x = df['%s_Xposition'%c][frame] + child_y = df['%s_Yposition'%c][frame] + ax.plot([parent_x, child_x], [parent_y, child_y], 'k-', lw=2) + + if draw_names: + ax.annotate(joint, + (df['%s_Xposition'%joint][frame] + 0.1, + df['%s_Yposition'%joint][frame] + 0.1)) + + return ax + +def draw_stickfigure3d(mocap_track, frame, data=None, joints=None, draw_names=False, ax=None, figsize=(8,8)): + from mpl_toolkits.mplot3d import Axes3D + + if ax is None: + fig = plt.figure(figsize=figsize) + ax = fig.add_subplot(111, projection='3d') + + if joints is None: + joints_to_draw = mocap_track.skeleton.keys() + else: + joints_to_draw = joints + + if data is None: + df = mocap_track.values + else: + df = data + + for joint in joints_to_draw: + parent_x = df['%s_Xposition'%joint][frame] + parent_y = df['%s_Zposition'%joint][frame] + parent_z = df['%s_Yposition'%joint][frame] + # ^ In mocaps, Y is the up-right axis + + ax.scatter(xs=parent_x, + ys=parent_y, + zs=parent_z, + alpha=0.6, c='b', marker='o') + + + children_to_draw = [c for c in mocap_track.skeleton[joint]['children'] if c in joints_to_draw] + + for c in children_to_draw: + child_x = df['%s_Xposition'%c][frame] + child_y = df['%s_Zposition'%c][frame] + child_z = df['%s_Yposition'%c][frame] + # ^ In mocaps, Y is the up-right axis + + ax.plot([parent_x, child_x], [parent_y, child_y], [parent_z, child_z], 'k-', lw=2, c='black') + + if draw_names: + ax.text(x=parent_x + 0.1, + y=parent_y + 0.1, + z=parent_z + 0.1, + s=joint, + color='rgba(0,0,0,0.9') + + return ax + + +def sketch_move(mocap_track, data=None, ax=None, figsize=(16,8)): + if ax is None: + fig = plt.figure(figsize=figsize) + ax = fig.add_subplot(111) + + if data is None: + data = mocap_track.values + + for frame in range(0, data.shape[0], 4): +# draw_stickfigure(mocap_track, f, data=data, ax=ax) + + for joint in mocap_track.skeleton.keys(): + children_to_draw = [c for c in mocap_track.skeleton[joint]['children']] + + parent_x = data['%s_Xposition'%joint][frame] + parent_y = data['%s_Yposition'%joint][frame] + + frame_alpha = frame/data.shape[0] + + for c in children_to_draw: + child_x = data['%s_Xposition'%c][frame] + child_y = data['%s_Yposition'%c][frame] + + ax.plot([parent_x, child_x], [parent_y, child_y], '-', lw=1, color='gray', alpha=frame_alpha) + + + +def viz_cnn_filter(feature_to_viz, mocap_track, data, gap=25): + fig = plt.figure(figsize=(16,4)) + ax = plt.subplot2grid((1,8),(0,0)) + ax.imshow(feature_to_viz.T, aspect='auto', interpolation='nearest') + + ax = plt.subplot2grid((1,8),(0,1), colspan=7) + for frame in range(feature_to_viz.shape[0]): + frame_alpha = 0.2#frame/data.shape[0] * 2 + 0.2 + + for joint_i, joint in enumerate(mocap_track.skeleton.keys()): + children_to_draw = [c for c in mocap_track.skeleton[joint]['children']] + + parent_x = data['%s_Xposition'%joint][frame] + frame * gap + parent_y = data['%s_Yposition'%joint][frame] + + ax.scatter(x=parent_x, + y=parent_y, + alpha=0.6, + cmap='RdBu', + c=feature_to_viz[frame][joint_i] * 10000, + marker='o', + s = abs(feature_to_viz[frame][joint_i] * 10000)) + plt.axis('off') + for c in children_to_draw: + child_x = data['%s_Xposition'%c][frame] + frame * gap + child_y = data['%s_Yposition'%c][frame] + + ax.plot([parent_x, child_x], [parent_y, child_y], '-', lw=1, color='gray', alpha=frame_alpha) + + +def print_skel(X): + stack = [X.root_name] + tab=0 + while stack: + joint = stack.pop() + tab = len(stack) + print('%s- %s (%s)'%('| '*tab, joint, X.skeleton[joint]['parent'])) + for c in X.skeleton[joint]['children']: + stack.append(c) + + +def nb_play_mocap_fromurl(mocap, mf, frame_time=1/30, scale=1, base_url='http://titan:8385'): + if mf == 'bvh': + bw = BVHWriter() + with open('test.bvh', 'w') as ofile: + bw.write(mocap, ofile) + + filepath = '../notebooks/test.bvh' + elif mf == 'pos': + c = list(mocap.values.columns) + + for cc in c: + if 'rotation' in cc: + c.remove(cc) + mocap.values.to_csv('test.csv', index=False, columns=c) + + filepath = '../notebooks/test.csv' + else: + return + + url = '%s/mocapplayer/player.html?data_url=%s&scale=%f&cz=200&order=xzyi&frame_time=%f'%(base_url, filepath, scale, frame_time) + iframe = '' + link = 'New Window'%url + return IPython.display.HTML(iframe+link) + +def nb_play_mocap(mocap, mf, meta=None, frame_time=1/30, scale=1, camera_z=500, base_url=None): + data_template = 'var dataBuffer = `$$DATA$$`;' + data_template += 'var metadata = $$META$$;' + data_template += 'start(dataBuffer, metadata, $$CZ$$, $$SCALE$$, $$FRAMETIME$$);' + dir_path = os.path.dirname(os.path.realpath(__file__)) + + + if base_url is None: + base_url = os.path.join(dir_path, 'mocapplayer/playBuffer.html') + + # print(dir_path) + + if mf == 'bvh': + pass + elif mf == 'pos': + cols = list(mocap.values.columns) + for c in cols: + if 'rotation' in c: + cols.remove(c) + + data_csv = mocap.values.to_csv(index=False, columns=cols) + + if meta is not None: + lines = [','.join(item) for item in meta.astype('str')] + meta_csv = '[' + ','.join('[%s]'%l for l in lines) +']' + else: + meta_csv = '[]' + + data_assigned = data_template.replace('$$DATA$$', data_csv) + data_assigned = data_assigned.replace('$$META$$', meta_csv) + data_assigned = data_assigned.replace('$$CZ$$', str(camera_z)) + data_assigned = data_assigned.replace('$$SCALE$$', str(scale)) + data_assigned = data_assigned.replace('$$FRAMETIME$$', str(frame_time)) + + else: + return + + + + with open(os.path.join(dir_path, 'mocapplayer/data.js'), 'w') as oFile: + oFile.write(data_assigned) + + url = '%s?&cz=200&order=xzyi&frame_time=%f&scale=%f'%(base_url, frame_time, scale) + iframe = '' + link = 'New Window'%url + return IPython.display.HTML(iframe+link) \ No newline at end of file diff --git a/evaluation_metric/pymo/writers.py b/evaluation_metric/pymo/writers.py new file mode 100644 index 0000000..630dbdf --- /dev/null +++ b/evaluation_metric/pymo/writers.py @@ -0,0 +1,70 @@ +import numpy as np +import pandas as pd + +class BVHWriter(): + def __init__(self): + pass + + def write(self, X, ofile, framerate=-1): + + # Writing the skeleton info + ofile.write('HIERARCHY\n') + + self.motions_ = [] + self._printJoint(X, X.root_name, 0, ofile) + + # Writing the motion header + ofile.write('MOTION\n') + ofile.write('Frames: %d\n'%X.values.shape[0]) + + if framerate > 0: + ofile.write('Frame Time: %f\n'%float(1.0/framerate)) + else: + ofile.write('Frame Time: %f\n'%X.framerate) + + # Writing the data + self.motions_ = np.asarray(self.motions_).T + lines = [" ".join(item) for item in self.motions_.astype(str)] + ofile.write("".join("%s\n"%l for l in lines)) + + def _printJoint(self, X, joint, tab, ofile): + + if X.skeleton[joint]['parent'] == None: + ofile.write('ROOT %s\n'%joint) + elif len(X.skeleton[joint]['children']) > 0: + ofile.write('%sJOINT %s\n'%('\t'*(tab), joint)) + else: + ofile.write('%sEnd site\n'%('\t'*(tab))) + + ofile.write('%s{\n'%('\t'*(tab))) + + ofile.write('%sOFFSET %3.5f %3.5f %3.5f\n'%('\t'*(tab+1), + X.skeleton[joint]['offsets'][0], + X.skeleton[joint]['offsets'][1], + X.skeleton[joint]['offsets'][2])) + rot_order = X.skeleton[joint]['order'] + + #print("rot_order = " + rot_order) + channels = X.skeleton[joint]['channels'] + rot = [c for c in channels if ('rotation' in c)] + pos = [c for c in channels if ('position' in c)] + + n_channels = len(rot) +len(pos) + ch_str = '' + if n_channels > 0: + for ci in range(len(pos)): + cn = pos[ci] + self.motions_.append(np.asarray(X.values['%s_%s'%(joint,cn)].values)) + ch_str = ch_str + ' ' + cn + for ci in range(len(rot)): + cn = '%srotation'%(rot_order[ci]) + self.motions_.append(np.asarray(X.values['%s_%s'%(joint,cn)].values)) + ch_str = ch_str + ' ' + cn + if len(X.skeleton[joint]['children']) > 0: + #ch_str = ''.join(' %s'*n_channels%tuple(channels)) + ofile.write('%sCHANNELS %d%s\n' %('\t'*(tab+1), n_channels, ch_str)) + + for c in X.skeleton[joint]['children']: + self._printJoint(X, c, tab+1, ofile) + + ofile.write('%s}\n'%('\t'*(tab))) diff --git a/evaluation_metric/train_AE.py b/evaluation_metric/train_AE.py new file mode 100644 index 0000000..5dbd323 --- /dev/null +++ b/evaluation_metric/train_AE.py @@ -0,0 +1,150 @@ +import glob +import os + +import numpy as np +import torch +import torch.nn.functional as F +from torch import optim +from torch.utils.data import TensorDataset, DataLoader + +from evaluation_metric.embedding_net import EmbeddingNet +from tqdm import tqdm + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + + +def train_iter(target_data, net, optim): + # zero gradients + optim.zero_grad() + + # reconstruction loss + feat, recon_data = net(target_data) + recon_loss = F.l1_loss(recon_data, target_data, reduction='none') + recon_loss = torch.mean(recon_loss, dim=(1, 2)) + + if True: # use pose diff + target_diff = target_data[:, 1:] - target_data[:, :-1] + recon_diff = recon_data[:, 1:] - recon_data[:, :-1] + recon_loss += torch.mean(F.l1_loss(recon_diff, target_diff, reduction='none'), dim=(1, 2)) + + recon_loss = torch.sum(recon_loss) + + recon_loss.backward() + optim.step() + + ret_dict = {'loss': recon_loss.item()} + return ret_dict + + +def make_tensor(path, n_frames, stride=None, max_files=64, n_chunks=None): + # use calculate_mean_std.py to get mean/std values + #mean_vec = np.array([0.06847747184608119,0.0357063580846738,0.058578770784627525,0.06847747184608119,183.03570635799917,0.058578770784627525,16.923291745240235,168.35926808117108,0.041955713046869694,22.215592763704525,96.00786795875209,7.367188758228035,-17.300431170968334,168.94294163004272,-0.24727338642555735,-20.8715950848037,96.38631298614739,5.896949875381165,0.20216133904775593,186.95854383542118,-6.432579379678317,0.6098873474970944,208.91504190095858,-4.915889192590782,1.0909151123302776,229.8609722171563,-12.6793862809497,1.928591364103774,266.5012367723191,-15.976334397752815,7.714763567890464,280.55824733756486,1.9359663647433214,39.544279960636416,270.51838703291963,-9.270711837971861,39.544279960636416,270.51838703291963,-9.270711837971861,50.81093482089116,222.26787089133526,-4.277154500217943,50.17939487042671,194.5012435696503,21.649902584660502,47.55546776177503,193.07937915800173,24.449814516984887,51.68258742769814,183.1130622846742,28.598284340608124,52.31257982349485,178.7300103572332,29.752167196655417,52.11476016985293,176.2234409034844,29.338322795250296,49.57245060363262,183.6685185194845,30.764422876178404,49.51180617225469,177.9781510143326,31.345989263749335,48.984835786867286,174.72433725238744,30.80040989078595,46.938570200832714,184.18559365577823,32.38094931234402,46.63205116111747,177.99271292951627,33.3617908178335,45.70316718692669,175.00854125031574,32.71919596837255,43.53471060814086,184.72586361214883,32.65900465541941,43.144235682959376,179.3784708387911,34.567474274114595,42.55463405258724,176.60663073315192,34.61292128747094,46.18735692733591,192.1666386454192,25.212175451995105,42.415178023338235,190.40999141207104,24.34708837726526,38.476911084063474,186.96434952919546,26.5919414956809,36.13208594764687,183.66904789495973,28.92036860488664,-3.260172112764866,280.83999461595937,1.9124809183428453,-35.07458946647343,272.1779993206412,-10.188934519524462,-35.07458946647343,272.1779993206412,-10.188934519524462,-44.221430642610315,224.12986592661338,-2.1810619636426996,-48.22807185664309,197.1836151601496,21.820415520252148,-46.33444054575024,195.86804954902195,25.039544449623758,-45.13668197604728,194.955477133584,26.116667176956216,-41.092202689799706,193.1729063777209,26.266446799572222,-37.40708496857298,189.6782655958083,29.307581254228797,-35.636081168567486,186.31843980348182,31.96766893075243,-50.79671007842462,186.06135889526658,28.08238050295619,-51.43370660186226,181.74414430313067,29.05281490814699,-50.935496917944754,179.24853570625916,28.80991130269415,-47.240733921717286,187.10254399331873,32.83320587793653,-46.90782281703879,180.84945866262322,34.0909044604069,-45.500171172703396,177.78021908057028,33.90137244712666,-49.34991372818972,186.60529710991787,30.652207359792364,-49.09096325437287,180.95734931633348,31.311876315818232,-48.05107944797273,177.7834688982274,31.007569225066618,-44.02123442944895,187.60657246793522,33.93921950361445,-43.76597286425863,182.3069165226035,36.05010015011614,-42.954543530752,179.48150399466223,36.4037713410187,2.5187557703509666,291.83247385777486,-9.389368690830807,2.4539357753355553,304.03709615227046,-2.881607696619349]) + #std_vec = np.array([8.255541100752382,0.6939720601761805,5.222482738827612,8.255541100752382,0.693972060176339,5.222482738827612,9.42100963255725,2.188372644619285,6.361481612445378,6.176403474432683,2.03443820178515,11.830267060993494,9.348297312453099,1.9986317079084224,6.001779897668299,8.162462099200143,2.2708856073173265,12.143746363244206,8.07002281773558,1.146760813248187,5.218745502524028,7.302334393532628,1.0869661835365756,5.640455884072193,7.477263866275783,1.2241260130341458,5.704034599390462,8.93140892031836,1.3043997705897612,6.246232138662631,9.656330236153833,1.7514046003701789,6.8102931179285715,9.369187221685337,3.7379417578069427,9.746046092317828,9.369187221685337,3.7379417578069427,9.746046092317828,10.227838670518729,7.286263937298487,13.107272130744121,17.056602336480278,28.775355928036703,17.65156414660871,17.616759573364362,31.57948329721122,17.58583620094351,22.409700875985976,36.64470600046972,22.894290518164947,24.53008313723881,39.07577209642486,25.295457925596633,25.690655337228268,40.177065667200765,26.808543783812727,22.2591494319307,38.401747944200125,21.950631662738534,24.854161656529712,41.50067854859389,25.042011259666936,26.473295311534336,43.197994006554005,27.221170133903826,22.10861269363438,40.10343062678119,21.14882281947506,25.013596742327397,43.87399009904381,24.52999507148788,26.52013512923851,45.59851255330215,26.774862236150504,21.725195479879996,41.3712584397316,20.536614369930735,24.39943341761268,45.32943762847826,23.38680984964916,25.786178720369854,47.14171501427746,25.319483052572473,17.914690428682025,32.89781536653379,17.7403567588838,18.181332875330988,34.64247649156103,18.418610850063928,19.771894506185426,39.40693746620956,19.868089428475518,21.55081438729135,43.62997102172089,21.5116223452909,9.702982004349828,1.7826399048001014,6.831726011211061,9.990627742798965,4.4897264818483995,9.924705741596368,9.990627742798965,4.4897264818483995,9.924705741596368,11.786532337428321,9.522123886195812,12.65776670636571,17.031859201169226,33.584280601824254,18.38674479122489,17.777049225781234,36.38401573078446,18.127686404978764,18.17511401763394,37.691189634409334,18.084351816160314,18.551140073870705,39.31562974080336,18.02247544203335,20.28661077866561,43.99125147305048,18.701801984691905,22.30030442409078,48.11743158915539,19.99370126076239,21.89724529987023,42.49461135019713,23.793704882062592,23.74127743788964,45.32675754902733,26.24880800837308,24.594991051498763,46.72175188102939,27.67702463406083,22.183607917647112,45.59612353311565,21.613578453862164,24.6049569241885,50.12226277439626,24.67927794755758,25.732481639988514,52.301113166777355,26.560623330610035,22.023636510505785,44.0923860001992,22.687839346541782,24.11276937550739,47.8236277337321,25.586912334755322,25.285414927567018,49.8865610793381,27.509628579848478,22.140064803278623,46.596465179326735,20.51837959275402,24.421617248855064,51.09661706141259,22.99069540500737,25.553481257475042,53.30714083193669,24.588120394659704,10.509501956123042,1.3845656231709993,7.232725520260621,10.735209001060486,1.9048225062899646,7.924571615243941]) + mean_vec = np.load('./dataset/Genea2023/trn/main-agent/rotpos_Mean.npy') + std_vec = np.load('./dataset/Genea2023/trn/main-agent/rotpos_Std.npy') + idx_positions = np.asarray([ [i*6+3, i*6+4, i*6+5] for i in range(int(len(mean_vec)/6)) ]).flatten() + mean_vec = mean_vec[idx_positions] + std_vec = std_vec[idx_positions] + + if os.path.isdir(path): + files = glob.glob(os.path.join(path, '*.npy')) + else: + files = [path] + + files.sort() + + # Make sure we don't run out of memory + max_files = max_files if max_files < len(files) else len(files) + + samples = [] + stride = n_frames // 2 if stride is None else stride + print('Preparing data...') + for file in files[:max_files]: + print('Loading {}'.format(file)) + data = np.load(file) #Should be shape [frames, features (joint rotations)] + data = data[:, idx_positions] + for i in range(0, len(data) - n_frames, stride)[:n_chunks]: + sample = data[i:i+n_frames] + sample = (sample - mean_vec) / std_vec + samples.append(sample) + + print('Converting to tensor...') + return torch.Tensor(samples) + + +def main(n_frames): + #https://github.com/genea-workshop/genea_challenge_2023/tree/main/evaluation_metric + # dataset + train_dataset = TensorDataset(make_tensor(f'./dataset/Genea2023/trn/main-agent/motion_npy_rotpos', n_frames)) + print('Done') + train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True, drop_last=True) + + # train + loss_meters = [AverageMeter('loss')] + + # interval params + print_interval = int(len(train_loader) / 5) + + # init model and optimizer + pose_dim = 249 + generator = EmbeddingNet(pose_dim, n_frames).to(device) + gen_optimizer = optim.Adam(generator.parameters(), lr=0.001, betas=(0.5, 0.999)) + + print('Training...') + # training + for epoch in range(100): + for iter_idx, target in enumerate(train_loader, 0): + target = target[0] + batch_size = target.size(0) + target_vec = target.to(device) + loss = train_iter(target_vec, generator, gen_optimizer) + + # loss values + for loss_meter in loss_meters: + name = loss_meter.name + if name in loss: + loss_meter.update(loss[name], batch_size) + + # print training status + if (iter_idx + 1) % print_interval == 0: + print_summary = 'EP {} ({:3d}) | '.format(epoch, iter_idx + 1) + for loss_meter in loss_meters: + if loss_meter.count > 0: + print_summary += '{}: {:.3f}, '.format(loss_meter.name, loss_meter.avg) + loss_meter.reset() + print(print_summary) + + # save model + gen_state_dict = generator.state_dict() + save_name = f'./evaluation_metric/output/model_checkpoint_{n_frames}.bin' + torch.save({'pose_dim': pose_dim, 'n_frames': n_frames, 'gen_dict': gen_state_dict}, save_name) + + +if __name__ == '__main__': + n_frames = 120 + print('Using n_frames: {}'.format(n_frames)) + main(n_frames) diff --git a/ggvad_container.sh b/ggvad_container.sh new file mode 100644 index 0000000..4e44b9c --- /dev/null +++ b/ggvad_container.sh @@ -0,0 +1,11 @@ +while getopts g:n:p: flag +do + case "${flag}" in + g) gpu=${OPTARG};; + n) number=${OPTARG};; + p) port=${OPTARG};; + esac +done +echo "Running container ggvad_container_$number on gpu $gpu and port $port"; + +docker run --rm -it --gpus device=$gpu --userns=host --shm-size 64G -v /work/rodolfo.tonoli/ggvad-genea2023:/workspace/ggvad/ -p $port --name ggvad_container$number ggvad:latest /bin/bash \ No newline at end of file diff --git a/model/local_attention_diffstylegest.py b/model/local_attention_diffstylegest.py new file mode 100644 index 0000000..f9d10f0 --- /dev/null +++ b/model/local_attention_diffstylegest.py @@ -0,0 +1,172 @@ +import torch +import math +import torch.nn.functional as F +from torch import nn, einsum + +from einops import rearrange, repeat, pack, unpack + +TOKEN_SELF_ATTN_VALUE = -5e4 + +def exists(val): + return val is not None + +def default(value, d): + return d if not exists(value) else value + +def to(t): + return {'device': t.device, 'dtype': t.dtype} + +def max_neg_value(tensor): + return -torch.finfo(tensor.dtype).max + +def l2norm(tensor): + dtype = tensor.dtype + normed = F.normalize(tensor, dim = -1) + return normed.type(dtype) + +def pad_to_multiple(tensor, multiple, dim=-1, value=0): + seqlen = tensor.shape[dim] + m = seqlen / multiple + if m.is_integer(): + return False, tensor + remainder = math.ceil(m) * multiple - seqlen + pad_offset = (0,) * (-1 - dim) * 2 + return True, F.pad(tensor, (*pad_offset, 0, remainder), value = value) + +def look_around(x, backward = 1, forward = 0, pad_value = -1, dim = 2): + t = x.shape[1] + dims = (len(x.shape) - dim) * (0, 0) + padded_x = F.pad(x, (*dims, backward, forward), value = pad_value) + tensors = [padded_x[:, ind:(ind + t), ...] for ind in range(forward + backward + 1)] + return torch.cat(tensors, dim = dim) + +class SinusoidalEmbeddings(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, x): + n = x.shape[-2] + t = torch.arange(n, device = x.device).type_as(self.inv_freq) + freqs = torch.einsum('i , j -> i j', t, self.inv_freq) + return torch.cat((freqs, freqs), dim=-1) + +def rotate_half(x): + x = rearrange(x, 'b ... (r d) -> b (...) r d', r = 2) + x1, x2 = x.unbind(dim = -2) + return torch.cat((-x2, x1), dim = -1) + +def apply_rotary_pos_emb(q, k, freqs): + q, k = map(lambda t: (t * freqs.cos()) + (rotate_half(t) * freqs.sin()), (q, k)) + return q, k + + +class LocalAttention(nn.Module): + def __init__( + self, + window_size, + causal = False, + look_backward = 1, + look_forward = None, + dropout = 0., + shared_qk = False, + exact_windowsize = False + ): + super().__init__() + look_forward = default(look_forward, 0 if causal else 1) + assert not (causal and look_forward > 0), 'you cannot look forward if causal' + + self.window_size = window_size + self.exact_windowsize = exact_windowsize + + self.causal = causal + + self.look_backward = look_backward + self.look_forward = look_forward + + self.dropout = nn.Dropout(dropout) + + self.shared_qk = shared_qk + + def forward(self, q, k, v, packed_shape, mask = None, input_mask = None): + mask = default(mask, input_mask) + + pad_value, window_size, causal, look_backward, look_forward, shared_qk = -1, self.window_size, self.causal, self.look_backward, self.look_forward, self.shared_qk + + b, n, dim_head, device, dtype = *q.shape, q.device, q.dtype + scale = dim_head ** -0.5 + + if n % window_size != 0: + print('sequence length must be divisible by window size for local attention', n, window_size) + assert True + + windows = n // window_size + + if shared_qk: + k = l2norm(k) + + seq = torch.arange(n, device = device) + b_t = rearrange(seq, '(w n) -> 1 w n', w = windows, n = window_size) + + bq, bk, bv = map(lambda t: rearrange(t, 'b (w n) d -> b w n d', w = windows), (q, k, v)) + + look_around_kwargs = dict( + backward = look_backward, + forward = look_forward, + pad_value = pad_value + ) + + bk = look_around(bk, **look_around_kwargs) + bv = look_around(bv, **look_around_kwargs) + + bq_t = b_t + bq_k = look_around(b_t, **look_around_kwargs) + + bq_t = rearrange(bq_t, '... i -> ... i 1') + bq_k = rearrange(bq_k, '... j -> ... 1 j') + + sim = einsum('b h i e, b h j e -> b h i j', bq, bk) * scale + + mask_value = max_neg_value(sim) + + if shared_qk: + self_mask = bq_t == bq_k + sim = sim.masked_fill(self_mask, TOKEN_SELF_ATTN_VALUE) + del self_mask + + if causal: + causal_mask = bq_t < bq_k + + if self.exact_windowsize: + max_causal_window_size = (self.window_size * self.look_backward) + causal_mask = causal_mask | (bq_t > (bq_k + max_causal_window_size)) + + sim = sim.masked_fill(causal_mask, mask_value) + del causal_mask + + if exists(mask): + batch = mask.shape[0] + assert (b % batch) == 0 + + h = b // mask.shape[0] + + mask = rearrange(mask, '... (w n) -> (...) w n', w = windows, n = window_size) + mask = look_around(mask, **{**look_around_kwargs, 'pad_value': False}) + mask = rearrange(mask, '... j -> ... 1 j') + mask = repeat(mask, 'b ... -> (b h) ...', h = h) + sim = sim.masked_fill(~mask, mask_value) + del mask + + # attention + + attn = sim.softmax(dim = -1) + attn = self.dropout(attn) + + # aggregation + + out = einsum('b h i j, b h j e -> b h i e', attn, bv) + out = rearrange(out, 'b w n d -> b (w n) d') + + out, *_ = unpack(out, packed_shape, '* n d') + return out \ No newline at end of file diff --git a/model/mdm.py b/model/mdm.py new file mode 100644 index 0000000..35beba9 --- /dev/null +++ b/model/mdm.py @@ -0,0 +1,481 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import clip +from model.local_attention_diffstylegest import SinusoidalEmbeddings, apply_rotary_pos_emb +from model.local_attention_diffstylegest import LocalAttention + +class MDM(nn.Module): + def __init__(self, njoints, nfeats, pose_rep, data_rep, latent_dim=256, text_dim=64, ff_size=1024, + num_layers=8, num_heads=4, dropout=0.1, activation="gelu", + dataset='amass', clip_dim=512, clip_version=None, **kargs): + super().__init__() + print('Using MDM V2 (w/ CrossAtt+RPM)') + + # General Configs + self.dataset = dataset + self.pose_rep = pose_rep + self.data_rep = data_rep + self.njoints = njoints + self.nfeats = nfeats + self.input_feats = self.njoints * self.nfeats + self.latent_dim = latent_dim + self.dropout = dropout + + # Timestep Network + self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) + self.embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder) + + # Text Encoder + self.use_text = kargs.get('use_text', False) + self.cond_mask_prob = kargs.get('cond_mask_prob', 0.) + self.text_dim = text_dim + self.clip_dim = clip_dim + if self.use_text: + self.embed_text = nn.Linear(self.clip_dim, self.text_dim) + print('Using Text') + print('Loading CLIP...') + self.clip_version = clip_version + self.clip_model = self.load_and_freeze_clip(clip_version) + + # VAD + self.use_vad = kargs.get('use_vad', False) + if self.use_vad: + vad_lat_dim = int(self.latent_dim) + self.vad_lookup = nn.Embedding(2, vad_lat_dim) + print('Using VAD') + + # Seed Pose Encoder + self.seed_poses = kargs.get('seed_poses', 0) + print('Using {} Seed Poses.'.format(self.seed_poses)) + if self.seed_poses > 0: + if self.use_text: + self.seed_pose_encoder = SeedPoseEncoder(self.njoints, self.seed_poses, self.latent_dim - self.text_dim) + else: + self.seed_pose_encoder = SeedPoseEncoder(self.njoints, self.seed_poses, self.latent_dim) + + # Audio Encoder + self.mfcc_input = kargs.get('mfcc_input', False) + self.use_wav_enc = kargs.get('use_wav_enc', False) + self.use_wavlm = kargs.get('use_wavlm', False) + print('Using Audio Features:') + if self.mfcc_input: + self.mfcc_dim = 26 + self.audio_feat_dim = 64 + self.wavlm_encoder = nn.Linear(26, self.audio_feat_dim) + print('Selected Features: MFCCs') + if self.use_wav_enc: + self.wav_enc_dim = 32 + self.audio_feat_dim = self.wav_enc_dim + print('Selected Features: WavEncoder Representations') + self.wav_encoder = WavEncoder() + if self.use_wavlm: + self.wavlm_proj_dim = 64 + self.audio_feat_dim = self.wavlm_proj_dim + self.wavlm_encoder = nn.Linear(768, self.audio_feat_dim) + print('Selected Features: WavLM Representations') + + # Pose Encoder + self.input_process = InputProcess(self.data_rep, self.input_feats, self.latent_dim) + + # Cross-Local Attention + self.cl_head=8 + if self.use_vad: + self.project_to_lat = nn.Linear(self.latent_dim * 3 + self.audio_feat_dim, self.latent_dim) + #self.project_to_lat = nn.Linear(vad_lat_dim + self.audio_feat_dim + self.latent_dim*2, self.latent_dim) + else: + self.project_to_lat = nn.Linear(self.latent_dim * 2 + self.audio_feat_dim, self.latent_dim) + self.cross_local_attention = LocalAttention( + # dim=32, # dimension of each head (you need to pass this in for relative positional encoding) + window_size=10, + causal=True, + look_backward=1, + look_forward=0, + dropout=0.1, + exact_windowsize=False + ) + + # Positional Encodings + self.rel_pos = SinusoidalEmbeddings(self.latent_dim // self.cl_head) + + # Self-Attention + self.num_heads = num_heads + self.ff_size = ff_size + self.activation = activation + self.num_layers = num_layers + self.seqTransEncoder = nn.TransformerEncoder(nn.TransformerEncoderLayer( + d_model=self.latent_dim, + nhead=self.num_heads, + dim_feedforward=self.ff_size, + dropout=self.dropout, + activation=self.activation), + num_layers=self.num_layers) + + # Project Representation to Output Pose + self.output_process = OutputProcess(self.data_rep, self.input_feats, self.latent_dim, self.njoints, + self.nfeats) + + self.log_train = False + self.batch_log = {'text': [], + 'vad': [], + 'seed': [], + 'timestep': [], + 'audio': [], + 'poses': [], + 'fg_embs': [], + 'coa_embs': [], + 'embs': [], + 'audiovad': []} + #self.log_seed = [] + #self.log_text = [] + #self.log_timestep = [] + #self.log_audio = [] + #self.log_vad = [] + #self.log_poses = [] + #self.log_fg_embs = [] + #self.log_coa_embs = [] + #self.log_embs = [] + + def forward(self, x, timesteps, y=None): + """ + x: [batch_size, njoints, nfeats, max_frames], denoted x_t in the paper + timesteps: [batch_size] (int) + """ + # Sizes + bs, njoints, nfeats, nframes = x.shape # [BS, POSE_DIM, 1, CHUNK_LEN] + force_mask = y.get('uncond', False) # TODO: UNDERSTAND MASK + + ############################# + #### FEATURE CALCULATION #### + ############################# + + # Text Embeddings + if self.use_text: + enc_text = self.encode_text(y['text']) + emb_text = self.embed_text(self.mask_cond(enc_text, force_mask=force_mask)) # [1, BS, TEXT_DIM] + emb_text = emb_text.squeeze(0) # [BS, TEXT_DIM] + + # Seed Poses Embeddings + flat_seed = y['seed'].squeeze(2).reshape(bs, -1) # [BS, POSE_DIM, 1, SEED_POSES] -> [BS, POSE_DIM, SEED_POSES] -> [BS, POSE_DIM * SEED_POSES] + #emb_seed = self.seed_pose_encoder(flat_seed) + emb_seed = self.seed_pose_encoder(self.mask_cond(flat_seed, force_mask=force_mask)) # [BS, LAT_DIM] or [BS, LAT_DIM - TEXT_DIM] + + # VAD Embeddings + if self.use_vad: + vad_vals = y['vad'] # [BS, CHUNK_LEN] + emb_vad = self.vad_lookup(vad_vals) # [BS, CHUNK_LEN, LAT_DIM] + emb_vad = emb_vad.permute(1, 0, 2) # [CHUNK_LEN, BS, LAT_DIM] + + # Timesteps Embeddings + emb_t = self.embed_timestep(timesteps) # [1, BS, LAT_DIM] + + # Audio Embeddings + if self.mfcc_input: # TODO: is it actually the raw mfccs? + emb_audio = y['audio_rep'] # [BS, MFCC_DIM, 1, CHUNK_LEN] + interp_reps = emb_audio.permute(0, 3, 2, 1) # [BS, CHUNK_LEN, 1, 768] + emb_audio = self.wavlm_encoder(interp_reps) # [BS, CHUNK_LEN, 1, WAVLM_PROJ_DIM] + emb_audio = emb_audio.permute(0, 3, 2, 1) # [BS, WAVLM_PROJ_DIM, 1, CHUNK_LEN] + elif self.use_wav_enc: + emb_audio = self.wav_encoder(y['audio']) # [BS, WAV_ENC_DIM, 1, CHUNK_LEN] + raise NotImplementedError # TODO: Resolve CNNs + elif self.use_wavlm: + interp_reps = y['audio_rep'] # [BS, 768, 1, CHUNK_LEN] + interp_reps = interp_reps.permute(0, 3, 2, 1) # [BS, CHUNK_LEN, 1, 768] + emb_audio = self.wavlm_encoder(interp_reps) # [BS, CHUNK_LEN, 1, WAVLM_PROJ_DIM] + emb_audio = emb_audio.permute(0, 3, 2, 1) # [BS, WAVLM_PROJ_DIM, 1, CHUNK_LEN] + else: + raise NotImplementedError + emb_audio = emb_audio.squeeze(2) # [BS, AUDIO_DIM, CHUNK_LEN], (AUDIO_DIM = MFCC_DIM or WAV_ENC_DIM or WAVLM_PROJ_DIM) + emb_audio = emb_audio.permute((2, 0, 1)) # [CHUNK_LEN, BS, AUDIO_DIM] + + # Pose Embeddings + emb_pose = self.input_process(x) # [CHUNK_LEN, BS, LAT_DIM] + + ############################# + #### FEATURE AGGREGATION #### + ############################# + + # Cat Pose w/ Audio (Fine-Grained) Embeddings + if self.use_vad: + fg_embs = torch.cat((emb_pose, emb_audio, emb_vad), axis=2) # [CHUNK_LEN, BS, LAT_DIM + AUDIO_DIM + LAT_DIM] + else: + fg_embs = torch.cat((emb_pose, emb_audio), axis=2) # [CHUNK_LEN, BS, LAT_DIM + AUDIO_DIM] + + # Cat Seed w/ Text Embeddings (if exist) + if self.use_text: + embs_stxt = torch.cat((emb_text,emb_seed),axis=1) # [BS, LAT_DIM] + else: + embs_stxt = emb_seed # [BS, LAT_DIM] + + # Sum All Coarse-Grained Embeddings (t + Seed w/ Text) + coa_embs = (embs_stxt + emb_t) # [1, BS, LAT_DIM] + + # Repeat Coarse-Grained Summation (to match chunk) + coa_embs_rep = coa_embs.repeat(nframes, 1, 1) # [CHUNK_LEN, BS, LAT_DIM] + + # Concatenate All to form feature inputs + embs = torch.cat((fg_embs, coa_embs_rep), axis=2) # [CHUNK_LEN, BS, LAT_DIM + AUDIO_DIM + LAT_DIM + LAT_DIM] of 2* LAT_DIM If no VAD + + # Project to Latent Dim + xseq = self.project_to_lat(embs) # [CHUNK_LEN, BS, LAT_DIM] + + ###################### + #### DENOISE PASS #### + ###################### + + ## Data Reshaping (Insert multiple att heads) + xseq = xseq.permute(1, 0, 2) # [BS, CHUNK_LEN, LAT_DIM] + xseq = xseq.view(bs, nframes, self.cl_head, -1) # [BS, CHUNK_LEN, CL_HEAD, LAT_DIM / CL_HEAD] + xseq = xseq.permute(0, 2, 1, 3) # [BS, CL_HEAD, CHUNK_LEN, LAT_DIM / CL_HEAD] + xseq = xseq.reshape(bs*self.cl_head, nframes, -1) # [BS * CL_HEAD, CHUNK_LEN, LAT_DIM / CL_HEAD] + + ## RPE Embeddings + pos_emb = self.rel_pos(xseq) # [CHUNK_LEN, BS] O CORRETO É [CHUNK_LEN, LAT_DIM / CL_HEAD] + xseq, _ = apply_rotary_pos_emb(xseq, xseq, pos_emb) # [LAT_DIM, CHUNK_LEN, BS] O CORRETO É [BS * CL_HEAD, CHUNK_LEN, LAT_DIM / CL_HEAD] + + ## Apply Cross Local Attention + packed_shape = [torch.Size([bs, self.cl_head])] # [1] = [torch.Size([BS, CL_HEAD]) + mask_local = torch.ones(bs, nframes).bool().to(device=xseq.device) # [BS, CHUNK_LEN] + xseq = self.cross_local_attention(xseq, xseq, xseq, + packed_shape=packed_shape, mask=mask_local) # [BS, CL_HEAD, CHUNK_LEN, LAT_DIM / CL_HEAD] + + # Data Reshaping to cat Global Information + xseq = xseq.permute(0, 2, 1, 3) # [BS, CHUNK_LEN, CL_HEAD, LAT_DIM / CL_HEAD] + xseq = xseq.reshape(bs, nframes, -1) # [BS, CHUNK_LEN, LAT_DIM] + xseq = xseq.permute(1, 0, 2) # [CHUNK_LEN, BS, LAT_DIM] + + # Concat Coarse Grained Embeddings + xseq = torch.cat((coa_embs, xseq), axis=0) # [CHUNK_LEN+1, BS, LAT_DIM] + + # Data Reshaping (Insert multiple att heads) + xseq = xseq.permute(1, 0, 2) # [BS, CHUNK_LEN+1, LAT_DIM] + xseq = xseq.view(bs, nframes + 1, self.cl_head, -1) # [BS, CHUNK_LEN+1, CL_HEAD, LAT_DIM / CL_HEAD] + xseq = xseq.permute(0, 2, 1, 3) # [BS, CL_HEAD, CHUNK_LEN+1, LAT_DIM / CL_HEAD] + xseq = xseq.reshape(bs*self.cl_head, nframes + 1, -1) # [BS * CL_HEAD, CHUNK_LEN+1, LAT_DIM / CL_HEAD] + + # RPE Embeddings + pos_emb = self.rel_pos(xseq) # [CHUNK_LEN+1, BS] O CORRETO É [CHUNK_LEN+1, LAT_DIM / CL_HEAD] + xseq, _ = apply_rotary_pos_emb(xseq, xseq, pos_emb) # [LAT_DIM, CHUNK_LEN+1, BS] O CORRETO É [BS * CL_HEAD, CHUNK_LEN+1, LAT_DIM / CL_HEAD] + + # Data Reshaping + xseq_rpe = xseq.reshape(bs,self.cl_head,nframes+1,-1) # [BS, CL_HEAD, CHUNK_LEN+1, LAT_DIM / CL_HEAD] + xseq = xseq_rpe.permute(0, 2, 1, 3) # [BS, CHUNK_LEN+1, CL_HEAD, LAT_DIM / CL_HEAD] + xseq = xseq.view(bs, nframes + 1, -1) # [BS, CHUNK_LEN+1, LAT_DIM] + xseq = xseq.permute(1, 0, 2) # [CHUNK_LEN+1, BS, LAT_DIM] + + # Self-Attention + output = self.seqTransEncoder(xseq) # [CHUNK_LEN+1, BS, LAT_DIM] + + # Ignore First Token + output = output[1:] # [CHUNK_LEN, BS, LAT_DIM] + + # Linear Output Feature Pass + output = self.output_process(output) # [BS, POSE_DIM, 1, CHUNK_LEN] + + if self.log_train: + + if self.use_text: + mean = torch.mean(emb_text, dim=1) #emb_text: [BS, TEXT_DIM] + self.batch_log['text'] = mean.detach().cpu().numpy() + + if self.use_vad: + mean = torch.mean(torch.mean(emb_vad, dim=0), dim=1) #emb_vad: [CHUNK_LEN, BS, LAT_DIM] + self.batch_log['vad'] = mean.detach().cpu().numpy() + + mean = torch.mean(emb_seed, dim=1) #emb_seed: [BS, LAT_DIM - TEXT_DIM] + self.batch_log['seed'] = mean.detach().cpu().numpy() + + mean = torch.mean(emb_t, dim=2) #emb_t: [1, BS, LAT_DIM] + self.batch_log['timestep'] = mean.detach().cpu().numpy() + + mean = torch.mean(torch.mean(emb_audio, dim=0), dim=1) #emb_audio: [CHUNK_LEN, BS, AUDIO_DIM] + self.batch_log['audio'] = mean.detach().cpu().numpy() + + mean = torch.mean(torch.mean(emb_pose, dim=0), dim=1) #emb_pose: [CHUNK_LEN, BS, LAT_DIM] + self.batch_log['poses'] = mean.detach().cpu().numpy() + + #mean = torch.mean(torch.mean(audiovad, dim=0), dim=1) # [CHUNK_LEN, BS, AUDIO_DIM] + #self.batch_log['audiovad'] = mean.detach().cpu().numpy() + + #std, mean = torch.std_mean(fg_embs, dim=1) # fg embeddings: [CHUNK_LEN, BS, LAT_DIM + AUDIO_DIM + LAT_DIM] + #self.log_fg_embs = [ std.detach().cpu().numpy(), mean.detach().cpu().numpy() ] + #self.batch_log['fg_embs'] = self.log_fg_embs +# + #std, mean = torch.std_mean(coa_embs, dim=1) # coa embeddings: [1, BS, LAT_DIM] + #self.log_coa_embs = [ std.detach().cpu().numpy(), mean.detach().cpu().numpy() ] + #self.batch_log['coa_embs'] = self.log_coa_embs + + std, mean = torch.std_mean(torch.mean(embs, dim=0), dim=0) # embeddings: [CHUNK_LEN, BS, LAT_DIM + AUDIO_DIM + LAT_DIM + LAT_DIM] of 2* LAT_DIM If no VAD + self.log_embs = [ std.detach().cpu().numpy(), mean.detach().cpu().numpy() ] + self.batch_log['embs'] = self.log_embs + + return output + + def parameters_wo_clip(self): + return [p for name, p in self.named_parameters() if not name.startswith('clip_model.')] + + def load_and_freeze_clip(self, clip_version): + clip_model, clip_preprocess = clip.load(clip_version, device='cpu', + jit=False) # Must set jit=False for training + clip.model.convert_weights( + clip_model) # Actually this line is unnecessary since clip by default already on float16 + + # Freeze CLIP weights + clip_model.eval() + for p in clip_model.parameters(): + p.requires_grad = False + + return clip_model + + def mask_cond(self, cond, force_mask=False): + bs, d = cond.shape + if force_mask: + return torch.zeros_like(cond) + elif self.training and self.cond_mask_prob > 0.: + mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_mask_prob).view(bs, 1) # 1-> use null_cond, 0-> use real cond + return cond * (1. - mask) + else: + return cond + + def encode_text(self, raw_text): + # raw_text - list (batch_size length) of strings with input text prompts + device = next(self.parameters()).device + max_text_len = 20 if self.dataset in ['humanml', 'kit'] else None # Specific hardcoding for humanml dataset + if max_text_len is not None: + default_context_length = 77 + context_length = max_text_len + 2 # start_token + 20 + end_token + assert context_length < default_context_length + texts = clip.tokenize(raw_text, context_length=context_length, truncate=True).to(device) # [bs, context_length] # if n_tokens > context_length -> will truncate + # print('texts', texts.shape) + zero_pad = torch.zeros([texts.shape[0], default_context_length-context_length], dtype=texts.dtype, device=texts.device) + texts = torch.cat([texts, zero_pad], dim=1) + # print('texts after pad', texts.shape, texts) + else: + texts = clip.tokenize(raw_text, truncate=True).to(device) # [bs, context_length] # if n_tokens > 77 -> will truncate + return self.clip_model.encode_text(texts).float() + + #def _apply(self, fn): + # super()._apply(fn) + # self.rot2xyz.smpl_model._apply(fn) +# + #def train(self, *args, **kwargs): + # super().train(*args, **kwargs) + # self.rot2xyz.smpl_model.train(*args, **kwargs) + +class PositionalEncoding(nn.Module): + def __init__(self, d_model, dropout=0.1, max_len=5000): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + + self.register_buffer('pe', pe) + + def forward(self, x): + # not used in the final model + x = x + self.pe[:x.shape[0], :] + return self.dropout(x) + +class TimestepEmbedder(nn.Module): + def __init__(self, latent_dim, sequence_pos_encoder): + super().__init__() + self.latent_dim = latent_dim + self.sequence_pos_encoder = sequence_pos_encoder + + time_embed_dim = self.latent_dim + self.time_embed = nn.Sequential( + nn.Linear(self.latent_dim, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + + def forward(self, timesteps): + return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).permute(1, 0, 2) + +class WavEncoder(nn.Module): + ''' + Taken from https://github.com/ai4r/Gesture-Generation-from-Trimodal-Context/ + ''' + def __init__(self): + super().__init__() + self.feat_extractor = nn.Sequential( + nn.Conv1d(1, 16, 15, stride=5, padding=1600, dilation = 1), + nn.BatchNorm1d(16), + nn.LeakyReLU(0.3, inplace=True), + nn.Conv1d(16, 32, 15, stride=5, dilation = 4), + nn.BatchNorm1d(32), + nn.LeakyReLU(0.3, inplace=True), + nn.Conv1d(32, 64, 15, stride=5, dilation = 7), + nn.BatchNorm1d(64), + nn.LeakyReLU(0.3, inplace=True), + nn.Conv1d(64, 32, 15, stride=5, dilation = 13), + ) + + def forward(self, wav_data): # [B, 147000] + wav_data = wav_data.unsqueeze(1) # [B, 1, 147000] + out = self.feat_extractor(wav_data) # [B, 32, 200] + return out.unsqueeze(2) # [B, 32, 1, 200] + + def layer_output_size(self,l_in, padding, kernel_size, dilation, stride): + l_out = int(np.floor((l_in + 2*padding - dilation*(kernel_size-1) - 1)/stride + 1)) + return l_out + +class InputProcess(nn.Module): + def __init__(self, data_rep, input_feats, latent_dim): + super().__init__() + self.data_rep = data_rep + self.input_feats = input_feats + self.latent_dim = latent_dim + self.poseEmbedding = nn.Linear(self.input_feats, self.latent_dim) + if self.data_rep == 'rot_vel': + self.velEmbedding = nn.Linear(self.input_feats, self.latent_dim) + + def forward(self, x): + bs, njoints, nfeats, nframes = x.shape + x = x.permute((3, 0, 1, 2)).reshape(nframes, bs, njoints*nfeats) + + if self.data_rep in ['genea_vec', 'genea_vec+']: + x = self.poseEmbedding(x) # [seqlen, bs, d] + return x + else: + raise NotImplementedError + +class OutputProcess(nn.Module): + def __init__(self, data_rep, input_feats, latent_dim, njoints, nfeats): + super().__init__() + self.data_rep = data_rep + self.input_feats = input_feats + self.latent_dim = latent_dim + self.njoints = njoints + self.nfeats = nfeats + self.poseFinal = nn.Linear(self.latent_dim, self.input_feats) + if self.data_rep == 'rot_vel': + self.velFinal = nn.Linear(self.latent_dim, self.input_feats) + + def forward(self, output): + nframes, bs, d = output.shape + if self.data_rep in ['genea_vec', 'genea_vec+']: + output = self.poseFinal(output) # [CHUNK_LEN, BS, POSE_DIM] + else: + raise NotImplementedError + output = output.reshape(nframes, bs, self.njoints, self.nfeats) # [CHUNK_LEN, BS, POSE_DIM, 1] + output = output.permute(1, 2, 3, 0) + return output + +class SeedPoseEncoder(nn.Module): + def __init__(self, njoints, seed_poses, latent_dim): + super().__init__() + self.njoints = njoints + self.seed_poses = seed_poses + self.latent_dim = latent_dim + self.seed_embed = nn.Linear(self.njoints * self.seed_poses, self.latent_dim) + + def forward(self, x): + x = self.seed_embed(x) + return x \ No newline at end of file diff --git a/sample/generate.py b/sample/generate.py new file mode 100644 index 0000000..80e6f5a --- /dev/null +++ b/sample/generate.py @@ -0,0 +1,232 @@ +# This code is based on https://github.com/openai/guided-diffusion +""" +Generate a large batch of image samples from a model and save them as a large +numpy array. This can be used to produce samples for FID evaluation. +""" +from utils.fixseed import fixseed +import os +import numpy as np +import torch +from utils.parser_util import generate_args +from utils.model_util import create_model_and_diffusion, load_model_wo_clip +from utils import dist_util +from data_loaders.get_data import get_dataset_loader +import shutil +from data_loaders.tensors import gg_collate +import bvhsdk +import utils.rotation_conversions as geometry +from scipy.signal import savgol_filter + +def main(): + args = generate_args() + fixseed(args.seed) + out_path = args.output_dir + name = os.path.basename(os.path.dirname(args.model_path)) + niter = os.path.basename(args.model_path).replace('model', '').replace('.pt', '') + if args.dataset in ['genea2023', 'genea2023+']: + fps = 30 + n_joints = 83 + #TODO: change to receive args.bvh_reference_file + bvhreference = bvhsdk.ReadFile('./dataset/Genea2023/trn/main-agent/bvh/trn_2023_v0_000_main-agent.bvh', skipmotion=True) + else: + raise NotImplementedError + dist_util.setup_dist(args.device) + if out_path == '': + out_path = os.path.join(os.path.dirname(args.model_path), + 'samples_{}_{}_seed{}'.format(name, niter, args.seed)) + if args.text_prompt != '': + out_path += '_' + args.text_prompt.replace(' ', '_').replace('.', '') + elif args.input_text != '': + out_path += '_' + os.path.basename(args.input_text).replace('.txt', '').replace(' ', '_').replace('.', '') + + # Hard-coded takes to be generated + num_samples = 70 + takes_to_generate = np.arange(num_samples) + + + args.batch_size = num_samples # Sampling a single batch from the testset, with exactly args.num_samples + + #inputs_i = [155,271,320,400,500,600,700,800,1145,1185] + + print('Loading dataset...') + data = load_dataset(args, num_samples) + + print("Creating model and diffusion...") + model, diffusion = create_model_and_diffusion(args, data) + + print(f"Loading checkpoints from [{args.model_path}]...") + state_dict = torch.load(args.model_path, map_location='cpu') + load_model_wo_clip(model, state_dict) + + model.to(dist_util.dev()) + model.eval() # disable random masking + + all_motions = [] #np.zeros(shape=(num_samples, n_joints, 3, args.num_frames*chunks_per_take)) + all_motions_rot = [] + all_lengths = [] + all_text = [] + all_audios = [] + + # dummy motion, batch_text, window, batch_audio, batch_audio_rep, dummy seed_poses, max_length + dummy_motion, data_text, chunk_len, data_audio, data_audio_rep, dummy_seed, max_length, vad_vals, takenames = data.dataset.gettestbatch(num_samples) + + chunks_per_take = int(max_length/chunk_len) + for chunk in range(chunks_per_take): # Motion is generated in chunks, for each chunk we load the corresponding data + empty = np.array([]) + inputs = [] + for take in range(num_samples): # For each take we will load the current chunk + vad = vad_vals[take][..., chunk:chunk+chunk_len] if args.use_vad else empty + inputs.append((empty, data_text[take][chunk], chunk_len, empty, data_audio_rep[take][..., chunk:chunk+chunk_len], dummy_seed, vad, takenames[take])) + + _, model_kwargs = gg_collate(inputs) # gt_motion: [num_samples(bs), njoints, 1, chunk_len] + model_kwargs['y'] = {key: val.to(dist_util.dev()) if torch.is_tensor(val) else val for key, val in model_kwargs['y'].items()} #seed: [num_samples(bs), njoints, 1, seed_len] + + if chunk == 0: + pass #send mean pose + else: + model_kwargs['y']['seed'] = sample_out[...,-args.seed_poses:] + + + + print('### Sampling chunk {} of {}'.format(chunk+1, int(max_length/chunk_len))) + + # add CFG scale to batch + if args.guidance_param != 1: # default 2.5 + model_kwargs['y']['scale'] = torch.ones(num_samples, device=dist_util.dev()) * args.guidance_param + + sample_fn = diffusion.p_sample_loop + + sample_out = sample_fn( + model, + (num_samples, model.njoints, model.nfeats, args.num_frames), + clip_denoised=False, + model_kwargs=model_kwargs, + skip_timesteps=0, # 0 is the default value - i.e. don't skip any step + init_image=None, + progress=True, + dump_steps=None, + noise=None, + const_noise=False, + ) # [num_samples(bs), njoints, 1, chunk_len] + + sample = data.dataset.inv_transform(sample_out.cpu().permute(0, 2, 3, 1)).float() # [num_samples(bs), 1, chunk_len, njoints] + + + # Separating positions and rotations + if args.dataset == 'genea2023': + idx_positions = np.asarray([ [i*6+3, i*6+4, i*6+5] for i in range(n_joints) ]).flatten() + idx_rotations = np.asarray([ [i*6, i*6+1, i*6+2] for i in range(n_joints) ]).flatten() + sample, sample_rot = sample[..., idx_positions], sample[..., idx_rotations] + + #rotations + sample_rot = sample_rot.view(sample_rot.shape[:-1] + (-1, 3)) + sample_rot = sample_rot.view(-1, *sample_rot.shape[2:]).permute(0, 2, 3, 1) + + + elif args.dataset == 'genea2023+': + idx_rotations = np.asarray([ [i*9, i*9+1, i*9+2, i*9+3, i*9+4, i*9+5] for i in range(n_joints) ]).flatten() + idx_positions = np.asarray([ [i*9+6, i*9+7, i*9+8] for i in range(n_joints) ]).flatten() + sample, sample_rot = sample[..., idx_positions], sample[..., idx_rotations] # sample_rot: [num_samples(bs), 1, chunk_len, n_joints*6] + + #rotations + sample_rot = sample_rot.view(sample_rot.shape[:-1] + (-1, 6)) # [num_samples(bs), 1, chunk_len, n_joints, 6] + sample_rot = geometry.rotation_6d_to_matrix(sample_rot) # [num_samples(bs), 1, chunk_len, n_joints, 3, 3] + sample_rot = geometry.matrix_to_euler_angles(sample_rot, "ZXY")[..., [1, 2, 0] ]*180/np.pi # [num_samples(bs), 1, chunk_len, n_joints, 3] + sample_rot = sample_rot.view(-1, *sample_rot.shape[2:]).permute(0, 2, 3, 1) # [num_samples(bs)*chunk_len, n_joints, 3] + + else: + raise ValueError(f'Unknown dataset: {args.dataset}') + + #positions + sample = sample.view(sample.shape[:-1] + (-1, 3)) # [num_samples(bs), 1, chunk_len, n_joints/3, 3] + sample = sample.view(-1, *sample.shape[2:]).permute(0, 2, 3, 1) # [num_samples(bs), n_joints/3, 3, chunk_len] + + text_key = 'text' if 'text' in model_kwargs['y'] else 'action_text' + all_text += model_kwargs['y'][text_key] + + all_audios.append(model_kwargs['y']['audio'].cpu().numpy()) + all_motions.append(sample.cpu().numpy()) + all_motions_rot.append(sample_rot.cpu().numpy()) + all_lengths.append(model_kwargs['y']['lengths'].cpu().numpy()) + + + + all_audios = data_audio + all_motions = np.concatenate(all_motions, axis=3) + all_motions_rot = np.concatenate(all_motions_rot, axis=3) + all_lengths = np.concatenate(all_lengths, axis=0) + + # Smooth chunk transitions + inter_range = 10 #interpolation range in frames + for transition in np.arange(1, chunks_per_take-1)*args.num_frames: + all_motions[..., transition:transition+2] = np.tile(np.expand_dims(all_motions[..., transition]/2 + all_motions[..., transition-1]/2,-1),2) + all_motions_rot[..., transition:transition+2] = np.tile(np.expand_dims(all_motions_rot[..., transition]/2 + all_motions_rot[..., transition-1]/2,-1),2) + for i, s in enumerate(np.linspace(0, 1, inter_range-1)): + forward = transition-inter_range+i + backward = transition+inter_range-i + all_motions[..., forward] = all_motions[..., forward]*(1-s) + all_motions[:, :, :, transition-1]*s + all_motions[..., backward] = all_motions[..., backward]*(1-s) + all_motions[:, :, :, transition]*s + all_motions_rot[..., forward] = all_motions_rot[..., forward]*(1-s) + all_motions_rot[:, :, :, transition-1]*s + all_motions_rot[..., backward] = all_motions_rot[..., backward]*(1-s) + all_motions_rot[:, :, :, transition]*s + + all_motions = savgol_filter(all_motions, 9, 3, axis=-1) + all_motions_rot = savgol_filter(all_motions_rot, 9, 3, axis=-1) + + if os.path.exists(out_path): + shutil.rmtree(out_path) + os.makedirs(out_path) + print(f"saving results to [{out_path}]") + + npy_path = os.path.join(out_path, 'results.npy') + + np.save(npy_path, + {'motion': all_motions, 'text': all_text, 'lengths': all_lengths, + 'num_samples': len(takes_to_generate), 'num_chunks': chunks_per_take}) + with open(npy_path.replace('.npy', '.txt'), 'w') as fw: + fw.write('\n'.join(all_text)) + with open(npy_path.replace('.npy', '_len.txt'), 'w') as fw: + fw.write('\n'.join([str(l) for l in all_lengths])) + + + for i, take in enumerate(takes_to_generate): + final_frame = data.dataset.frames[i] + save_file = data.dataset.takes[take][0] + print('Saving take {}: {}'.format(i, save_file)) + positions = all_motions[i] + positions = positions[..., :final_frame] + positions = positions.transpose(2, 0, 1) + + # Saving generated motion as bvh file + rotations = all_motions_rot[i] # [njoints/3, 3, chunk_len*chunks] + rotations = rotations[..., :final_frame] + rotations = rotations.transpose(2, 0, 1) # [chunk_len*chunks, njoints/3, 3] + bvhreference.frames = rotations.shape[0] + for j, joint in enumerate(bvhreference.getlistofjoints()): + joint.rotation = rotations[:, j, :] + joint.translation = np.tile(joint.offset, (bvhreference.frames, 1)) + bvhreference.root.translation = positions[:, 0, :] + #bvhreference.root.children[0].translation = positions[:, 1, :] + print('Saving bvh file...') + bvhsdk.WriteBVH(bvhreference, path=out_path, name=save_file, frametime=1/fps, refTPose=False) + + abs_path = os.path.abspath(out_path) + print(f'[Done] Results are at [{abs_path}]') + + +def load_dataset(args, batch_size): + data = get_dataset_loader(name=args.dataset, + data_dir=args.data_dir, + batch_size=batch_size, + num_frames=args.num_frames, + split='tst', + hml_mode='text_only', + step = args.num_frames, + use_wavlm=args.use_wavlm, + use_vad = args.use_vad, + vadfromtext = args.vadfromtext,) + #data.fixed_length = n_frames + return data + + +if __name__ == "__main__": + main() diff --git a/speechbrain/Dockerfile b/speechbrain/Dockerfile new file mode 100644 index 0000000..5190c1d --- /dev/null +++ b/speechbrain/Dockerfile @@ -0,0 +1,28 @@ +FROM pytorch/pytorch:1.6.0-cuda10.1-cudnn7-devel + +ENV PATH="/root/miniconda3/bin:${PATH}" +ARG PATH="/root/miniconda3/bin:${PATH}" + +RUN rm /etc/apt/sources.list.d/cuda.list + +RUN apt-get update +RUN apt-get install -y wget git nano ffmpeg + +RUN wget \ + https://repo.anaconda.com/miniconda/Miniconda3-py37_4.8.3-Linux-x86_64.sh \ + && mkdir /root/.conda \ + && bash Miniconda3-py37_4.8.3-Linux-x86_64.sh -b \ + && rm -f Miniconda3-py37_4.8.3-Linux-x86_64.sh + +RUN conda --version + +WORKDIR /root +COPY environment.yml /root + +RUN conda install tqdm -f +RUN conda update conda +RUN conda install pip +RUN conda --version + +RUN rm -rf $(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())")/ruamel* +RUN pip install speechbrain diff --git a/speechbrain/environment.yml b/speechbrain/environment.yml new file mode 100644 index 0000000..a299bb4 --- /dev/null +++ b/speechbrain/environment.yml @@ -0,0 +1,8 @@ +name: speechbrain +channels: + - pytorch + - anaconda + - conda-forge + - defaults + - pip: + - speechbrain diff --git a/speechbrain_vad_container.sh b/speechbrain_vad_container.sh new file mode 100644 index 0000000..a273ff8 --- /dev/null +++ b/speechbrain_vad_container.sh @@ -0,0 +1,11 @@ +while getopts g:n:p: flag +do + case "${flag}" in + g) gpu=${OPTARG};; + n) number=${OPTARG};; + p) port=${OPTARG};; + esac +done +echo "Running container speechbrain_container_$number on gpu $gpu and port $port"; + +nvidia-docker run --rm -it -e NVIDIA_VISIBLE_DEVICES=$gpu --runtime=nvidia --userns=host --shm-size 64G -v /work/rodolfo.tonoli/ggvad-genea2023:/workspace/ggvad/ -p $port --name ggvad_container$number speechbrain_vad:latest /bin/bash \ No newline at end of file diff --git a/train/train_mdm.py b/train/train_mdm.py new file mode 100644 index 0000000..4b27b6c --- /dev/null +++ b/train/train_mdm.py @@ -0,0 +1,58 @@ +# This code is based on https://github.com/openai/guided-diffusion +""" +Train a diffusion model on images. +""" + +import os +import json +from utils.fixseed import fixseed +from utils.parser_util import train_args +from utils import dist_util +from train.training_loop import TrainLoop +from data_loaders.get_data import get_dataset_loader +from utils.model_util import create_model_and_diffusion +import numpy as np + +def main(): + args = train_args() + fixseed(args.seed) + + if args.save_dir is None: + raise FileNotFoundError('save_dir was not specified.') + elif os.path.exists(args.save_dir) and not args.overwrite: + raise FileExistsError('save_dir [{}] already exists.'.format(args.save_dir)) + elif not os.path.exists(args.save_dir): + os.makedirs(args.save_dir) + args_path = os.path.join(args.save_dir, 'args.json') + with open(args_path, 'w') as fw: + json.dump(vars(args), fw, indent=4, sort_keys=True) + + if args.wandb: + projectname = os.path.basename(os.path.normpath(args.save_dir)) + import wandb + wandb.login(anonymous="allow") + wandb.init(project='ggvad-genea2023', config=vars(args)) + args.wandb = wandb + + dist_util.setup_dist(args.device) + + print("creating data loader...") + data = get_dataset_loader(name=args.dataset, + data_dir=args.data_dir, + batch_size=args.batch_size, + num_frames=args.num_frames, + step=args.step, + use_wavlm=args.use_wavlm, + use_vad=args.use_vad, + vadfromtext=args.vadfromtext) + + print("creating model and diffusion...") + model, diffusion = create_model_and_diffusion(args, data) + model.to(dist_util.dev()) + + print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters_wo_clip()) / 1000000.0)) + print("Training...") + TrainLoop(args, model, diffusion, data).run_loop() + +if __name__ == "__main__": + main() diff --git a/train/training_loop.py b/train/training_loop.py new file mode 100644 index 0000000..ca527e0 --- /dev/null +++ b/train/training_loop.py @@ -0,0 +1,376 @@ +import copy +import functools +import os +import time +from types import SimpleNamespace +import numpy as np + +import blobfile as bf +import torch +from torch.optim import AdamW + +from diffusion import logger +from utils import dist_util +from diffusion.fp16_util import MixedPrecisionTrainer +from diffusion.resample import LossAwareSampler, UniformSampler +from tqdm import tqdm +from diffusion.resample import create_named_schedule_sampler +from eval import eval_genea +from data_loaders.get_data import get_dataset_loader +import utils.rotation_conversions as geometry + + +# For ImageNet experiments, this was a good default value. +# We found that the lg_loss_scale quickly climbed to +# 20-21 within the first ~1K steps of training. +INITIAL_LOG_LOSS_SCALE = 20.0 + + +class TrainLoop: + def __init__(self, args, model, diffusion, data): + self.args = args + self.dataset = args.dataset + self.model = model + self.diffusion = diffusion + self.data = data + self.batch_size = args.batch_size + self.microbatch = args.batch_size # deprecating this option + self.lr = args.lr + self.log_interval = args.log_interval + self.save_interval = args.save_interval + self.resume_checkpoint = args.resume_checkpoint + self.use_fp16 = False # deprecating this option + self.fp16_scale_growth = 1e-3 # deprecating this option + self.weight_decay = args.weight_decay + self.lr_anneal_steps = args.lr_anneal_steps + self.log_wandb = args.wandb + if self.log_wandb: + self.genea_evaluator = eval_genea.GeneaEvaluator(args, self.model, self.diffusion) + + self.step = 0 + self.resume_step = 0 + self.global_batch = self.batch_size # * dist.get_world_size() + self.num_steps = args.num_steps + self.num_epochs = self.num_steps // len(self.data) + 1 + + self.sync_cuda = torch.cuda.is_available() + + self._load_and_sync_parameters() + self.mp_trainer = MixedPrecisionTrainer( + model=self.model, + use_fp16=self.use_fp16, + fp16_scale_growth=self.fp16_scale_growth, + ) + + self.save_dir = args.save_dir + self.overwrite = args.overwrite + + self.opt = AdamW( + self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay + ) + if self.resume_step: + self._load_optimizer_state() + # Model was resumed, either due to a restart or a checkpoint + # being specified at the command line. + + self.device = torch.device("cpu") + if torch.cuda.is_available() and dist_util.dev() != 'cpu': + self.device = torch.device(dist_util.dev()) + + self.schedule_sampler_type = 'uniform' + self.schedule_sampler = create_named_schedule_sampler(self.schedule_sampler_type, diffusion) + self.eval_wrapper, self.eval_data, self.eval_gt_data = None, None, None + self.use_ddp = False + self.ddp_model = self.model + + def _load_and_sync_parameters(self): + resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint + + if resume_checkpoint: + self.resume_step = parse_resume_step_from_filename(resume_checkpoint) + logger.log(f"loading model from checkpoint: {resume_checkpoint}...") + check = dist_util.load_state_dict(resume_checkpoint, map_location=dist_util.dev()) + missing_keys, unexpected_keys = self.model.load_state_dict(check, strict=False) + assert len(unexpected_keys) == 0 + assert all([k.startswith('clip_model.') for k in missing_keys]) + #self.model.load_state_dict( + # dist_util.load_state_dict( + # resume_checkpoint, map_location=dist_util.dev() + # ) + #) + + def _load_optimizer_state(self): + main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint + opt_checkpoint = bf.join( + bf.dirname(main_checkpoint), f"opt{self.resume_step:09}.pt" + ) + if bf.exists(opt_checkpoint): + logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") + state_dict = dist_util.load_state_dict( + opt_checkpoint, map_location=dist_util.dev() + ) + self.opt.load_state_dict(state_dict) + + def run_loop(self): + + for epoch in range(self.num_epochs): + print(f'Starting epoch {epoch}') + + if self.log_wandb: + self.model.log_train = True + size = len(self.data)*self.batch_size + dictlog = {'text': np.zeros(size), + 'vad': np.zeros(size), + 'seed': np.zeros(size), + 'timestep':np.zeros(size), + 'audio': np.zeros(size), + 'poses': np.zeros(size)} + + for stepcount, (motion, cond) in enumerate(tqdm(self.data)): + if not (not self.lr_anneal_steps or self.step + self.resume_step < self.lr_anneal_steps): + break + + motion = motion.to(self.device) + cond['y'] = {key: val.to(self.device) if torch.is_tensor(val) else val for key, val in cond['y'].items()} + + self.run_step(motion, cond) + + if self.log_wandb: + i = self.batch_size*stepcount + e = i + self.batch_size + dictlog['text'][i:e] = self.model.batch_log['text'] if self.model.batch_log['text'] != [] else np.zeros(self.batch_size) + dictlog['vad'][i:e] = self.model.batch_log['vad'] if self.model.batch_log['vad'] != [] else np.zeros(self.batch_size) + dictlog['seed'][i:e] = self.model.batch_log['seed'] + dictlog['timestep'][i:e] = self.model.batch_log['timestep'] + dictlog['audio'][i:e] = self.model.batch_log['audio'] + dictlog['poses'][i:e] = self.model.batch_log['poses'] + + if self.step % self.log_interval == 0 and self.log_wandb: + + mean_, std_ = self.model.batch_log['embs'][1], self.model.batch_log['embs'][0] + mean = [ [str(i), v] for i,v in enumerate(mean_)] + std = [ [str(i), v] for i,v in enumerate(std_)] + + table_mean = self.log_wandb.wandb.Table(data=mean, columns=['dim', 'mean']) + table_std = self.log_wandb.wandb.Table(data=std, columns=['dim', 'std']) + + mean_scatter = self.log_wandb.wandb.plot.scatter(table_mean, x='dim', y='mean', title='embs mean') + std_scatter = self.log_wandb.wandb.plot.scatter(table_std, x='dim', y='std', title='embs std') + + self.log_wandb.wandb.log({'embs_mean_plot': mean_scatter, 'embs_std_plot': std_scatter}) + for k,v in logger.get_current().name2val.items(): + if k == 'loss': + print('step[{}]: loss[{:0.5f}]'.format(self.step+self.resume_step, v)) + if self.log_wandb: + self.log_wandb.wandb.log({'loss': v, 'step': self.step+self.resume_step}) + + if k in ['step', 'samples'] or '_q' in k: + continue + #else: + # self.train_platform.report_scalar(name=k, value=v, iteration=self.step, group_name='Loss') + + if self.step % self.save_interval == 0: + self.save() + + if self.log_wandb: + print('Logging epoch wandb') + stds = np.zeros(len(dictlog)) + + for i, (k,v) in enumerate(dictlog.items()): + self.log_wandb.wandb.log({k+'_mean': v}) + + stds = [ [str(i), np.std(v)] for i,v in enumerate(dictlog.values())] + table_std = self.log_wandb.wandb.Table(data=stds, columns=['dim', 'std']) + std_scatter = self.log_wandb.wandb.plot.scatter(table_std, x='dim', y='std', title='trn data emb std over batch') + self.log_wandb.wandb.log({'epoch': epoch, 'trn_data_emb_std_plot': std_scatter}) + + self.model.eval() + self.valwandb() + self.model.train() + + self.step += 1 + + if not (not self.lr_anneal_steps or self.step + self.resume_step < self.lr_anneal_steps): + break + + # Save the last checkpoint if it wasn't already saved. + if (self.step - 1) % self.save_interval != 0: + self.save() + #self.evaluate() + + def valwandb(self): + assert self.log_wandb + fgd, histfig = self.genea_evaluator.eval() + self.log_wandb.wandb.log({'FGD Validation': fgd, 'Rot Vel Hist': self.log_wandb.wandb.Image(histfig)}) + + + def run_debugemb(self): + print(f'Starting debug embedding') + batchs = 10 + for i, (motion, cond) in enumerate(tqdm(self.data)): + motion = motion.to(self.device) + cond['y'] = {key: val.to(self.device) if torch.is_tensor(val) else val for key, val in cond['y'].items()} + + self.run_step(motion, cond) + if i>= batchs: + break + return self.model.debug_seed,self.model.debug_text,self.model.debug_timestep,self.model.debug_audio,self.model.debug_vad,self.model.debug_poses + + + #def evaluate(self): + # if not self.args.eval_during_training: + # return + # start_eval = time.time() + # if self.eval_wrapper is not None: + # print('Running evaluation loop: [Should take about 90 min]') + # log_file = os.path.join(self.save_dir, f'eval_humanml_{(self.step + self.resume_step):09d}.log') + # diversity_times = 300 + # mm_num_times = 0 # mm is super slow hence we won't run it during training + # eval_dict = eval_humanml.evaluation( + # self.eval_wrapper, self.eval_gt_data, self.eval_data, log_file, + # replication_times=self.args.eval_rep_times, diversity_times=diversity_times, mm_num_times=mm_num_times, run_mm=False) + # print(eval_dict) + # for k, v in eval_dict.items(): + # if k.startswith('R_precision'): + # for i in range(len(v)): + # self.train_platform.report_scalar(name=f'top{i + 1}_' + k, value=v[i], + # iteration=self.step + self.resume_step, + # group_name='Eval') + # else: + # self.train_platform.report_scalar(name=k, value=v, iteration=self.step + self.resume_step, + # group_name='Eval') +# + # elif self.dataset in ['humanact12', 'uestc']: + # eval_args = SimpleNamespace(num_seeds=self.args.eval_rep_times, num_samples=self.args.eval_num_samples, + # batch_size=self.args.eval_batch_size, device=self.device, guidance_param = 1, + # dataset=self.dataset, unconstrained=self.args.unconstrained, + # model_path=os.path.join(self.save_dir, self.ckpt_file_name())) + # eval_dict = eval_humanact12_uestc.evaluate(eval_args, model=self.model, diffusion=self.diffusion, data=self.data.dataset) + # print(f'Evaluation results on {self.dataset}: {sorted(eval_dict["feats"].items())}') + # for k, v in eval_dict["feats"].items(): + # if 'unconstrained' not in k: + # self.train_platform.report_scalar(name=k, value=np.array(v).astype(float).mean(), iteration=self.step, group_name='Eval') + # else: + # self.train_platform.report_scalar(name=k, value=np.array(v).astype(float).mean(), iteration=self.step, group_name='Eval Unconstrained') +# + # end_eval = time.time() + # print(f'Evaluation time: {round(end_eval-start_eval)/60}min') + + + def run_step(self, batch, cond): + self.forward_backward(batch, cond) + self.mp_trainer.optimize(self.opt) + self._anneal_lr() + self.log_step() + + def forward_backward(self, batch, cond): + self.mp_trainer.zero_grad() + for i in range(0, batch.shape[0], self.microbatch): + # Eliminates the microbatch feature + assert i == 0 + assert self.microbatch == self.batch_size + micro = batch + micro_cond = cond + last_batch = (i + self.microbatch) >= batch.shape[0] + t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) + + compute_losses = functools.partial( + self.diffusion.training_losses, + self.ddp_model, + micro, # [bs, ch, image_size, image_size] + t, # [bs](int) sampled timesteps + model_kwargs=micro_cond, + dataset=self.data.dataset + ) + + if last_batch or not self.use_ddp: + losses = compute_losses() + else: + with self.ddp_model.no_sync(): + losses = compute_losses() + + if isinstance(self.schedule_sampler, LossAwareSampler): + self.schedule_sampler.update_with_local_losses( + t, losses["loss"].detach() + ) + + loss = (losses["loss"] * weights).mean() + log_loss_dict( + self.diffusion, t, {k: v * weights for k, v in losses.items()} + ) + self.mp_trainer.backward(loss) + + def _anneal_lr(self): + if not self.lr_anneal_steps: + return + frac_done = (self.step + self.resume_step) / self.lr_anneal_steps + lr = self.lr * (1 - frac_done) + for param_group in self.opt.param_groups: + param_group["lr"] = lr + + def log_step(self): + logger.logkv("step", self.step + self.resume_step) + logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) + + + def ckpt_file_name(self): + return f"model{(self.step+self.resume_step):09d}.pt" + + + def save(self): + def save_checkpoint(params): + state_dict = self.mp_trainer.master_params_to_state_dict(params) + + # Do not save CLIP weights + clip_weights = [e for e in state_dict.keys() if e.startswith('clip_model.')] + for e in clip_weights: + del state_dict[e] + + logger.log(f"saving model...") + filename = self.ckpt_file_name() + with bf.BlobFile(bf.join(self.save_dir, filename), "wb") as f: + torch.save(state_dict, f) + + save_checkpoint(self.mp_trainer.master_params) + + with bf.BlobFile( + bf.join(self.save_dir, f"opt{(self.step+self.resume_step):09d}.pt"), + "wb", + ) as f: + torch.save(self.opt.state_dict(), f) + + +def parse_resume_step_from_filename(filename): + """ + Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the + checkpoint's number of steps. + """ + split = filename.split("model") + if len(split) < 2: + return 0 + split1 = split[-1].split(".")[0] + try: + return int(split1) + except ValueError: + return 0 + + +def get_blob_logdir(): + # You can change this to be a separate path to save checkpoints to + # a blobstore or some external drive. + return logger.get_dir() + + +def find_resume_checkpoint(): + # On your infrastructure, you may want to override this to automatically + # discover the latest checkpoint on your blob storage, etc. + return None + + +def log_loss_dict(diffusion, ts, losses): + for key, values in losses.items(): + logger.logkv_mean(key, values.mean().item()) + # Log the quantiles (four quartiles, in particular). + for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): + quartile = int(4 * sub_t / diffusion.num_timesteps) + logger.logkv_mean(f"{key}_q{quartile}", sub_loss) diff --git a/utils/PYTORCH3D_LICENSE b/utils/PYTORCH3D_LICENSE new file mode 100644 index 0000000..bed0ceb --- /dev/null +++ b/utils/PYTORCH3D_LICENSE @@ -0,0 +1,30 @@ +BSD License + +For PyTorch3D software + +Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name Facebook nor the names of its contributors may be used to + endorse or promote products derived from this software without specific + prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/utils/dist_util.py b/utils/dist_util.py new file mode 100644 index 0000000..9f5580a --- /dev/null +++ b/utils/dist_util.py @@ -0,0 +1,77 @@ +""" +Helpers for distributed training. +""" + +import socket + +import torch as th +import torch.distributed as dist + +# Change this to reflect your cluster layout. +# The GPU for a given rank is (rank % GPUS_PER_NODE). +GPUS_PER_NODE = 8 + +SETUP_RETRY_COUNT = 3 + +used_device = 0 + +def setup_dist(device=0): + """ + Setup a distributed process group. + """ + global used_device + used_device = device + if dist.is_initialized(): + return + # os.environ["CUDA_VISIBLE_DEVICES"] = str(device) # f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}" + + # comm = MPI.COMM_WORLD + # backend = "gloo" if not th.cuda.is_available() else "nccl" + + # if backend == "gloo": + # hostname = "localhost" + # else: + # hostname = socket.gethostbyname(socket.getfqdn()) + # os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) + # os.environ["RANK"] = str(comm.rank) + # os.environ["WORLD_SIZE"] = str(comm.size) + + # port = comm.bcast(_find_free_port(), root=used_device) + # os.environ["MASTER_PORT"] = str(port) + # dist.init_process_group(backend=backend, init_method="env://") + + +def dev(): + """ + Get the device to use for torch.distributed. + """ + global used_device + if th.cuda.is_available() and used_device>=0: + return th.device(f"cuda:{used_device}") + return th.device("cpu") + + +def load_state_dict(path, **kwargs): + """ + Load a PyTorch file without redundant fetches across MPI ranks. + """ + return th.load(path, **kwargs) + + +def sync_params(params): + """ + Synchronize a sequence of Tensors across ranks from rank 0. + """ + for p in params: + with th.no_grad(): + dist.broadcast(p, 0) + + +def _find_free_port(): + try: + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] + finally: + s.close() diff --git a/utils/fixseed.py b/utils/fixseed.py new file mode 100644 index 0000000..6f44f6c --- /dev/null +++ b/utils/fixseed.py @@ -0,0 +1,18 @@ +import numpy as np +import torch +import random + + +def fixseed(seed): + torch.backends.cudnn.benchmark = False + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + +# SEED = 10 +# EVALSEED = 0 +# # Provoc warning: not fully functionnal yet +# # torch.set_deterministic(True) +# torch.backends.cudnn.benchmark = False +# fixseed(SEED) diff --git a/utils/model_util.py b/utils/model_util.py new file mode 100644 index 0000000..0b16a17 --- /dev/null +++ b/utils/model_util.py @@ -0,0 +1,76 @@ +from model.mdm import MDM +from diffusion import gaussian_diffusion as gd +from diffusion.respace import SpacedDiffusion, space_timesteps + + +def load_model_wo_clip(model, state_dict): + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + assert len(unexpected_keys) == 0 + assert all([k.startswith('clip_model.') for k in missing_keys]) + + +def create_model_and_diffusion(args, data): + model = MDM(**get_model_args(args, data)) + diffusion = create_gaussian_diffusion(args) + return model, diffusion + + +def get_model_args(args, data): + + # default args + clip_version = 'ViT-B/32' + + if args.dataset in ['genea2023']: + data_rep = 'genea_vec' + njoints = 498 + nfeats = 1 + elif args.dataset in ['genea2023+']: + data_rep = 'genea_vec+' + njoints = 1245 + nfeats = 1 + + return {'modeltype': '', 'njoints': njoints, 'nfeats': nfeats, + 'translation': True, 'pose_rep': 'rot6d', 'glob': True, 'glob_rot': True, + 'latent_dim': args.latent_dim, 'ff_size': 1024, 'num_layers': args.layers, 'num_heads': 4, + 'dropout': 0.1, 'activation': "gelu", 'data_rep': data_rep, 'cond_mask_prob': args.cond_mask_prob, + 'clip_version': clip_version, 'dataset': args.dataset, + 'use_text': args.use_text, 'mfcc_input': args.mfcc_input, 'use_wavlm': args.use_wavlm, 'use_vad':args.use_vad, + 'use_wav_enc':args.use_wav_enc, 'seed_poses': args.seed_poses} + + +def create_gaussian_diffusion(args): + # default params + predict_xstart = True # we always predict x_start (a.k.a. x0), that's our deal! + steps = 1000 + scale_beta = 1. # no scaling + timestep_respacing = '' # can be used for ddim sampling, we don't use it. + learn_sigma = False + rescale_timesteps = False + + betas = gd.get_named_beta_schedule(args.noise_schedule, steps, scale_beta) + loss_type = gd.LossType.MSE + + if not timestep_respacing: + timestep_respacing = [steps] + + return SpacedDiffusion( + use_timesteps=space_timesteps(steps, timestep_respacing), + betas=betas, + model_mean_type=( + gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X + ), + model_var_type=( + ( + gd.ModelVarType.FIXED_LARGE + if not args.sigma_small + else gd.ModelVarType.FIXED_SMALL + ) + if not learn_sigma + else gd.ModelVarType.LEARNED_RANGE + ), + loss_type=loss_type, + rescale_timesteps=rescale_timesteps, + lambda_vel=args.lambda_vel, + lambda_rcxyz=args.lambda_rcxyz, + lambda_fc=args.lambda_fc, + ) \ No newline at end of file diff --git a/utils/parser_util.py b/utils/parser_util.py new file mode 100644 index 0000000..2ed3c1d --- /dev/null +++ b/utils/parser_util.py @@ -0,0 +1,249 @@ +from argparse import ArgumentParser +import argparse +import os +import json + + +def parse_and_load_from_model(parser): + # args according to the loaded model + # do not try to specify them from cmd line since they will be overwritten + add_data_options(parser) + add_model_options(parser) + add_diffusion_options(parser) + args = parser.parse_args() + args_to_overwrite = [] + for group_name in ['dataset', 'model', 'diffusion']: + args_to_overwrite += get_args_per_group_name(parser, args, group_name) + + # load args from model + model_path = get_model_path_from_args() + args_path = os.path.join(os.path.dirname(model_path), 'args.json') + assert os.path.exists(args_path), 'Arguments json file was not found!' + with open(args_path, 'r') as fr: + model_args = json.load(fr) + + for a in args_to_overwrite: + if a in model_args.keys(): + setattr(args, a, model_args[a]) + else: + print('Warning: was not able to load [{}], using default value [{}] instead.'.format(a, args.__dict__[a])) + + if args.cond_mask_prob == 0: + args.guidance_param = 1 + return args + + +def get_args_per_group_name(parser, args, group_name): + for group in parser._action_groups: + if group.title == group_name: + group_dict = {a.dest: getattr(args, a.dest, None) for a in group._group_actions} + return list(argparse.Namespace(**group_dict).__dict__.keys()) + return ValueError('group_name was not found.') + +def get_model_path_from_args(): + try: + dummy_parser = ArgumentParser() + dummy_parser.add_argument('model_path') + dummy_args, _ = dummy_parser.parse_known_args() + return dummy_args.model_path + except: + raise ValueError('model_path argument must be specified.') + + +def add_base_options(parser): + group = parser.add_argument_group('base') + group.add_argument("--cuda", default=True, type=bool, help="Use cuda device, otherwise use CPU.") + group.add_argument("--device", default=0, type=int, help="Device id to use.") + group.add_argument("--seed", default=10, type=int, help="For fixing random seed.") + group.add_argument("--batch_size", default=64, type=int, help="Batch size during training.") + group.add_argument("--wandb", default=False, type=bool, help="Track training with wandb.") + + +def add_diffusion_options(parser): + group = parser.add_argument_group('diffusion') + group.add_argument("--noise_schedule", default='cosine', choices=['linear', 'cosine'], type=str, + help="Noise schedule type") + group.add_argument("--diffusion_steps", default=1000, type=int, + help="Number of diffusion steps (denoted T in the paper)") + group.add_argument("--sigma_small", default=True, type=bool, help="Use smaller sigma values.") + + +def add_model_options(parser): + group = parser.add_argument_group('model') + group.add_argument("--arch", default='trans_enc', + choices=['trans_enc', 'trans_dec', 'gru'], type=str, + help="Architecture types as reported in the paper.") + group.add_argument("--emb_trans_dec", default=False, type=bool, + help="For trans_dec architecture only, if true, will inject condition as a class token" + " (in addition to cross-attention).") + group.add_argument("--layers", default=8, type=int, + help="Number of layers.") + group.add_argument("--latent_dim", default=256, type=int, + help="Transformer/GRU width.") + group.add_argument("--cond_mask_prob", default=.1, type=float, + help="The probability of masking the condition during training." + " For classifier-free guidance learning.") + group.add_argument("--lambda_rcxyz", default=0.0, type=float, help="Joint positions loss.") + group.add_argument("--lambda_vel", default=0.0, type=float, help="Joint velocity loss.") + group.add_argument("--lambda_fc", default=0.0, type=float, help="Foot contact loss.") + group.add_argument("--unconstrained", action='store_true', + help="Model is trained unconditionally. That is, it is constrained by neither text nor action. " + "Currently tested on HumanAct12 only.") + group.add_argument("--use_text", action='store_true', help="Unlock text for any usage.") + group.add_argument("--mfcc_input", action='store_true', help="Agregate audio mfcc features in the motion.") + group.add_argument("--use_wav_enc", action='store_true', help="Agregate audio representation extracted w/ conv encoder in the motion.") + group.add_argument("--seed_poses", type=int, default = 10, help="Number of seed poses to condition the beginning of generated motion.") + + + +def add_data_options(parser): + group = parser.add_argument_group('dataset') + group.add_argument("--dataset", default='humanml', choices=['genea2023+','genea2023'], type=str, + help="Dataset name (choose from list).") + group.add_argument("--data_dir", default="./dataset/Genea2023/", type=str, + help="If empty, will use defaults according to the specified dataset.") + group.add_argument("--num_frames", default=120, type=int, + help="Window length to be used in the dataset.") + group.add_argument("--step", default=30, type=int, + help="Step taken to get next window in the take (overlap between successive samples is equal to num_frames - step).") + group.add_argument("--use_wavlm", default=False, type=bool, + help="Use wavlm representations.") + group.add_argument("--use_vad", default=False, type=bool, + help="Use vad speech indicator values.") + group.add_argument("--vadfromtext", default=False, type=bool, + help="Get vad speech indicator values from text.") + group.add_argument("--bvh_reference_file", default='./dataset/Genea2023/trn/main-agent/bvh/trn_2023_v0_000_main-agent.bvh', type=str, + help="BVH file reference. Used for extracting joint length and hierarchy during evaluation and generation.") + group.add_argument("--fgd_embedding", default='./evaluation_metric/output/model_checkpoint_120.bin', type=str, + help="Embedding Space Evaluator network for computing FGD metric.") + +def add_training_options(parser): + group = parser.add_argument_group('training') + group.add_argument("--save_dir", required=True, type=str, + help="Path to save checkpoints and results.") + group.add_argument("--overwrite", action='store_true', + help="If True, will enable to use an already existing save_dir.") + group.add_argument("--train_platform_type", default='NoPlatform', choices=['NoPlatform', 'ClearmlPlatform', 'TensorboardPlatform'], type=str, + help="Choose platform to log results. NoPlatform means no logging.") + group.add_argument("--lr", default=1e-4, type=float, help="Learning rate.") + group.add_argument("--weight_decay", default=0.0, type=float, help="Optimizer weight decay.") + group.add_argument("--lr_anneal_steps", default=0, type=int, help="Number of learning rate anneal steps.") + group.add_argument("--eval_batch_size", default=32, type=int, + help="Batch size during evaluation loop. Do not change this unless you know what you are doing. " + "T2m precision calculation is based on fixed batch size 32.") + group.add_argument("--eval_split", default='test', choices=['val', 'test'], type=str, + help="Which split to evaluate on during training.") + group.add_argument("--eval_during_training", action='store_true', + help="If True, will run evaluation during training.") + group.add_argument("--eval_rep_times", default=3, type=int, + help="Number of repetitions for evaluation loop during training.") + group.add_argument("--eval_num_samples", default=1_000, type=int, + help="If -1, will use all samples in the specified split.") + group.add_argument("--log_interval", default=1_000, type=int, + help="Log losses each N steps") + group.add_argument("--save_interval", default=50_000, type=int, + help="Save checkpoints and run evaluation each N steps") + group.add_argument("--num_steps", default=600_000, type=int, + help="Training will stop after the specified number of steps.") + group.add_argument("--resume_checkpoint", default="", type=str, + help="If not empty, will start from the specified checkpoint (path to model###.pt file).") + + +def add_sampling_options(parser): + group = parser.add_argument_group('sampling') + group.add_argument("--model_path", required=True, type=str, + help="Path to model####.pt file to be sampled.") + group.add_argument("--output_dir", default='', type=str, + help="Path to results dir (auto created by the script). " + "If empty, will create dir in parallel to checkpoint.") + group.add_argument("--num_samples", default=10, type=int, + help="Maximal number of prompts to sample, " + "if loading dataset from file, this field will be ignored.") + group.add_argument("--num_repetitions", default=3, type=int, + help="Number of repetitions, per sample (text prompt/action)") + group.add_argument("--guidance_param", default=2.5, type=float, + help="For classifier-free sampling - specifies the s parameter, as defined in the paper.") + + +def add_generate_options(parser): + group = parser.add_argument_group('generate') + group.add_argument("--motion_length", default=6.0, type=float, + help="The length of the sampled motion [in seconds]. " + "Maximum is 9.8 for HumanML3D (text-to-motion), and 2.0 for HumanAct12 (action-to-motion)") + group.add_argument("--input_text", default='', type=str, + help="Path to a text file lists text prompts to be synthesized. If empty, will take text prompts from dataset.") + group.add_argument("--action_file", default='', type=str, + help="Path to a text file that lists names of actions to be synthesized. Names must be a subset of dataset/uestc/info/action_classes.txt if sampling from uestc, " + "or a subset of [warm_up,walk,run,jump,drink,lift_dumbbell,sit,eat,turn steering wheel,phone,boxing,throw] if sampling from humanact12. " + "If no file is specified, will take action names from dataset.") + group.add_argument("--text_prompt", default='', type=str, + help="A text prompt to be generated. If empty, will take text prompts from dataset.") + group.add_argument("--action_name", default='', type=str, + help="An action name to be generated. If empty, will take text prompts from dataset.") + + +def add_edit_options(parser): + group = parser.add_argument_group('edit') + group.add_argument("--edit_mode", default='in_between', choices=['in_between', 'upper_body'], type=str, + help="Defines which parts of the input motion will be edited.\n" + "(1) in_between - suffix and prefix motion taken from input motion, " + "middle motion is generated.\n" + "(2) upper_body - lower body joints taken from input motion, " + "upper body is generated.") + group.add_argument("--text_condition", default='', type=str, + help="Editing will be conditioned on this text prompt. " + "If empty, will perform unconditioned editing.") + group.add_argument("--prefix_end", default=0.25, type=float, + help="For in_between editing - Defines the end of input prefix (ratio from all frames).") + group.add_argument("--suffix_start", default=0.75, type=float, + help="For in_between editing - Defines the start of input suffix (ratio from all frames).") + + +def add_evaluation_options(parser): + group = parser.add_argument_group('eval') + group.add_argument("--model_path", required=True, type=str, + help="Path to model####.pt file to be sampled.") + group.add_argument("--eval_mode", default='wo_mm', choices=['wo_mm', 'mm_short', 'debug', 'full'], type=str, + help="wo_mm (t2m only) - 20 repetitions without multi-modality metric; " + "mm_short (t2m only) - 5 repetitions with multi-modality metric; " + "debug - short run, less accurate results." + "full (a2m only) - 20 repetitions.") + group.add_argument("--guidance_param", default=2.5, type=float, + help="For classifier-free sampling - specifies the s parameter, as defined in the paper.") + + +def train_args(): + parser = ArgumentParser() + add_base_options(parser) + add_data_options(parser) + add_model_options(parser) + add_diffusion_options(parser) + add_training_options(parser) + return parser.parse_args() + + +def generate_args(): + parser = ArgumentParser() + # args specified by the user: (all other will be loaded from the model) + add_base_options(parser) + add_sampling_options(parser) + add_generate_options(parser) + args = parse_and_load_from_model(parser) + return args + + +def edit_args(): + parser = ArgumentParser() + # args specified by the user: (all other will be loaded from the model) + add_base_options(parser) + add_sampling_options(parser) + add_edit_options(parser) + return parse_and_load_from_model(parser) + + +def evaluation_parser(): + parser = ArgumentParser() + # args specified by the user: (all other will be loaded from the model) + add_base_options(parser) + add_evaluation_options(parser) + return parse_and_load_from_model(parser) diff --git a/utils/rotation_conversions.py b/utils/rotation_conversions.py new file mode 100644 index 0000000..210ae1f --- /dev/null +++ b/utils/rotation_conversions.py @@ -0,0 +1,552 @@ +# This code is based on https://github.com/Mathux/ACTOR.git +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# Check PYTORCH3D_LICENCE before use + +import functools +from typing import Optional + +import torch +import torch.nn.functional as F + + +""" +The transformation matrices returned from the functions in this file assume +the points on which the transformation will be applied are column vectors. +i.e. the R matrix is structured as + + R = [ + [Rxx, Rxy, Rxz], + [Ryx, Ryy, Ryz], + [Rzx, Rzy, Rzz], + ] # (3, 3) + +This matrix can be applied to column vectors by post multiplication +by the points e.g. + + points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point + transformed_points = R * points + +To apply the same matrix to points which are row vectors, the R matrix +can be transposed and pre multiplied by the points: + +e.g. + points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point + transformed_points = points * R.transpose(1, 0) +""" + + +def quaternion_to_matrix(quaternions): + """ + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def _copysign(a, b): + """ + Return a tensor where each element has the absolute value taken from the, + corresponding element of a, with sign taken from the corresponding + element of b. This is like the standard copysign floating-point operation, + but is not careful about negative 0 and NaN. + + Args: + a: source tensor. + b: tensor whose signs will be used, of the same shape as a. + + Returns: + Tensor of the same shape as a with the signs of b. + """ + signs_differ = (a < 0) != (b < 0) + return torch.where(signs_differ, -a, a) + + +def _sqrt_positive_part(x): + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + ret[positive_mask] = torch.sqrt(x[positive_mask]) + return ret + + +def matrix_to_quaternion(matrix): + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") + m00 = matrix[..., 0, 0] + m11 = matrix[..., 1, 1] + m22 = matrix[..., 2, 2] + o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22) + x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22) + y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22) + z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22) + o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2]) + o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0]) + o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1]) + return torch.stack((o0, o1, o2, o3), -1) + + +def _axis_angle_rotation(axis: str, angle): + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + if axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + if axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + + +def euler_angles_to_matrix(euler_angles, convention: str): + """ + Convert rotations given as Euler angles in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3). + convention: Convention string of three uppercase letters from + {"X", "Y", and "Z"}. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1)) + return functools.reduce(torch.matmul, matrices) + + +def _angle_from_tan( + axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool +): + """ + Extract the first or third Euler angle from the two members of + the matrix which are positive constant times its sine and cosine. + + Args: + axis: Axis label "X" or "Y or "Z" for the angle we are finding. + other_axis: Axis label "X" or "Y or "Z" for the middle axis in the + convention. + data: Rotation matrices as tensor of shape (..., 3, 3). + horizontal: Whether we are looking for the angle for the third axis, + which means the relevant entries are in the same row of the + rotation matrix. If not, they are in the same column. + tait_bryan: Whether the first and third axes in the convention differ. + + Returns: + Euler Angles in radians for each matrix in dataset as a tensor + of shape (...). + """ + + i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] + if horizontal: + i2, i1 = i1, i2 + even = (axis + other_axis) in ["XY", "YZ", "ZX"] + if horizontal == even: + return torch.atan2(data[..., i1], data[..., i2]) + if tait_bryan: + return torch.atan2(-data[..., i2], data[..., i1]) + return torch.atan2(data[..., i2], -data[..., i1]) + + +def _index_from_letter(letter: str): + if letter == "X": + return 0 + if letter == "Y": + return 1 + if letter == "Z": + return 2 + + +def matrix_to_euler_angles(matrix, convention: str): + """ + Convert rotations given as rotation matrices to Euler angles in radians. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + convention: Convention string of three uppercase letters. + + Returns: + Euler angles in radians as tensor of shape (..., 3). + """ + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") + i0 = _index_from_letter(convention[0]) + i2 = _index_from_letter(convention[2]) + tait_bryan = i0 != i2 + if tait_bryan: + central_angle = torch.asin( + matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) + ) + else: + central_angle = torch.acos(matrix[..., i0, i0]) + + o = ( + _angle_from_tan( + convention[0], convention[1], matrix[..., i2], False, tait_bryan + ), + central_angle, + _angle_from_tan( + convention[2], convention[1], matrix[..., i0, :], True, tait_bryan + ), + ) + return torch.stack(o, -1) + + +def random_quaternions( + n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate random quaternions representing rotations, + i.e. versors with nonnegative real part. + + Args: + n: Number of quaternions in a batch to return. + dtype: Type to return. + device: Desired device of returned tensor. Default: + uses the current device for the default tensor type. + requires_grad: Whether the resulting tensor should have the gradient + flag set. + + Returns: + Quaternions as tensor of shape (N, 4). + """ + o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad) + s = (o * o).sum(1) + o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None] + return o + + +def random_rotations( + n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate random rotations as 3x3 rotation matrices. + + Args: + n: Number of rotation matrices in a batch to return. + dtype: Type to return. + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type. + requires_grad: Whether the resulting tensor should have the gradient + flag set. + + Returns: + Rotation matrices as tensor of shape (n, 3, 3). + """ + quaternions = random_quaternions( + n, dtype=dtype, device=device, requires_grad=requires_grad + ) + return quaternion_to_matrix(quaternions) + + +def random_rotation( + dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate a single random 3x3 rotation matrix. + + Args: + dtype: Type to return + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type + requires_grad: Whether the resulting tensor should have the gradient + flag set + + Returns: + Rotation matrix as tensor of shape (3, 3). + """ + return random_rotations(1, dtype, device, requires_grad)[0] + + +def standardize_quaternion(quaternions): + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) + + +def quaternion_raw_multiply(a, b): + """ + Multiply two quaternions. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions shape (..., 4). + """ + aw, ax, ay, az = torch.unbind(a, -1) + bw, bx, by, bz = torch.unbind(b, -1) + ow = aw * bw - ax * bx - ay * by - az * bz + ox = aw * bx + ax * bw + ay * bz - az * by + oy = aw * by - ax * bz + ay * bw + az * bx + oz = aw * bz + ax * by - ay * bx + az * bw + return torch.stack((ow, ox, oy, oz), -1) + + +def quaternion_multiply(a, b): + """ + Multiply two quaternions representing rotations, returning the quaternion + representing their composition, i.e. the versor with nonnegative real part. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions of shape (..., 4). + """ + ab = quaternion_raw_multiply(a, b) + return standardize_quaternion(ab) + + +def quaternion_invert(quaternion): + """ + Given a quaternion representing rotation, get the quaternion representing + its inverse. + + Args: + quaternion: Quaternions as tensor of shape (..., 4), with real part + first, which must be versors (unit quaternions). + + Returns: + The inverse, a tensor of quaternions of shape (..., 4). + """ + + return quaternion * quaternion.new_tensor([1, -1, -1, -1]) + + +def quaternion_apply(quaternion, point): + """ + Apply the rotation given by a quaternion to a 3D point. + Usual torch rules for broadcasting apply. + + Args: + quaternion: Tensor of quaternions, real part first, of shape (..., 4). + point: Tensor of 3D points of shape (..., 3). + + Returns: + Tensor of rotated points of shape (..., 3). + """ + if point.size(-1) != 3: + raise ValueError(f"Points are not in 3D, f{point.shape}.") + real_parts = point.new_zeros(point.shape[:-1] + (1,)) + point_as_quaternion = torch.cat((real_parts, point), -1) + out = quaternion_raw_multiply( + quaternion_raw_multiply(quaternion, point_as_quaternion), + quaternion_invert(quaternion), + ) + return out[..., 1:] + + +def axis_angle_to_matrix(axis_angle): + """ + Convert rotations given as axis/angle to rotation matrices. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) + + +def matrix_to_axis_angle(matrix): + """ + Convert rotations given as rotation matrices to axis/angle. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) + + +def axis_angle_to_quaternion(axis_angle): + """ + Convert rotations given as axis/angle to quaternions. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) + half_angles = 0.5 * angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + quaternions = torch.cat( + [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 + ) + return quaternions + + +def quaternion_to_axis_angle(quaternions): + """ + Convert rotations given as quaternions to axis/angle. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) + half_angles = torch.atan2(norms, quaternions[..., :1]) + angles = 2 * half_angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + return quaternions[..., 1:] / sin_half_angles_over_angles + + +def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: + """ + Converts 6D rotation representation by Zhou et al. [1] to rotation matrix + using Gram--Schmidt orthogonalisation per Section B of [1]. + Args: + d6: 6D rotation representation, of size (*, 6) + + Returns: + batch of rotation matrices of size (*, 3, 3) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + +def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: + """ + Converts rotation matrices to 6D rotation representation by Zhou et al. [1] + by dropping the last row. Note that 6D representation is not unique. + Args: + matrix: batch of rotation matrices of size (*, 3, 3) + + Returns: + 6D rotation representation, of size (*, 6) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6) diff --git a/wavlm/WavLM.py b/wavlm/WavLM.py new file mode 100644 index 0000000..c111c30 --- /dev/null +++ b/wavlm/WavLM.py @@ -0,0 +1,743 @@ +# -------------------------------------------------------- +# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf) +# Github source: https://github.com/microsoft/unilm/tree/master/wavlm +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import math +import logging +from typing import List, Optional, Tuple + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import LayerNorm +from wavlm.modules import ( + Fp32GroupNorm, + Fp32LayerNorm, + GradMultiply, + MultiheadAttention, + SamePad, + init_bert_params, + get_activation_fn, + TransposeLast, + GLU_Linear, +) + +logger = logging.getLogger(__name__) + + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = np.random.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = np.random.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = np.random.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - keep_length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = np.random.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + + min_len = min([len(m) for m in mask_idcs]) + for i, mask_idc in enumerate(mask_idcs): + if len(mask_idc) > min_len: + mask_idc = np.random.choice(mask_idc, min_len, replace=False) + mask[i, mask_idc] = True + + return mask + + +class WavLMConfig: + def __init__(self, cfg=None): + self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True) + self.encoder_layers: int = 12 # num encoder layers in the transformer + + self.encoder_embed_dim: int = 768 # encoder embedding dimension + self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN + self.encoder_attention_heads: int = 12 # num encoder attention heads + self.activation_fn: str = "gelu" # activation function to use + + self.layer_norm_first: bool = False # apply layernorm first in the transformer + self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...] + self.conv_bias: bool = False # include bias in conv encoder + self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this + + self.normalize: bool = False # normalize input to have 0 mean and unit variance during training + + # dropouts + self.dropout: float = 0.1 # dropout probability for the transformer + self.attention_dropout: float = 0.1 # dropout probability for attention weights + self.activation_dropout: float = 0.0 # dropout probability after activation in FFN + self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer + self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr) + self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr) + + # masking + self.mask_length: int = 10 # mask length + self.mask_prob: float = 0.65 # probability of replacing a token with mask + self.mask_selection: str = "static" # how to choose mask length + self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh + self.no_mask_overlap: bool = False # whether to allow masks to overlap + self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled) + + # channel masking + self.mask_channel_length: int = 10 # length of the mask for features (channels) + self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0 + self.mask_channel_selection: str = "static" # how to choose mask length for channel masking + self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices + self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap + self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled) + + # positional embeddings + self.conv_pos: int = 128 # number of filters for convolutional positional embeddings + self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding + + # relative position embedding + self.relative_position_embedding: bool = False # apply relative position embedding + self.num_buckets: int = 320 # number of buckets for relative position embedding + self.max_distance: int = 1280 # maximum distance for relative position embedding + self.gru_rel_pos: bool = False # apply gated relative position embedding + + if cfg is not None: + self.update(cfg) + + def update(self, cfg: dict): + self.__dict__.update(cfg) + + +class WavLM(nn.Module): + def __init__( + self, + cfg: WavLMConfig, + ) -> None: + super().__init__() + logger.info(f"WavLM Config: {cfg.__dict__}") + + self.cfg = cfg + feature_enc_layers = eval(cfg.conv_feature_layers) + self.embed = feature_enc_layers[-1][0] + + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0.0, + mode=cfg.extractor_mode, + conv_bias=cfg.conv_bias, + ) + + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim + else None + ) + + self.mask_prob = cfg.mask_prob + self.mask_selection = cfg.mask_selection + self.mask_other = cfg.mask_other + self.mask_length = cfg.mask_length + self.no_mask_overlap = cfg.no_mask_overlap + self.mask_min_space = cfg.mask_min_space + + self.mask_channel_prob = cfg.mask_channel_prob + self.mask_channel_selection = cfg.mask_channel_selection + self.mask_channel_other = cfg.mask_channel_other + self.mask_channel_length = cfg.mask_channel_length + self.no_mask_channel_overlap = cfg.no_mask_channel_overlap + self.mask_channel_min_space = cfg.mask_channel_min_space + + self.dropout_input = nn.Dropout(cfg.dropout_input) + self.dropout_features = nn.Dropout(cfg.dropout_features) + + self.feature_grad_mult = cfg.feature_grad_mult + + self.mask_emb = nn.Parameter( + torch.FloatTensor(cfg.encoder_embed_dim).uniform_() + ) + + self.encoder = TransformerEncoder(cfg) + self.layer_norm = LayerNorm(self.embed) + + def apply_mask(self, x, padding_mask): + B, T, C = x.shape + if self.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x[mask_indices] = self.mask_emb + else: + mask_indices = None + + if self.mask_channel_prob > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x[mask_channel_indices] = 0 + + return x, mask_indices + + def forward_padding_mask( + self, features: torch.Tensor, padding_mask: torch.Tensor, + ) -> torch.Tensor: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view( + padding_mask.size(0), features.size(1), -1 + ) + padding_mask = padding_mask.all(-1) + return padding_mask + + def extract_features( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + mask: bool = False, + ret_conv: bool = False, + output_layer: Optional[int] = None, + ret_layer_results: bool = False, + ): + + if self.feature_grad_mult > 0: + features = self.feature_extractor(source) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with torch.no_grad(): + features = self.feature_extractor(source) + + features = features.transpose(1, 2) + features = self.layer_norm(features) + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + features = self.dropout_input(features) + + if mask: + x, mask_indices = self.apply_mask( + features, padding_mask + ) + else: + x = features + + # feature: (B, T, D), float + # target: (B, T), long + # x: (B, T, D), float + # padding_mask: (B, T), bool + # mask_indices: (B, T), bool + x, layer_results = self.encoder( + x, + padding_mask=padding_mask, + layer=None if output_layer is None else output_layer - 1 + ) + + res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results} + + feature = res["features"] if ret_conv else res["x"] + if ret_layer_results: + feature = (feature, res["layer_results"]) + return feature, res["padding_mask"] + + +class ConvFeatureExtractionModel(nn.Module): + def __init__( + self, + conv_layers: List[Tuple[int, int, int]], + dropout: float = 0.0, + mode: str = "default", + conv_bias: bool = False, + conv_type: str = "default" + ): + super().__init__() + + assert mode in {"default", "layer_norm"} + + def block( + n_in, + n_out, + k, + stride, + is_layer_norm=False, + is_group_norm=False, + conv_bias=False, + ): + def make_conv(): + conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) + nn.init.kaiming_normal_(conv.weight) + return conv + + assert ( + is_layer_norm and is_group_norm + ) == False, "layer norm and group norm are exclusive" + + if is_layer_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + nn.Sequential( + TransposeLast(), + Fp32LayerNorm(dim, elementwise_affine=True), + TransposeLast(), + ), + nn.GELU(), + ) + elif is_group_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + Fp32GroupNorm(dim, dim, affine=True), + nn.GELU(), + ) + else: + return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) + + self.conv_type = conv_type + if self.conv_type == "default": + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3, "invalid conv definition: " + str(cl) + (dim, k, stride) = cl + + self.conv_layers.append( + block( + in_d, + dim, + k, + stride, + is_layer_norm=mode == "layer_norm", + is_group_norm=mode == "default" and i == 0, + conv_bias=conv_bias, + ) + ) + in_d = dim + elif self.conv_type == "conv2d": + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3 + (dim, k, stride) = cl + + self.conv_layers.append( + torch.nn.Conv2d(in_d, dim, k, stride) + ) + self.conv_layers.append(torch.nn.ReLU()) + in_d = dim + elif self.conv_type == "custom": + in_d = 1 + idim = 80 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3 + (dim, k, stride) = cl + self.conv_layers.append( + torch.nn.Conv2d(in_d, dim, k, stride, padding=1) + ) + self.conv_layers.append( + torch.nn.LayerNorm([dim, idim]) + ) + self.conv_layers.append(torch.nn.ReLU()) + in_d = dim + if (i + 1) % 2 == 0: + self.conv_layers.append( + torch.nn.MaxPool2d(2, stride=2, ceil_mode=True) + ) + idim = int(math.ceil(idim / 2)) + else: + pass + + def forward(self, x, mask=None): + + # BxT -> BxCxT + x = x.unsqueeze(1) + if self.conv_type == "custom": + for conv in self.conv_layers: + if isinstance(conv, nn.LayerNorm): + x = x.transpose(1, 2) + x = conv(x).transpose(1, 2) + else: + x = conv(x) + x = x.transpose(2, 3).contiguous() + x = x.view(x.size(0), -1, x.size(-1)) + else: + for conv in self.conv_layers: + x = conv(x) + if self.conv_type == "conv2d": + b, c, t, f = x.size() + x = x.transpose(2, 3).contiguous().view(b, c * f, t) + return x + + +class TransformerEncoder(nn.Module): + def __init__(self, args): + super().__init__() + + self.dropout = args.dropout + self.embedding_dim = args.encoder_embed_dim + + self.pos_conv = nn.Conv1d( + self.embedding_dim, + self.embedding_dim, + kernel_size=args.conv_pos, + padding=args.conv_pos // 2, + groups=args.conv_pos_groups, + ) + dropout = 0 + std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) + nn.init.normal_(self.pos_conv.weight, mean=0, std=std) + nn.init.constant_(self.pos_conv.bias, 0) + + self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) + self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) + + if hasattr(args, "relative_position_embedding"): + self.relative_position_embedding = args.relative_position_embedding + self.num_buckets = args.num_buckets + self.max_distance = args.max_distance + else: + self.relative_position_embedding = False + self.num_buckets = 0 + self.max_distance = 0 + + self.layers = nn.ModuleList( + [ + TransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + has_relative_attention_bias=(self.relative_position_embedding and i == 0), + num_buckets=self.num_buckets, + max_distance=self.max_distance, + gru_rel_pos=args.gru_rel_pos, + ) + for i in range(args.encoder_layers) + ] + ) + + self.layer_norm_first = args.layer_norm_first + self.layer_norm = LayerNorm(self.embedding_dim) + self.layerdrop = args.encoder_layerdrop + + self.apply(init_bert_params) + + def forward(self, x, padding_mask=None, streaming_mask=None, layer=None): + x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer) + + if self.layer_norm_first and layer is None: + x = self.layer_norm(x) + + return x, layer_results + + def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None): + + if padding_mask is not None: + x[padding_mask] = 0 + + x_conv = self.pos_conv(x.transpose(1, 2)) + x_conv = x_conv.transpose(1, 2) + x = x + x_conv + + if not self.layer_norm_first: + x = self.layer_norm(x) + + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + layer_results = [] + z = None + if tgt_layer is not None: + layer_results.append((x, z)) + r = None + pos_bias = None + for i, layer in enumerate(self.layers): + dropout_probability = np.random.random() + if not self.training or (dropout_probability > self.layerdrop): + x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, + self_attn_mask=streaming_mask, pos_bias=pos_bias) + if tgt_layer is not None: + layer_results.append((x, z)) + if i == tgt_layer: + r = x + break + + if r is not None: + x = r + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x, layer_results + + +class TransformerSentenceEncoderLayer(nn.Module): + """ + Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained + models. + """ + + def __init__( + self, + embedding_dim: float = 768, + ffn_embedding_dim: float = 3072, + num_attention_heads: float = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + layer_norm_first: bool = False, + has_relative_attention_bias: bool = False, + num_buckets: int = 0, + max_distance: int = 0, + rescale_init: bool = False, + gru_rel_pos: bool = False, + ) -> None: + + super().__init__() + # Initialize parameters + self.embedding_dim = embedding_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + + # Initialize blocks + self.activation_name = activation_fn + self.activation_fn = get_activation_fn(activation_fn) + self.self_attn = MultiheadAttention( + self.embedding_dim, + num_attention_heads, + dropout=attention_dropout, + self_attention=True, + has_relative_attention_bias=has_relative_attention_bias, + num_buckets=num_buckets, + max_distance=max_distance, + rescale_init=rescale_init, + gru_rel_pos=gru_rel_pos, + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(self.activation_dropout) + self.dropout3 = nn.Dropout(dropout) + + self.layer_norm_first = layer_norm_first + + # layer norm associated with the self attention layer + self.self_attn_layer_norm = LayerNorm(self.embedding_dim) + + if self.activation_name == "glu": + self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish") + else: + self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) + self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) + + # layer norm associated with the position wise feed-forward NN + self.final_layer_norm = LayerNorm(self.embedding_dim) + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + need_weights: bool = False, + pos_bias=None + ): + """ + LayerNorm is applied either before or after the self-attention/ffn + modules similar to the original Transformer imlementation. + """ + residual = x + + if self.layer_norm_first: + x = self.self_attn_layer_norm(x) + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=False, + attn_mask=self_attn_mask, + position_bias=pos_bias + ) + x = self.dropout1(x) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + else: + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=need_weights, + attn_mask=self_attn_mask, + position_bias=pos_bias + ) + + x = self.dropout1(x) + x = residual + x + + x = self.self_attn_layer_norm(x) + + residual = x + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + x = self.final_layer_norm(x) + + return x, attn, pos_bias + diff --git a/wavlm/modules.py b/wavlm/modules.py new file mode 100644 index 0000000..1dcfc6f --- /dev/null +++ b/wavlm/modules.py @@ -0,0 +1,827 @@ +# -------------------------------------------------------- +# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf) +# Github source: https://github.com/microsoft/unilm/tree/master/wavlm +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import math +import warnings +from typing import Dict, Optional, Tuple +import torch +from torch import Tensor, nn +from torch.nn import Parameter +import torch.nn.functional as F + + +class TransposeLast(nn.Module): + def __init__(self, deconstruct_idx=None): + super().__init__() + self.deconstruct_idx = deconstruct_idx + + def forward(self, x): + if self.deconstruct_idx is not None: + x = x[self.deconstruct_idx] + return x.transpose(-2, -1) + + +class Fp32LayerNorm(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.layer_norm( + input.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +class Fp32GroupNorm(nn.GroupNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.group_norm( + input.float(), + self.num_groups, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +class GradMultiply(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale): + ctx.scale = scale + res = x.new(x) + return res + + @staticmethod + def backward(ctx, grad): + return grad * ctx.scale, None + + +class SamePad(nn.Module): + def __init__(self, kernel_size, causal=False): + super().__init__() + if causal: + self.remove = kernel_size - 1 + else: + self.remove = 1 if kernel_size % 2 == 0 else 0 + + def forward(self, x): + if self.remove > 0: + x = x[:, :, : -self.remove] + return x + + +class Swish(nn.Module): + """Swish function + """ + + def __init__(self): + """Construct an MultiHeadedAttention object.""" + super(Swish, self).__init__() + self.act = torch.nn.Sigmoid() + + def forward(self, x): + return x * self.act(x) + + +class GLU_Linear(nn.Module): + def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True): + super(GLU_Linear, self).__init__() + + self.glu_type = glu_type + self.output_dim = output_dim + + if glu_type == "sigmoid": + self.glu_act = torch.nn.Sigmoid() + elif glu_type == "swish": + self.glu_act = Swish() + elif glu_type == "relu": + self.glu_act = torch.nn.ReLU() + elif glu_type == "gelu": + self.glu_act = torch.nn.GELU() + + if bias_in_glu: + self.linear = nn.Linear(input_dim, output_dim * 2, True) + else: + self.linear = nn.Linear(input_dim, output_dim * 2, False) + + def forward(self, x): + # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case + x = self.linear(x) + + if self.glu_type == "bilinear": + x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2]) + else: + x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2])) + + return x + + +def gelu_accurate(x): + if not hasattr(gelu_accurate, "_a"): + gelu_accurate._a = math.sqrt(2 / math.pi) + return ( + 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) + ) + + +def gelu(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x.float()).type_as(x) + + +def get_activation_fn(activation: str): + """Returns the activation function corresponding to `activation`""" + + if activation == "relu": + return F.relu + elif activation == "gelu": + return gelu + elif activation == "gelu_fast": + warnings.warn( + "--activation-fn=gelu_fast has been renamed to gelu_accurate" + ) + return gelu_accurate + elif activation == "gelu_accurate": + return gelu_accurate + elif activation == "tanh": + return torch.tanh + elif activation == "linear": + return lambda x: x + elif activation == "glu": + return lambda x: x + else: + raise RuntimeError("--activation-fn {} not supported".format(activation)) + + +def init_bert_params(module): + """ + Initialize the weights specific to the BERT Model. + This overrides the default initializations depending on the specified arguments. + 1. If normal_init_linear_weights is set then weights of linear + layer will be initialized using the normal distribution and + bais will be set to the specified value. + 2. If normal_init_embed_weights is set then weights of embedding + layer will be initialized using the normal distribution. + 3. If normal_init_proj_weights is set then weights of + in_project_weight for MultiHeadAttention initialized using + the normal distribution (to be validated). + """ + + def normal_(data): + # with FSDP, module params will be on CUDA, so we cast them back to CPU + # so that the RNG is consistent with and without FSDP + data.copy_( + data.cpu().normal_(mean=0.0, std=0.02).to(data.device) + ) + + if isinstance(module, nn.Linear): + normal_(module.weight.data) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + normal_(module.weight.data) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + if isinstance(module, MultiheadAttention): + normal_(module.q_proj.weight.data) + normal_(module.k_proj.weight.data) + normal_(module.v_proj.weight.data) + + +def quant_noise(module, p, block_size): + """ + Wraps modules and applies quantization noise to the weights for + subsequent quantization with Iterative Product Quantization as + described in "Training with Quantization Noise for Extreme Model Compression" + + Args: + - module: nn.Module + - p: amount of Quantization Noise + - block_size: size of the blocks for subsequent quantization with iPQ + + Remarks: + - Module weights must have the right sizes wrt the block size + - Only Linear, Embedding and Conv2d modules are supported for the moment + - For more detail on how to quantize by blocks with convolutional weights, + see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" + - We implement the simplest form of noise here as stated in the paper + which consists in randomly dropping blocks + """ + + # if no quantization noise, don't register hook + if p <= 0: + return module + + # supported modules + assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) + + # test whether module.weight has the right sizes wrt block_size + is_conv = module.weight.ndim == 4 + + # 2D matrix + if not is_conv: + assert ( + module.weight.size(1) % block_size == 0 + ), "Input features must be a multiple of block sizes" + + # 4D matrix + else: + # 1x1 convolutions + if module.kernel_size == (1, 1): + assert ( + module.in_channels % block_size == 0 + ), "Input channels must be a multiple of block sizes" + # regular convolutions + else: + k = module.kernel_size[0] * module.kernel_size[1] + assert k % block_size == 0, "Kernel size must be a multiple of block size" + + def _forward_pre_hook(mod, input): + # no noise for evaluation + if mod.training: + if not is_conv: + # gather weight and sizes + weight = mod.weight + in_features = weight.size(1) + out_features = weight.size(0) + + # split weight matrix into blocks and randomly drop selected blocks + mask = torch.zeros( + in_features // block_size * out_features, device=weight.device + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) + + else: + # gather weight and sizes + weight = mod.weight + in_channels = mod.in_channels + out_channels = mod.out_channels + + # split weight matrix into blocks and randomly drop selected blocks + if mod.kernel_size == (1, 1): + mask = torch.zeros( + int(in_channels // block_size * out_channels), + device=weight.device, + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) + else: + mask = torch.zeros( + weight.size(0), weight.size(1), device=weight.device + ) + mask.bernoulli_(p) + mask = ( + mask.unsqueeze(2) + .unsqueeze(3) + .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) + ) + + # scale weights and apply mask + mask = mask.to( + torch.bool + ) # x.bool() is not currently supported in TorchScript + s = 1 / (1 - p) + mod.weight.data = s * weight.masked_fill(mask, 0) + + module.register_forward_pre_hook(_forward_pre_hook) + return module + + +class MultiheadAttention(nn.Module): + """Multi-headed attention. + + See "Attention Is All You Need" for more details. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + has_relative_attention_bias=False, + num_buckets=32, + max_distance=128, + gru_rel_pos=False, + rescale_init=False, + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout_module = nn.Dropout(dropout) + + self.has_relative_attention_bias = has_relative_attention_bias + self.num_buckets = num_buckets + self.max_distance = max_distance + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(num_buckets, num_heads) + + self.head_dim = embed_dim // num_heads + self.q_head_dim = self.head_dim + self.k_head_dim = self.head_dim + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, ( + "Self-attention requires query, key and " "value to be of the same size" + ) + + k_bias = True + if rescale_init: + k_bias = False + + k_embed_dim = embed_dim + q_embed_dim = embed_dim + + self.k_proj = quant_noise( + nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size + ) + self.v_proj = quant_noise( + nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size + ) + self.q_proj = quant_noise( + nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size + ) + + self.out_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size + ) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.gru_rel_pos = gru_rel_pos + if self.gru_rel_pos: + self.grep_linear = nn.Linear(self.q_head_dim, 8) + self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1)) + + self.reset_parameters() + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.q_proj.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + if self.has_relative_attention_bias: + nn.init.xavier_normal_(self.relative_attention_bias.weight) + + def _relative_positions_bucket(self, relative_positions, bidirectional=True): + num_buckets = self.num_buckets + max_distance = self.max_distance + relative_buckets = 0 + + if bidirectional: + num_buckets = num_buckets // 2 + relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets + relative_positions = torch.abs(relative_positions) + else: + relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions)) + + max_exact = num_buckets // 2 + is_small = relative_positions < max_exact + + relative_postion_if_large = max_exact + ( + torch.log(relative_positions.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length): + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position + relative_position_bucket = self._relative_positions_bucket( + relative_position, + bidirectional=True + ) + relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) + values = self.relative_attention_bias(relative_position_bucket) + values = values.permute([2, 0, 1]) + return values + + def forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + position_bias: Optional[Tensor] = None + ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + is_tpu = query.device.type == "xla" + + tgt_len, bsz, embed_dim = query.size() + src_len = tgt_len + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + if key is not None: + src_len, key_bsz, _ = key.size() + if not torch.jit.is_scripting(): + assert key_bsz == bsz + assert value is not None + assert src_len, bsz == value.shape[:2] + + if self.has_relative_attention_bias and position_bias is None: + position_bias = self.compute_bias(tgt_len, src_len) + position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len) + + if ( + not is_tpu # don't use PyTorch version on TPUs + and incremental_state is None + and not static_kv + # A workaround for quantization to work. Otherwise JIT compilation + # treats bias in linear module as method. + and not torch.jit.is_scripting() + and self.q_head_dim == self.head_dim + ): + assert key is not None and value is not None + assert attn_mask is None + + attn_mask_rel_pos = None + if position_bias is not None: + attn_mask_rel_pos = position_bias + if self.gru_rel_pos: + query_layer = query.transpose(0, 1) + new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1) + query_layer = query_layer.view(*new_x_shape) + query_layer = query_layer.permute(0, 2, 1, 3) + _B, _H, _L, __ = query_layer.size() + + gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( + _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias + + attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len)) + k_proj_bias = self.k_proj.bias + if k_proj_bias is None: + k_proj_bias = torch.zeros_like(self.q_proj.bias) + + x, attn = F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + torch.empty([0]), + torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout_module.p, + self.out_proj.weight, + self.out_proj.bias, + self.training, + # self.training or self.dropout_module.apply_during_inference, + key_padding_mask, + need_weights, + attn_mask_rel_pos, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + ) + return x, attn, position_bias + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros(key_padding_mask.size(0), 1), + ], + dim=1, + ) + + q = ( + q.contiguous() + .view(tgt_len, bsz * self.num_heads, self.q_head_dim) + .transpose(0, 1) + ) + if k is not None: + k = ( + k.contiguous() + .view(-1, bsz * self.num_heads, self.k_head_dim) + .transpose(0, 1) + ) + if v is not None: + v = ( + v.contiguous() + .view(-1, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + src_len = k.size(1) + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + + saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + assert k.size(1) == src_len + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) + v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + torch.zeros(key_padding_mask.size(0), 1).type_as( + key_padding_mask + ), + ], + dim=1, + ) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + if not is_tpu: + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf"), + ) + else: + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v, position_bias + + if position_bias is not None: + if self.gru_rel_pos == 1: + query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) + _B, _H, _L, __ = query_layer.size() + gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( + _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias + + position_bias = position_bias.view(attn_weights.size()) + + attn_weights = attn_weights + position_bias + + attn_weights_float = F.softmax( + attn_weights, dim=-1 + ) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = self.dropout_module(attn_weights) + + assert v is not None + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + attn_weights: Optional[Tensor] = None + if need_weights: + attn_weights = attn_weights_float.view( + bsz, self.num_heads, tgt_len, src_len + ).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights, position_bias + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[Tensor]: + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 + ) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + if src_len > prev_key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - prev_key_padding_mask.size(1)), + device=prev_key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), filler.float()], dim=1 + ) + else: + new_key_padding_mask = prev_key_padding_mask.float() + elif key_padding_mask is not None: + if src_len > key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - key_padding_mask.size(1)), + device=key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [filler.float(), key_padding_mask.float()], dim=1 + ) + else: + new_key_padding_mask = key_padding_mask.float() + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ) -> Dict[str, Optional[Tensor]]: + result = self.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + else: + empty_result: Dict[str, Optional[Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + buffer: Dict[str, Optional[Tensor]], + ): + return self.set_incremental_state(incremental_state, "attn_state", buffer) + + def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): + return attn_weights