diff --git a/Dockerfile b/Dockerfile index e290f98..92ddb56 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,8 @@ -FROM nvidia/cuda:12.2.0-devel-ubuntu22.04 +FROM pytorch/pytorch:1.6.0-cuda10.1-cudnn7-devel +RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub + +RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub ENV PATH="/root/miniconda3/bin:${PATH}" ARG PATH="/root/miniconda3/bin:${PATH}" @@ -7,12 +10,6 @@ ARG PATH="/root/miniconda3/bin:${PATH}" RUN apt-get update RUN apt-get install -y wget git nano ffmpeg -RUN wget \ - https://repo.anaconda.com/miniconda/Miniconda3-py37_4.8.3-Linux-x86_64.sh \ - && mkdir /root/.conda \ - && bash Miniconda3-py37_4.8.3-Linux-x86_64.sh -b \ - && rm -f Miniconda3-py37_4.8.3-Linux-x86_64.sh - RUN conda --version WORKDIR /root @@ -24,5 +21,11 @@ RUN conda install pip RUN conda --version RUN conda env create -f environment.yml -SHELL ["conda", "run", "-n", "ggvad", "/bin/bash", "-c"] -RUN pip install git+https://github.com/openai/CLIP.git +SHELL ["conda", "run", "-n", "stylistic-env", "/bin/bash", "-c"] +RUN python -m spacy download en_core_web_sm +RUN pip install blobfile +RUN pip install PyYAML +RUN pip install librosa +RUN pip install python_speech_features +RUN pip install einops +RUN pip install wandb diff --git a/README.md b/README.md index fff6ad0..ef5f166 100644 --- a/README.md +++ b/README.md @@ -8,61 +8,45 @@ Official repository for the paper Stylistic Co-Speech Gesture Generation: Modeli 2. Enter the repo and create docker image using ```sh -docker build -t ggvad . +docker build -t stylistic-gesture . ``` 3. Run container using ```sh -docker run --rm -it --gpus device=GPU_NUMBER --userns=host --shm-size 64G -v /MY_DIR/ggvad-genea2023:/workspace/ggvad/ -p PORT_NUMBR --name CONTAINER_NAME ggvad:latest /bin/bash +nvidia-docker run --rm -it -e NVIDIA_VISIBLE_DEVICES={GPU} --runtime=nvidia --userns=host --shm-size 64G -v {LOCAL_DIR}:{CONTAINER_DIR} -p {PORT} --name {CONTAINER_NAME} stylistic-gesture:latest /bin/bash ``` for example: ```sh -docker run --rm -it --gpus device=0 --userns=host --shm-size 64G -v C:\ProgramFiles\ggvad-genea2023:/workspace/my_repo -p '8888:8888' --name my_container ggvad:latest /bin/bash -``` - -> ### Cuda version < 12.0: -> -> If you have a previous cuda or nvcc release version you will need to adjust the Dockerfile. Change the first line to `FROM pytorch/pytorch:1.6.0-cuda10.1-cudnn7-devel` and remove lines 10-14 (conda is already installed in the pythorch image). Then, run container using: -> -> ```sh -> nvidia-docker run --rm -it -e NVIDIA_VISIBLE_DEVICES=GPU_NUMBER --runtime=nvidia --userns=host --shm-size 64G -v /work/rodolfo.tonoli/GestureDiffusion:/workspace/gesture-diffusion/ -p $port --name gestdiff_container$number multimodal-research-group-mdm:latest /bin/bash -> ``` - - -OR use the shell script ggvad_container.sh (don't forget to change the volume) using the flags -g, -n, and -p -example: -```sh -sh ggvad_container.sh -g 0 -n my_container -p '8888:8888' +docker run --rm -it --gpus device=0 --userns=host --shm-size 64G -v C:\ProgramFiles\stylistic-gesture:/workspace/stylistic-gesture -p '8888:8888' --name stylistic-gesture-container stylistic-gesture:latest /bin/bash ``` 4. Activate cuda environment: ```sh -source activate ggvad +source activate stylistic-env ``` ## Data pre-processing -1. Get the GENEA Challenge 2023 dataset and put it into `./dataset/` -(Our system is monadic so you'll only need the main-agent's data) +1. Get the BRG-Unicamp dataset following the instructions from [here](https://ai-unicamp.github.io/BRG-Unicamp/) and put it into `./dataset/` 2. Download the [WavLM Base +](https://github.com/microsoft/unilm/tree/master/wavlm) and put it into the folder `/wavlm/` -3. Inside the folder `/workspace/ggvad`, run +3. In the container with the active environment, enter the folder `/workspace/stylistic-gesture`, run ```sh -python -m data_loaders.gesture.scripts.genea_prep +python -m data_loaders.gesture.scripts.ptbrgesture_prep ``` This will convert the bvh files to npy representations, downsample wav files to 16k and save them as npy arrays, and convert these arrays to wavlm representations. The VAD data must be processed separetely due to python libraries incompatibility. 4. (Optional) Process VAD data -We provide the speech activity information (from speechbrain's VAD) data, but if you wish to process them yourself you should redo the steps of "Preparing environment" as before, but for the speechbrain environment: Build the image using the Dockerfile inside speechbrain (`docker build -t speechbrain .`), run the container (`docker run ... --name CONTAINER_NAME speechbrain:latest /bin/bash`) and run: +BRG-Unicamp provides the speech activity information (from speechbrain's VAD) data, but if you wish to process them yourself you should redo the steps of "Preparing environment" as before, but for the speechbrain environment: Build the image using the Dockerfile inside speechbrain (`docker build -t speechbrain .`), run the container (`docker run ... --name CONTAINER_NAME speechbrain:latest /bin/bash`) and run: ```sh -python -m data_loaders.gesture.scripts.genea_prep_vad +python -m data_loaders.gesture.scripts.ptbrgesture_prep_vad ``` ## Train model @@ -70,20 +54,20 @@ python -m data_loaders.gesture.scripts.genea_prep_vad To train the model described in the paper use the following command inside the repo: ```sh -python -m train.train_mdm --save_dir save/my_model_run --dataset genea2023+ --step 10 --use_text --use_vad True --use_wavlm True +python -m train.train_mdm --save_dir save/my_model_run --dataset ptbr --step 10 --use_vad True --use_wavlm True --use_style_enc True ``` ## Gesture Generation -Generate motion using the trained model by running the following command. If you wish to generate gestures with the pretrained model of the Genea Challenge, use `--model_path ./save/default_vad_wavlm/model000290000.pt` +Generate motion using the trained model by running the following command. If you wish to generate gestures with the pretrained model of the Genea Challenge, use `--model_path ./save/stylistic-gesture/model000600000.pt` ```sh -python -m sample.generate --model_path ./save/my_model_run/model000XXXXXX.pt +python -m sample.ptbrgenerate --model_path ./save/my_model_run/model000XXXXXX.pt ``` ## Render -To render the official Genea 2023 visualizations follow the instructions provided [here](https://github.com/TeoNikolov/genea_visualizer/) +In our perceptual evaluation, we used the render procedure from the official GENEA Challenge 2023 visualizations. Instructions provided [here](https://github.com/TeoNikolov/genea_visualizer/) ## Cite diff --git a/data_loaders/gesture/data/dataset.py b/data_loaders/gesture/data/dataset.py deleted file mode 100644 index 72c5839..0000000 --- a/data_loaders/gesture/data/dataset.py +++ /dev/null @@ -1,337 +0,0 @@ -import torch -from torch.utils import data -import csv -import os -import numpy as np -from python_speech_features import mfcc -import librosa -import torch.nn.functional as F - -class Genea2023(data.Dataset): - def __init__(self, name, split='trn', datapath='./dataset/Genea2023/', step=30, window=80, fps=30, sr=22050, n_seed_poses=10, use_wavlm=False, use_vad=False, vadfromtext=False): - - self.split = split - if self.split not in ['trn', 'val', 'tst']: - raise ValueError('Split not recognized') - srcpath = os.path.join(datapath, self.split, 'main-agent/') - - if use_wavlm: - self.sr = 16000 - self.audiopath = os.path.join(srcpath, 'audio16k_npy') - else: - self.sr = sr - self.audiopath = os.path.join(srcpath, 'audio_npy') - - self.name = name - self.step = step - - self.datapath = datapath - self.window=window - self.fps = fps - self.n_seed_poses = n_seed_poses - - self.loadstats(os.path.join(datapath, 'trn/main-agent/')) - self.std = np.array([ item if item != 0 else 1 for item in self.std ]) - self.vel_std = np.array([ item if item != 0 else 1 for item in self.vel_std ]) - self.rot6dpos_std = np.array([ item if item != 0 else 1 for item in self.rot6dpos_std ]) - - if self.split in ['trn', 'val']: - self.motionpath = os.path.join(srcpath, 'motion_npy_rotpos') - self.motionpath_rot6d = os.path.join(srcpath, 'motion_npy_rot6dpos') - - - with open(os.path.join(srcpath, '../metadata.csv')) as csvfile: - reader = csv.reader(csvfile, delimiter=',') - self.takes = [take for take in reader] - self.takes = self.takes[1:] - for take in self.takes: - take[0] += '_main-agent' - - # Here we use the motion files to get the number of frames per take (there are some takes that audio files are longer for the trn and val sets) - self.frames = [] - if self.split in ['trn', 'val']: - for motionfile in self.takes: - motion = np.load(os.path.join(self.motionpath_rot6d, motionfile[0] + '.npy')) - self.frames.append(motion.shape[0]) - elif self.split in ['tst']: - for audiofile in os.listdir(self.audiopath): - if audiofile.endswith('.npy'): - audio = np.load(os.path.join(self.audiopath, audiofile)) - self.frames.append( int(audio.shape[0]/self.sr*self.fps)) - self.frames = np.array(self.frames) - self.frames = np.array(self.frames) - - self.samples_per_file = [int(np.floor( (n - self.window ) / self.step)) for n in self.frames] - self.samples_cumulative = [np.sum(self.samples_per_file[:i+1]) for i in range(len(self.samples_per_file))] - self.length = self.samples_cumulative[-1] - self.textpath = os.path.join(srcpath, 'tsv') - - self.use_wavlm = use_wavlm - if self.use_wavlm: - self.wavlm_rep_path = os.path.join(srcpath, 'wavlm_representations') - - self.use_vad = use_vad - if self.use_vad: - self.vad_path = os.path.join(srcpath, "vad") - self.vadfromtext = vadfromtext - if self.vadfromtext: print('Getting speech activity from text') - - self.alljoints = {'body_world':0,'b_root':1,'b_spine0':2,'b_spine1':3,'b_spine2':4,'b_spine3':5,'b_neck0':6,'b_head':7,'b_head_null':8,'b_l_eye':9,'b_r_eye':10,'b_jaw':11,'b_jaw_null':12,'b_teeth':13,'b_tongue0':14,'b_tongue1':15,'b_tongue2':16,'b_tongue3':17,'b_tongue4':18,'b_l_tongue4_1':19,'b_r_tongue4_1':20,'b_l_tongue3_1':21,'b_r_tongue3_1':22,'b_l_tongue2_1':23,'b_r_tongue2_1':24,'b_r_tongue1_1':25,'b_l_tongue1_1':26,'b_r_shoulder':27,'p_r_scap':28,'b_r_arm':29,'b_r_arm_twist':30,'b_r_forearm':31,'b_r_wrist_twist':32,'b_r_wrist':33,'b_r_index1':34,'b_r_index2':35,'b_r_index3':36,'b_r_ring1':37,'b_r_ring2':38,'b_r_ring3':39,'b_r_middle1':40,'b_r_middle2':41,'b_r_middle3':42,'b_r_pinky1':43,'b_r_pinky2':44,'b_r_pinky3':45,'b_r_thumb0':46,'b_r_thumb1':47,'b_r_thumb2':48,'b_r_thumb3':49,'b_l_shoulder':50,'p_l_delt':51,'p_l_scap':52,'b_l_arm':53,'b_l_arm_twist':54,'b_l_forearm':55,'b_l_wrist_twist':56,'b_l_wrist':57,'b_l_thumb0':58,'b_l_thumb1':59,'b_l_thumb2':60,'b_l_thumb3':61,'b_l_index1':62,'b_l_index2':63,'b_l_index3':64,'b_l_middle1':65,'b_l_middle2':66,'b_l_middle3':67,'b_l_ring1':68,'b_l_ring2':69,'b_l_ring3':70,'b_l_pinky1':71,'b_l_pinky2':72,'b_l_pinky3':73,'p_navel':74,'b_r_upleg':75,'b_r_leg':76,'b_r_foot_twist':77,'b_r_foot':78,'b_l_upleg':79,'b_l_leg':80,'b_l_foot_twist':81,'b_l_foot':82} - - if False: - for take in self.takes: - name = take[0] - m = os.path.join(self.motionpath, name+'.npy') - a = os.path.join(self.audiopath, name+'.npy') - t = os.path.join(self.textpath, name+'.tsv') - assert os.path.isfile( m ), "Motion file {} not found".format(m) - assert os.path.isfile( a ), "Audio file {} not found".format(a) - assert os.path.isfile( t ), "Text file {} not found".format(t) - - def __getitem__(self, idx): - if self.split == 'tst': - raise ValueError('Test set does should not use __getitem__(), use gettestbatch() instead') - # find the file that the sample belongs two - file_idx = np.searchsorted(self.samples_cumulative, idx+1, side='left') - # find sample's index - if file_idx > 0: - sample = idx - self.samples_cumulative[file_idx-1] - else: - sample = idx - take_name = self.takes[file_idx][0] - motion, seed_poses = self.__getmotion( file_idx, sample) - audio, audio_rep = self.__getaudiofeats(file_idx, sample) - n_text, text, tokens, vad = self.__gettext(file_idx, sample) - if self.use_vad: - if not self.vadfromtext: - vad = self.__getvad(file_idx, sample) - else: - vad = np.ones(int(self.window)) # Dummy - return motion, text, self.window, audio, audio_rep, seed_poses, vad, take_name - - def __len__(self): - return self.length - - def __getvad(self, file, sample): - # Cut Chunk - vad_file = np.load(os.path.join(self.vad_path,self.takes[file][0]+'.npy')) - vad_vals = vad_file[sample*self.step: sample*self.step + self.window] # [CHUNK_LEN, ] - - # Reshape - vad_vals = np.expand_dims(vad_vals, 1) # [CHUNK_LEN, 1] - vad_vals = np.transpose(vad_vals, (1,0)) # [1, CHUNK_LEN] - return vad_vals - - def __getmotion(self, file, sample): - if self.name == 'genea2023+': - # loading rot6d and position representations - rot6dpos_file = np.load(os.path.join(self.motionpath_rot6d,self.takes[file][0]+'.npy')) - rot6dpos = (rot6dpos_file[sample*self.step: sample*self.step + self.window,:] - self.rot6dpos_mean) / self.rot6dpos_std - - # loading rotpos representation and computing velocity - rotpos_file = np.load(os.path.join(self.motionpath,self.takes[file][0]+'.npy')) - rotpos_file[1:,:] = rotpos_file[1:,:] - rotpos_file[:-1,:] - rotpos_file[0,:] = np.zeros(rotpos_file.shape[1]) - rotpos = (rotpos_file[sample*self.step: sample*self.step + self.window,:] - self.vel_mean) / self.vel_std - if sample*self.step - self.n_seed_poses < 0: - rot6dpos_seed = np.zeros((self.n_seed_poses, rot6dpos.shape[1])) - rotpos_seed = np.zeros((self.n_seed_poses, rotpos.shape[1])) - else: - rot6dpos_seed = (rot6dpos_file[sample*self.step - self.n_seed_poses: sample*self.step ,:] - self.rot6dpos_mean) / self.rot6dpos_std - rotpos_seed = (rotpos_file[sample*self.step - self.n_seed_poses: sample*self.step,:] - self.vel_mean) / self.vel_std - - motion = np.concatenate((rot6dpos, rotpos), axis=1) - seed_poses = np.concatenate((rot6dpos_seed, rotpos_seed), axis=1) - - else: - motion_file = np.load(os.path.join(self.motionpath,self.takes[file][0]+'.npy')) - motion = (motion_file[sample*self.step: sample*self.step + self.window,:] - self.mean) / self.std - if sample*self.step - self.n_seed_poses < 0: - seed_poses = np.zeros((self.n_seed_poses, motion.shape[1])) - else: - seed_poses = (motion_file[sample*self.step - self.n_seed_poses: sample*self.step,:] - self.mean) / self.std - - return motion, seed_poses - - def __getaudiofeats(self, file, sample): - - # Load Audio - signal = np.load(os.path.join(self.motionpath,'..', 'audio16k_npy',self.takes[file][0]+'.npy')) - - # Cut Chunk - i = sample*16000*self.step/self.fps - signal = signal[ int(i) : int(i+self.window*16000/self.fps) ] - - if self.use_wavlm: - # Cut Chunk - representation_file = np.load(os.path.join(self.wavlm_rep_path,self.takes[file][0]+'.npy')) - wavlm_reps = representation_file[sample*self.step: sample*self.step + self.window,:] # [CHUNK_LEN, WAVLM_DIM] - - # Reshape - wavlm_reps = np.transpose(wavlm_reps, (1,0)) # [WAVLM_DIM, CHUNK_LEN] - wavlm_reps = np.expand_dims(wavlm_reps, 1) # [WAVLM_DIM, 1, CHUNK_LEN] - wavlm_reps = np.expand_dims(wavlm_reps, 0) # [1, WAVLM_DIM, 1, CHUNK_LEN] - return signal, wavlm_reps - else: - return self.__compute_audiofeats(signal) - - def __compute_audiofeats(self, signal): - - # MFCCs - mfcc_vectors = mfcc(signal, winlen=0.06, winstep= (1/self.fps), samplerate=16000, numcep=27, nfft=5000) - - # Normalize - #mfcc_vectors = (mfcc_vectors - self.mfcc_mean) / self.mfcc_std - - # Format - mfcc_vectors = mfcc_vectors.T - mfcc_vectors = np.expand_dims(mfcc_vectors, 1) - mfcc_vectors = np.expand_dims(mfcc_vectors, 0) # should be [1, MFCC_DIM, 1, CHUNK_LEN] - return signal, mfcc_vectors - - def __gettext(self, file, sample): - with open(os.path.join(self.textpath, self.takes[file][0]+'.tsv')) as tsv: - reader = csv.reader(tsv, delimiter='\t') - file = [ [float(word[0])*self.fps, float(word[1])*self.fps, word[2]] for word in reader] - begin = self.search_time(file, sample*self.step) - end = self.search_time(file, sample*self.step + self.window) - text = [ word[-1] for word in file[begin: end] ] - tokens = self.__gentokens(text) - vad = None - if self.vadfromtext: - times = [(np.floor(word[0] - sample*self.step).astype(int), np.ceil(word[1] - sample*self.step).astype(int)) for word in file[begin: end]] - vad = np.zeros(self.window) - for (i, f) in times: - vad[i:f] = 1 - vad = np.expand_dims(vad, 1) # [CHUNK_LEN, 1] - vad = np.transpose(vad, (1,0)) # [1, CHUNK_LEN] - return len(text), ' '.join(text), tokens, vad - - def __gentokens(self, text): - tokens = [ word+'/OTHER' for word in text] - tokens = '_'.join(tokens) - tokens = 'sos/OTHER_' + tokens + '_eos/OTHER' - return tokens - - def search_time(self, text, frame): - for i in range(len(text)): - if frame <= text[i][0]: - return i if (frame > text[i-1][1] or i==0) else i-1 - - def inv_transform(self, data): - if self.name == 'genea2023': - return data * self.std + self.mean - elif self.name == 'genea2023+': - return data * np.concatenate((self.rot6dpos_std, self.vel_std)) + np.concatenate((self.rot6dpos_mean, self.vel_mean)) - else: - raise ValueError('Dataset name not recognized') - - - def gettime(self): - import time - start = time.time() - for i in range(200): - sample = self.__getitem__(i) - print(time.time()-start) - - def loadstats(self, statspath): - self.std = np.load(os.path.join(statspath, 'rotpos_Std.npy')) - self.mean = np.load(os.path.join(statspath, 'rotpos_Mean.npy')) - #self.mfcc_std = np.load(os.path.join(statspath, 'mfccs_Std.npy')) - #self.mfcc_mean = np.load(os.path.join(statspath, 'mfccs_Mean.npy')) - self.rot6dpos_std = np.load(os.path.join(statspath, 'rot6dpos_Std.npy')) - self.rot6dpos_mean = np.load(os.path.join(statspath, 'rot6dpos_Mean.npy')) - self.vel_std = np.load(os.path.join(statspath, 'velrotpos_Std.npy')) - self.vel_mean = np.load(os.path.join(statspath, 'velrotpos_Mean.npy')) - - def gettestbatch(self, num_samples): - max_length = max(self.frames[:num_samples]) - max_length = max_length + self.window - max_length%self.window # increase length so it can be divisible by window - batch_audio = [] - batch_audio_rep = [] - batch_text = [] - vad_vals = [] - for i, _ in enumerate(self.takes[:num_samples]): - # Get audio file - audio_feats = [] - signal = np.zeros(int(max_length*self.sr/self.fps)) - signal_ = np.load(os.path.join(self.audiopath,self.takes[i][0]+'.npy')) - signal[:len(signal_)] = signal_ - - if self.use_wavlm: - # Cut Chunk - wavlm_reps_ = np.load(os.path.join(self.wavlm_rep_path,self.takes[i][0]+'.npy')) - audio_feat = np.zeros((max_length, wavlm_reps_.shape[1])) - audio_feat[:wavlm_reps_.shape[0],:] = wavlm_reps_ - - # Reshape - audio_feat = np.transpose(audio_feat, (1,0)) # [WAVLM_DIM, CHUNK_LEN] - audio_feat = np.expand_dims(audio_feat, 1) # [WAVLM_DIM, 1, CHUNK_LEN] - audio_feat = np.expand_dims(audio_feat, 0) # [1, WAVLM_DIM, 1, CHUNK_LEN] - audio_feats.append(audio_feat) - - if self.use_vad: - # Cut Chunk - vad_val_ = np.load(os.path.join(self.vad_path,self.takes[i][0]+'.npy')) - vad_val = np.zeros(max_length) - vad_val[:vad_val_.shape[0]] = vad_val_ # [CHUNK_LEN, ] - - # Reshape - vad_val = np.expand_dims(vad_val, 1) # [CHUNK_LEN, 1] - vad_val = np.transpose(vad_val, (1,0)) # [1, CHUNK_LEN] - vad_vals.append(vad_val) - - # Get text file - text_feats = [] - with open(os.path.join(self.textpath, self.takes[i][0]+'.tsv')) as tsv: - reader = csv.reader(tsv, delimiter='\t') - file = [ [float(word[0])*self.fps, float(word[1])*self.fps, word[2]] for word in reader] - - for chunk in range(int(max_length/self.window)): - if not self.use_wavlm: - # Get audio features - k = chunk*self.window*self.sr/self.fps - _, audio_feat = self.__compute_audiofeats(signal[int(k) : int(k+self.window*self.sr/self.fps)]) - audio_feats.append(audio_feat) - - # Get text - begin = self.search_time(file, chunk*self.window) - end = self.search_time(file, chunk*self.window + self.window) - text = [ word[-1] for word in file[begin: end] ] if begin or end else [] - text_feats.append(' '.join(text)) - - - audio_feats = np.concatenate(audio_feats, axis=-1) - end_audio = int(len(signal_)/self.sr*self.fps) - audio_feats[..., end_audio:] = np.zeros_like(audio_feats[..., end_audio:]) # zero audio feats after end of audio - batch_audio_rep.append(audio_feats) - batch_text.append(text_feats) - batch_audio.append(signal_) - - # Dummy motions and seed poses - feats = 1245 if self.name == 'genea2023+' else 498 - motion, seed_poses = np.zeros((self.window, feats)), np.zeros((self.n_seed_poses, feats)) #dummy - - # Attention: this is not collate-ready! - return motion, batch_text, self.window, batch_audio, batch_audio_rep, seed_poses, max_length, vad_vals, self.takes[:num_samples] - - def getvalbatch(self, num_takes, index): - # Get batch of data from the validation set, index refer to the chunk that you want to get - # Example: index = 0 and num_takes = 10 will return the first chunk of the first 10 takes - # index = 5 and num_takes = 30 will return the moment starting at 5*num_frames (120 by default) and ending at 6*num_frames of the first 30 takes - # num_takes = batch_size - batch = [] - assert num_takes <= len(self.takes) - # for each take - for take in np.arange(num_takes): - # get the corresponding index to call __getitem__ - sampleindex = self.samples_cumulative[take-1] + index if take != 0 else index - # check if the index is from the take and call __getitem__ - out = self.__getitem__(sampleindex) if sampleindex <= self.samples_per_file[take] + sampleindex - index else None - batch.append(out) - return batch - - def getjoints(self, toget= ['b_r_forearm','b_l_forearm']): - #toget = ['b_r_shoulder','b_r_arm','b_r_arm_twist','b_r_forearm','b_r_wrist_twist','b_r_wrist', - # 'b_l_shoulder','b_l_arm','b_l_arm_twist','b_l_forearm','b_l_wrist_twist','b_l_wrist'] - return {k:self.alljoints[k] for k in self.alljoints if k in toget} \ No newline at end of file diff --git a/data_loaders/gesture/data/ptbrdataset.py b/data_loaders/gesture/data/ptbrdataset.py new file mode 100644 index 0000000..842b331 --- /dev/null +++ b/data_loaders/gesture/data/ptbrdataset.py @@ -0,0 +1,379 @@ +import os +from torch.utils import data +import csv +import numpy as np +import torch + +class PTBRGesture(data.Dataset): + def __init__(self, + name, + split, + datapath='./dataset/BRG-Unicamp', + step=10, + window=120, + fps=30, + sr=22050, + n_seed_poses=10, + use_wavlm=False, + use_vad=False, + vadfromtext=False, + bvhreference='./dataset/BRG-Unicamp/motion/bvh_twh/id01_p01_e01_f01.bvh'): + + self.name = name + self.bvhreference = bvhreference + + # Hard-coded because it IS 30 fps + self.fps = 30 + + self.window = window + self.step = step + self.n_seed_poses = n_seed_poses + + # Get all paths + audio_path, wav_path, audio16k_path, vad_path, wavlm_path, \ + motion_path, pos_path, rot3d_path, rot6d_path = self.getpaths(datapath) + # Get takes from wav path and check if all paths have the same takes + takes = self.gettakes(wav_path, [audio16k_path, vad_path, wavlm_path, pos_path, rot3d_path, rot6d_path]) + print('Data integrity check passed.') + # Register takes as Take objects + self.__registered = False # flag to check if takes are registered + self.takes = self.registertakes(takes) + #print(f'{len(self.takes)} takes registered.') + + # Get metadata and register bvh start for audio alignment + with open(os.path.join(datapath, 'meta.csv'), 'r', encoding='utf-16') as f: + reader = csv.reader(f, delimiter=',') + self.metadata = [line for line in reader] + ratio = self.fps/120 # We are applying this ratio because the bvh_start was computed for 120 fps + self.metadata = [ [self.gettake(line[0]), np.floor(int(line[1])*ratio).astype(int)] for line in self.metadata ] + for line in self.metadata: + line[0].bvh_start = line[1] + + # Hard-coded split + val_p01 = [ 0, 5, 10, 15, 20, 25, 30, 35, 40 ] + val_p02 = [ 0, 3, 6, 9, 12, 15, 27, 34, 41, 48, 55, 62, 69] + if split == 'trn': + p01 = [i for i in np.arange(0, 45) if i not in val_p01] + p02 = [i for i in np.arange(0, 73) if i not in val_p02] + elif split == 'val': + p01 = val_p01 + p02 = val_p02 + else: + raise ValueError('Invalid split') + + # Categorize whole dataset (without unscripted samples) + self.filtered = [ + [ take for i, take in enumerate(self.filter_style_part_id(1,1,1)) if i in p01], # id01_p01_e01_fXX + [ take for i, take in enumerate(self.filter_style_part_id(2,1,1)) if i in p01], # id01_p01_e02_fXX + [ take for i, take in enumerate(self.filter_style_part_id(3,1,1)) if i in p01], # id01_p01_e03_fXX + [ take for i, take in enumerate(self.filter_style_part_id(1,1,2)) if i in p01], # id02_p01_e01_fXX + [ take for i, take in enumerate(self.filter_style_part_id(2,1,2)) if i in p01], # id02_p01_e02_fXX + [ take for i, take in enumerate(self.filter_style_part_id(3,1,2)) if i in p01], # id02_p01_e03_fXX + [ take for i, take in enumerate(self.filter_style_part_id(1,2,1)) if i in p02], # id01_p02_e01_fXX + [ take for i, take in enumerate(self.filter_style_part_id(2,2,1)) if i in p02], # id01_p02_e02_fXX + [ take for i, take in enumerate(self.filter_style_part_id(3,2,1)) if i in p02], # id01_p02_e03_fXX + [ take for i, take in enumerate(self.filter_style_part_id(1,2,2)) if i in p02], # id02_p02_e01_fXX + [ take for i, take in enumerate(self.filter_style_part_id(2,2,2)) if i in p02], # id02_p02_e02_fXX + [ take for i, take in enumerate(self.filter_style_part_id(3,2,2)) if i in p02], # id02_p02_e03_fXX + ] + + # Get whole dataset given split + self.takes = [ take for takelist in self.filtered for take in takelist ] + + # Load dataset + print('Loading dataset...') + self.rot6d = [ np.load(os.path.join(rot6d_path, take.name+'.npy')) for take in self.takes ] + self.rot3d = [ np.load(os.path.join(rot3d_path, take.name+'.npy')) for take in self.takes ] + self.pos = [ np.load(os.path.join(pos_path, take.name+'.npy')) for take in self.takes ] + self.wavlm = [ np.load(os.path.join(wavlm_path, take.name+'.npy')) for take in self.takes ] + self.vad = [ np.load(os.path.join(vad_path, take.name+'.npy')) for take in self.takes ] + self.audio16k = [ np.load(os.path.join(audio16k_path, take.name+'.npy')) for take in self.takes ] + self.velrot3d = [ np.diff(rot3d, axis=0, append=0) for rot3d in self.rot3d ] + self.velpos = [ np.diff(pos, axis=0, append=0) for pos in self.pos ] + print('Done') + + + self.frames = [] + for index, take in enumerate(self.takes): + assert self.rot6d[index].shape[0] == self.rot3d[index].shape[0] == self.pos[index].shape[0], f'{take.name} has different lengths' + self.rot6d[index] = self.rot6d[index][take.bvh_start:] + self.rot3d[index] = self.rot3d[index][take.bvh_start:] + self.pos[index] = self.pos[index][take.bvh_start:] + e = int(self.audio16k[index].shape[0]/16000*self.fps) + # Due to audio processing, vad and wavlm arrays may have one more frame than the audio arrays + e = np.min([e, self.vad[index].shape[0], self.wavlm[index].shape[0]]) + self.rot6d[index] = self.rot6d[index][:e] + self.rot3d[index] = self.rot3d[index][:e] + self.pos[index] = self.pos[index][:e] + self.frames.append(self.rot6d[index].shape[0]) + + #self.samples_per_file = [int(np.floor( (n - self.window ) / self.step)) for n in self.frames] + self.samples_per_file = [len( [i for i in np.arange(0, n, self.step) if i + self.window <= n] ) for n in self.frames] + self.samples_cumulative = [np.sum(self.samples_per_file[:i+1]) for i in range(len(self.samples_per_file))] + self.length = self.samples_cumulative[-1] + + # Load mean and std for normalization + self.rot6d_mean, self.rot6d_std, \ + self.rot3d_mean, self.rot3d_std, \ + self.velrot_mean, self.velrot_std, \ + self.pos_mean, self.pos_std, \ + self.velpos_mean, self.velpos_std = self.loadstats(datapath) + # Get rid of zeros in the std + self.rot6d_std = np.where(self.rot6d_std == 0, 1, self.rot6d_std) + self.rot3d_std = np.where(self.rot3d_std == 0, 1, self.rot3d_std) + self.velrot_std = np.where(self.velrot_std == 0, 1, self.velrot_std) + self.pos_std = np.where(self.pos_std == 0, 1, self.pos_std) + self.velpos_std = np.where(self.velpos_std == 0, 1, self.velpos_std) + + #TODO: Do the normalization here since the whole dataset is being loaded into memory + + # Compute some useful info for the dataset that will be used later (in the __getitem__ primarily) + self.r6d_shape = self.rot6d[0].shape[1] + self.r3d_shape = self.rot3d[0].shape[1] + self.pos_shape = self.pos[0].shape[1] + self.motio_feat_shape = self.r6d_shape + self.r3d_shape + 2*self.pos_shape # 2* for velocity + + def inv_transform(self, data): + return data * np.concatenate((self.rot6d_std, self.pos_std, self.velrot_std, self.velpos_std)) + np.concatenate((self.rot6d_mean, self.pos_mean, self.velrot_mean, self.velpos_mean)) + + def __getitem__(self, index): + # find the file that the sample belongs two + file_idx = np.searchsorted(self.samples_cumulative, index+1, side='left') + # find sample's index in the file + sample = index - self.samples_cumulative[file_idx-1] if file_idx > 0 else index + motion, seed_poses = self._getmotion(file_idx, sample) + audio = self._getaudio(file_idx, sample) + vad = self._getvad(file_idx, sample) + wavlm = self._getwavlm(file_idx, sample) + return motion, seed_poses, audio, vad, wavlm, [self.takes[file_idx].name, file_idx, sample, self.takes[file_idx].one_hot] + + def __dummysample__(self): + motion = np.zeros(shape=(self.window, self.motio_feat_shape)) + seed = np.zeros(shape=(self.n_seed_poses, self.motio_feat_shape)) + audio = np.zeros(shape=(int(self.window/30*16000))) + vad = np.zeros(shape=(self.window, 1)) + vad = np.transpose(vad, (1,0)) + wavlm = np.zeros(shape=(1, 768, 1, self.window)) + return motion, seed, audio, vad, wavlm, ['dummy', 0, 0, np.zeros(shape=(12))] + + def _getaudio(self, file_idx, sample): + # Get audio data from file_idx and sample + b = int(sample*self.step/30*16000) + e = int((sample*self.step+self.window)/30*16000) + return self.audio16k[file_idx][b:e] + + def _getvad(self, file_idx, sample): + # Get vad data from file_idx and sample + vad_vals = self.vad[file_idx][sample*self.step:sample*self.step+self.window] + assert vad_vals.shape[0] == self.window, f'VAD shape is {vad_vals.shape[0]} instead of {self.window}' + # Reshape + vad_vals = np.expand_dims(vad_vals, 1) # [CHUNK_LEN, 1] + vad_vals = np.transpose(vad_vals, (1,0)) # [1, CHUNK_LEN] + return vad_vals + + def _getwavlm(self, file_idx, sample): + # Get wavlm data from file_idx and sample + wavlm_reps = self.wavlm[file_idx][sample*self.step:sample*self.step+self.window] + assert wavlm_reps.shape[0] == self.window, f'WAVLM shape is {wavlm_reps.shape[0]} instead of {self.window}' + # Reshape + wavlm_reps = np.transpose(wavlm_reps, (1,0)) # [WAVLM_DIM, CHUNK_LEN] + wavlm_reps = np.expand_dims(wavlm_reps, 1) # [WAVLM_DIM, 1, CHUNK_LEN] + wavlm_reps = np.expand_dims(wavlm_reps, 0) # [1, WAVLM_DIM, 1, CHUNK_LEN] + return wavlm_reps + + + def _getmotion(self, file_idx, sample): + # Get motion data from file_idx and sample + motion = np.zeros(shape=(self.window, self.motio_feat_shape)) + b, e = sample*self.step, sample*self.step+self.window + # Get motion data from rot6d + motion[:, :self.r6d_shape] = (self.rot6d[file_idx][b:e] - self.rot6d_mean) / self.rot6d_std + # Get motion data from rot3d + cumulative = self.r6d_shape + #motion[:, cumulative:cumulative+self.r3d_shape] = (self.rot3d[file_idx][b:e] - self.rot3d_mean) / self.rot3d_std + # Get motion data from pos + #cumulative += self.r3d_shape + motion[:, cumulative:cumulative+self.pos_shape] = (self.pos[file_idx][b:e] - self.pos_mean) / self.pos_std + # Get vel from rot3d + cumulative += self.pos_shape + motion[:, cumulative:cumulative+self.r3d_shape] = (self.velrot3d[file_idx][b:e] - self.velrot_mean) / self.velrot_std + # Get vel from pos + cumulative += self.r3d_shape + motion[:, cumulative:cumulative+self.pos_shape] = (self.velpos[file_idx][b:e] - self.velpos_mean) / self.velpos_std + # Get seed poses + seed = np.zeros(shape=(self.n_seed_poses, self.motio_feat_shape)) + # TODO: This behaves wrongly for step = 5 for example + if b - self.n_seed_poses >= 0: + seed[:, :self.r6d_shape] = (self.rot6d[file_idx][b-self.n_seed_poses:b] - self.rot6d_mean) / self.rot6d_std + cumulative = self.r6d_shape + #seed[:, cumulative:cumulative+self.r3d_shape] = (self.rot3d[file_idx][b-self.n_seed_poses:b] - self.rot3d_mean) / self.rot3d_std + #cumulative += self.r3d_shape + seed[:, cumulative:cumulative+self.pos_shape] = (self.pos[file_idx][b-self.n_seed_poses:b] - self.pos_mean) / self.pos_std + cumulative += self.pos_shape + seed[:, cumulative:cumulative+self.r3d_shape] = (self.velrot3d[file_idx][b-self.n_seed_poses:b] - self.velrot_mean) / self.velrot_std + cumulative += self.r3d_shape + seed[:, cumulative:cumulative+self.pos_shape] = (self.velpos[file_idx][b-self.n_seed_poses:b] - self.velpos_mean) / self.velpos_std + return motion, seed + + def __len__(self): + return self.length + + def getpaths(self, datapath): + # Create a list of paths to the dataset folders + # and check if all paths exist + audio_path = os.path.join(datapath, 'audio') + wav_path = os.path.join(audio_path, 'wav') + audio16k_path = os.path.join(audio_path, 'npy16k') + vad_path = os.path.join(audio_path, 'vad') + wavlm_path = os.path.join(audio_path, 'wavlm') + motion_path = os.path.join(datapath, 'motion') + pos_path = os.path.join(motion_path, 'pos') + rot3d_path = os.path.join(motion_path, 'rot3d') + rot6d_path = os.path.join(motion_path, 'rot6d') + for path in [audio_path, wav_path, audio16k_path, vad_path, wavlm_path, motion_path, pos_path, rot3d_path, rot6d_path]: + assert os.path.exists(path), f'{path} does not exist' + return audio_path, wav_path, audio16k_path, vad_path, wavlm_path, motion_path, pos_path, rot3d_path, rot6d_path + + def gettakes(self, reference_path, paths): + # Create a list of take names based on the takes in the reference path + # Also checks if all paths have the same takes + takes = [] + for file in os.listdir(reference_path): + for path in paths: + assert file[:-4] in [os.path.basename(f)[:-4] for f in os.listdir(path)], f'{file} not found in {path}' + takes.append(file[:-4]) + #print(f'Found {len(takes)} takes.') + return takes + + def registertakes(self, takes): + # Sort takes and create Take objects + takes.sort() + class_takes = [] + for take in takes: + class_takes.append(Take(take)) + self.__registered = True + return class_takes + + def filter_id(self, id, include_unscripted=False): + # Get list of takes from id + assert self.__registered, 'Takes are not registered' + takelist = [] + for take in self.takes: + if take.id == id: + if include_unscripted: + takelist.append(take) + else: + if take.type == 'scripted': + takelist.append(take) + assert takelist, f'No takes found for id {id}' + return takelist + + def filter_style(self, style): + # Get list of takes from style + assert self.__registered, 'Takes are not registered' + takelist = [take for take in self.takes if take.style == style] + assert takelist, f'No takes found for style {style}' + return takelist + + def filter_part_id(self, part, id, include_unscripted=False): + # Get list of takes from id, then remove takes that are not from part + assert self.__registered, 'Takes are not registered' + takelist = self.filter_id(id, include_unscripted) + takelist = [take for take in takelist if take.part == part] + assert takelist, f'No takes found for id {id} and part {part}' + return takelist + + def filter_style_part_id(self, style, part, id, include_unscripted=False): + # Get list of takes from id, then remove takes that are not from part and style + assert self.__registered, 'Takes are not registered' + takelist = self.filter_part_id(part, id, include_unscripted) + takelist = [take for take in takelist if take.style == style] + assert takelist, f'No takes found for id {id}, part {part} and style {style}' + return takelist + + def gettake(self, name): + # Get take from name + assert self.__registered, 'Takes are not registered' + for take in self.takes: + if take.name == name: + return take + assert False, f'Take {name} not found' + + def loadstats(self, statspath): + rot6d_mean = np.load(os.path.join(statspath, 'rot6d_Mean.npy')) + rot6d_std = np.load(os.path.join(statspath, 'rot6d_Std.npy')) + rot3d_mean = np.load(os.path.join(statspath, 'rot3d_Mean.npy')) + rot3d_std = np.load(os.path.join(statspath, 'rot3d_Std.npy')) + velrot_mean = np.load(os.path.join(statspath, 'velrot_Mean.npy')) + velrot_std = np.load(os.path.join(statspath, 'velrot_Std.npy')) + pos_mean = np.load(os.path.join(statspath, 'pos_Mean.npy')) + pos_std = np.load(os.path.join(statspath, 'pos_Std.npy')) + velpos_mean = np.load(os.path.join(statspath, 'velpos_Mean.npy')) + velpos_std = np.load(os.path.join(statspath, 'velpos_Std.npy')) + return rot6d_mean, rot6d_std, rot3d_mean, rot3d_std, velrot_mean, velrot_std, pos_mean, pos_std, velpos_mean, velpos_std + + + +class Take(): + def __init__(self, name): + self.name = name + splited = name.split('_') + self.id = self.getint(splited[0]) + self.type = 'unscripted' if splited[1] == 'un' else 'scripted' + if self.type == 'scripted': + self.part = self.getint(splited[1]) + self.phrase = self.getint(splited[3]) + else: + self.part = None + self.phrase = None + self.style = self.getint(splited[2]) + self.bvh_start = 0 + self.one_hot, self.class_label = self.one_hot_encode() + + def getint(self, string): + # Get integer from string. Note: 'something01' -> 1 + return int(''.join(filter(str.isdigit, string))) + + def one_hot_encode(self): + # One-hot encode take + if self.id == 1: + if (self.part == 1 and self.style == 1) or (self.type == 'unscripted' and self.style == 1): + # ID 01, extroverted, scripted or unscripted + class_label = 1 + elif (self.part == 1 and self.style == 2) or (self.type == 'unscripted' and self.style == 2): + # ID 01, introverted, scripted or unscripted + class_label = 2 + elif (self.part == 1 and self.style == 3) or (self.type == 'unscripted' and self.style == 3): + # ID 01, neutral, scripted or unscripted + class_label = 3 + elif self.part == 2 and self.style == 1: + class_label = 4 + elif self.part == 2 and self.style == 2: + class_label = 5 + elif self.part == 2 and self.style == 3: + class_label = 6 + elif self.id == 2: + if (self.part == 1 and self.style == 1) or (self.type == 'unscripted' and self.style == 1): + # ID 01, extroverted, scripted or unscripted + class_label = 7 + elif (self.part == 1 and self.style == 2) or (self.type == 'unscripted' and self.style == 2): + # ID 01, introverted, scripted or unscripted + class_label = 8 + elif (self.part == 1 and self.style == 3) or (self.type == 'unscripted' and self.style == 3): + # ID 01, neutral, scripted or unscripted + class_label = 9 + elif self.part == 2 and self.style == 1: + class_label = 10 + elif self.part == 2 and self.style == 2: + class_label = 11 + elif self.part == 2 and self.style == 3: + class_label = 12 + else: + raise ValueError(f'Invalid part {self.part} and style {self.style} for take {self.name}') + else: + raise ValueError(f'Invalid id {self.id}') + token = np.zeros(shape=(12)) + token[class_label - 1] = 1 + return token, class_label + + \ No newline at end of file diff --git a/data_loaders/gesture/scripts/genea_prep.py b/data_loaders/gesture/scripts/genea_prep.py deleted file mode 100644 index 8c8dd13..0000000 --- a/data_loaders/gesture/scripts/genea_prep.py +++ /dev/null @@ -1,154 +0,0 @@ -from argparse import ArgumentParser -import os -import numpy as np -import librosa -import torch -from wavlm.WavLM import WavLM, WavLMConfig -import torch.nn.functional as F -from data_loaders.gesture.scripts.motion_process import bvh2representations2 -import bvhsdk -from tqdm import tqdm - -def main(args): - #paths_check(args.data_dir) - assert args.split in ['all', 'trn', 'tst', 'val'], f"Split {args.split} not recognized. Options: \'all\', \'trn\', \'tst\', \'val\'" # Check if user is trying to process a split that does not exist - splits = [args.split] if args.split != 'all' else ['trn', 'tst', 'val'] - assert args.step in ['all', 'bvh', 'wav', 'wavlm'], f"Step {args.step} not recognized. Options: \'all\', \'bvh\', \'wav\', \'wavlm\'" # Check if user is trying to process a step that does not exist - steps = [args.step] if args.step != 'all' else ['bvh', 'wav', 'wavlm'] - print('WARNING: Running all steps and all splits will take a long time.') - print('Processing splits: ', splits) - print('Processing steps: ', steps) - for split in splits: - print(f'Processing {split} split') - if 'bvh' in steps and split != 'tst': - print(f'Processing bvh for {split} split') - r6p, rp = process_bvh(args.data_dir, split) - statspath = os.path.join(args.data_dir, split, 'main-agent') - print(f'Computing mean and std for {split} split') - compute_meanstd(r6p, os.path.join(statspath, 'rot6dpos'), npstep=5) - compute_meanstd(rp, os.path.join(statspath, 'rotpos'), npstep=5) - compute_meanstd(rp, os.path.join(statspath, 'velrotpos'), npstep=5, vel=True) - if 'wav' in steps: - print(f'Processing wav for {split} split') - process_wav(args.data_dir, split) - if 'wavlm' in steps: - print(f'Processing wavlm for {split} split') - process_wavlm(args.data_dir, split) - -def process_bvh(path, split): - sourcepath = os.path.join(path, split, 'main-agent', 'bvh') - savepathrot6d = os.path.join(path, split, 'main-agent', 'motion_npy_rot6dpos') - savepathrot = os.path.join(path, split, 'main-agent', 'motion_npy_rotpos') - assert not os.path.exists(savepathrot6d), f"motion_npy_rot6dpos already exists in {savepathrot6d}. Delete it to process again." - assert not os.path.exists(savepathrot), f"motion_npy_rotpos already exists in {savepathrot}. Delete it to process again." - if not os.path.exists(savepathrot6d): - os.mkdir(savepathrot6d) - if not os.path.exists(savepathrot): - os.mkdir(savepathrot) - for file in tqdm(os.listdir(sourcepath)): - #if not os.path.exists(os.path.join(savepathrot6d, file[:-4] + '.npy')) or not os.path.exists(os.path.join(savepathrot, file[:-4] + '.npy')): - anim = bvhsdk.ReadFile(os.path.join(sourcepath, file)) - rot6dpos, rotpos = bvh2representations2(anim) - np.save(os.path.join(savepathrot6d, file[:-4]), rot6dpos) - np.save(os.path.join(savepathrot, file[:-4]), rotpos) - return savepathrot6d, savepathrot - -def compute_meanstd(path, savepath, npstep=1, vel=False): - all_data = [] - for f in os.listdir(path)[::npstep]: - data = np.load(os.path.join(path, f)) - if vel: - data = data[1:,:] - data[:-1,:] - data[0,:] = np.zeros(data.shape[1]) - all_data.append(data) - all_data = np.vstack(all_data) - mean = np.mean(all_data, axis=0) - std = np.std(all_data, axis=0) - np.save(savepath + '_Mean.npy', mean) - np.save(savepath + '_Std.npy', std) - - -def process_wav(path, split, sr=16000): - sourcepath = os.path.join(path, split, 'main-agent', 'wav') - savepath = os.path.join(path, split, 'main-agent', 'audio16k_npy') - assert not os.path.exists(savepath), f"audio_16k_npy already exists in {savepath}. Delete it to process again." - if not os.path.exists(savepath): - os.mkdir(savepath) - for file in tqdm(os.listdir(sourcepath)): - #if not os.path.exists(os.path.join(savepath, file[:-4] + '.npy')): - signal, _sr = librosa.load(os.path.join(sourcepath, file), mono=True, sr=sr) - assert _sr == sr - np.save(os.path.join(savepath, file[:-4]+'.npy'), signal) - return savepath - -def process_wavlm(path, split): - wavlm_layer = 11 - fps=30 - sr=16000 - device = 'cuda' if torch.cuda.is_available() else 'cpu' - sourcepath = os.path.join(path, split, 'main-agent', 'audio16k_npy') - savepath = os.path.join(path, split, 'main-agent', 'wavlm_representations') - #assert os.path.exists(sourcepath), f"audio16k_npy not found in {sourcepath}. Required to process wavlm representations, make sure wav files were processed first." - #assert os.path.exists(savepath), f"wavlm model directory not found in current directory." - if not os.path.exists(savepath): - os.mkdir(savepath) - checkpoint = torch.load('./wavlm/WavLM-Base+.pt') - wavlm_cfg = WavLMConfig(checkpoint['cfg']) - wavlm = WavLM(wavlm_cfg) - #wavlm.to(device) - wavlm.load_state_dict(checkpoint['model']) - wavlm.eval() - for file in tqdm(os.listdir(sourcepath)): - if not os.path.exists(os.path.join(savepath, file)): - audio_path = os.path.join(sourcepath, file) - # Load with Numpy - signal = np.load(audio_path) - # Set to model innput format - signal = torch.tensor(signal).unsqueeze(0)#.to(device) - # Normalize - if wavlm_cfg.normalize: - signal_norm = torch.nn.functional.layer_norm(signal , signal.shape) - else: - signal_norm = signal - # Run Model (rep=Desired Layer, layer_results=all layers) - rep, layer_results = wavlm.extract_features(signal_norm, output_layer=wavlm_layer, ret_layer_results=True)[0] - layer_reps = [x.transpose(0, 1) for x, _ in layer_results] # fix shape - # Get Number of Seconds of Audio File - n_secs = signal.shape[1] / sr - # Get Number of poses equivalent to audio file duration, given fps (alignment len) - n_pose = n_secs * fps - # Interpolate number of representations to match number of poses corresponding to audio file - interp_reps = F.interpolate(rep.transpose(1, 2), size=int(n_pose), align_corners=True, mode='linear') - # Prepare to save - interp_reps = interp_reps.squeeze(0).transpose(0,1).cpu().detach().data.cpu().numpy() - # Double check dimension - assert (interp_reps.shape[0] == int(np.ceil(n_pose)) or interp_reps.shape[0] == int(np.floor(n_pose))) - np.save(os.path.join(savepath, file), interp_reps) - - -def paths_check(data_dir): - # First check if everything is in place - for split in ['trn', 'tst', 'val']: - split_dir = os.path.join(data_dir, split) - assert os.path.exists(split_dir), f"Split {split} not found in {data_dir}" - main_agent_dir = os.path.join(split_dir, 'main-agent') - assert os.path.exists(main_agent_dir), f"main_agent not found in {split_dir}" - tsv_dir = os.path.join(main_agent_dir, 'tsv') - wav_dir = os.path.join(main_agent_dir, 'wav') - assert os.path.exists(tsv_dir), f"tsv not found in {main_agent_dir}" - assert os.path.exists(wav_dir), f"wav not found in {main_agent_dir}" - assert len(os.listdir(tsv_dir)) == len(os.listdir(wav_dir)), f"tsv and wav have different number of files in {main_agent_dir}" - if split != 'tst': - bvh_dir = os.path.join(main_agent_dir, 'bvh') - assert os.path.exists(bvh_dir), f"bvhs not found in {main_agent_dir}" - assert len(os.listdir(tsv_dir)) == len(os.listdir(bvh_dir)), f"tsv and bvh have different number of files in {main_agent_dir}" - print('Data paths and files seem correct') - - -if __name__ == '__main__': - parser = ArgumentParser() - parser.add_argument('--data_dir', type=str, default='./dataset/Genea2023', help='path to the dataset directory') - parser.add_argument('--split', type=str, default='all', help='Which split to process. Use \'all\' to process all splits') - parser.add_argument('--step', type=str, default='all', help='Which step to process. Use \'all\' to process all steps. Options: \'bvh\', \'wav\', \'wavlm\'') - args = parser.parse_args() - main(args) \ No newline at end of file diff --git a/data_loaders/gesture/scripts/genea_prep_vad.py b/data_loaders/gesture/scripts/genea_prep_vad.py deleted file mode 100644 index cff5fc6..0000000 --- a/data_loaders/gesture/scripts/genea_prep_vad.py +++ /dev/null @@ -1,72 +0,0 @@ -from argparse import ArgumentParser -import os -import numpy as np -import torch -from speechbrain.inference.VAD import VAD -import torchaudio -from scipy.signal import resample -from tqdm import tqdm - - - -def main(args): - #paths_check(args.data_dir) - assert args.split in ['all', 'trn', 'tst', 'val'], f"Split {args.split} not recognized. Options: \'all\', \'trn\', \'tst\', \'val\'" # Check if user is trying to process a split that does not exist - splits = [args.split] if args.split != 'all' else ['trn', 'tst', 'val'] - - print('Processing VAD.') - print('Processing splits: ', splits) - for split in splits: - print(f'Processing vad for {split} split') - process_vad(args.data_dir, split) - - - -def process_vad(path, split): - sr=16000 - fps=30 - sourcepath = os.path.join(path, split, 'main-agent', 'wav') - savepathrot = os.path.join(path, split, 'main-agent', 'vad') - _VAD = VAD.from_hparams(source= "speechbrain/vad-crdnn-libriparty", savedir= os.path.join(path, '..','..','speechbrain', 'pretrained_models', 'vad-crdnn-libriparty')) - #assert not os.path.exists(savepathrot), f"vad already exists in {savepathrot}. Delete it to process again." - if not os.path.exists(savepathrot): - os.mkdir(savepathrot) - # VAD requires a torch tensor with sample rate = 16k. This process saves a temporary wav file with 16k sr. It can be deleted after processing. - for file in tqdm(os.listdir(sourcepath)): - audio, old_sr = torchaudio.load(os.path.join(sourcepath,file)) - audio = torchaudio.functional.resample(audio, orig_freq=44100, new_freq=sr) - tmpfile = "tmp.wav" - torchaudio.save( - tmpfile , audio, sr - ) - boundaries = _VAD.get_speech_prob_file(audio_file=tmpfile, large_chunk_size=4, small_chunk_size=0.2) - boundaries = resample(boundaries[0,:,0], int(boundaries.shape[1]*fps/100)) - boundaries[boundaries>=0.5] = 1 - boundaries[boundaries<0.5] = 0 - np.save(os.path.join(savepathrot, file[:-4]+'.npy'), boundaries) - -def paths_check(data_dir): - # First check if everything is in place - for split in ['trn', 'tst', 'val']: - split_dir = os.path.join(data_dir, split) - assert os.path.exists(split_dir), f"Split {split} not found in {data_dir}" - main_agent_dir = os.path.join(split_dir, 'main-agent') - assert os.path.exists(main_agent_dir), f"main_agent not found in {split_dir}" - tsv_dir = os.path.join(main_agent_dir, 'tsv') - wav_dir = os.path.join(main_agent_dir, 'wav') - assert os.path.exists(tsv_dir), f"tsv not found in {main_agent_dir}" - assert os.path.exists(wav_dir), f"wav not found in {main_agent_dir}" - assert len(os.listdir(tsv_dir)) == len(os.listdir(wav_dir)), f"tsv and wav have different number of files in {main_agent_dir}" - if split != 'tst': - bvh_dir = os.path.join(main_agent_dir, 'bvh') - assert os.path.exists(bvh_dir), f"bvhs not found in {main_agent_dir}" - assert len(os.listdir(tsv_dir)) == len(os.listdir(bvh_dir)), f"tsv and bvh have different number of files in {main_agent_dir}" - print('Data paths and files seem correct') - - -if __name__ == '__main__': - parser = ArgumentParser() - parser.add_argument('--data_dir', type=str, default='./dataset/Genea2023', help='path to the dataset directory') - parser.add_argument('--split', type=str, default='all', help='Which split to process. Use \'all\' to process all splits') - args = parser.parse_args() - main(args) \ No newline at end of file diff --git a/data_loaders/gesture/scripts/motion_process.py b/data_loaders/gesture/scripts/motion_process.py index 2c9b819..1898c6a 100644 --- a/data_loaders/gesture/scripts/motion_process.py +++ b/data_loaders/gesture/scripts/motion_process.py @@ -23,11 +23,10 @@ def split_pos_rot(dataset, data): idx_positions, idx_rotations = get_indexes(dataset) return data[..., idx_positions], data[..., idx_rotations] -def rot6d_to_euler(data): +def rot6d_to_euler(data, n_joints): # Convert numpy array to euler angles # Shape expected [num_samples(bs), 1, chunk_len, 498] # Output shape [num_samples(bs) * chunk_len, n_joints, 3] - n_joints = 83 assert data.shape[-1] == n_joints*6 sample_rot = data.view(data.shape[:-1] + (-1, 6)) # [num_samples(bs), 1, chunk_len, n_joints, 6] sample_rot = geometry.rotation_6d_to_matrix(sample_rot) # [num_samples(bs), 1, chunk_len, n_joints, 3, 3] @@ -88,8 +87,9 @@ def np_matrix_to_rotation_6d(matrix: np.ndarray) -> np.ndarray: return matrix[..., :2, :].copy().reshape(6) def bvh2representations2(anim: bvhsdk.Animation): - # Converts bvh to two representations: 6d rotations and 3d positions - # And 3d rotations (euler angles) and 3d positions + # Converts bvh to two representations: + # 1st - 6d rotations and 3d positions + # 2nd - And 3d rotations (euler angles) and 3d positions # The 3d positions of both representations are the same (duplicated data) # This representation is used in the genea challenge njoints = len(anim.getlistofjoints()) @@ -108,7 +108,10 @@ def bvh2representations2(anim: bvhsdk.Animation): return npyrot6dpos, npyrotpos def bvh2representations1(anim: bvhsdk.Animation): - # Converts bvh to three representations: 6d rotations, 3d positions (euler angles) and 3d positions + # Converts bvh to three representations: + # 1st - 6d rotations + # 2nd - 3d rotations (euler angles) and + # 3rd - 3d positions njoints = len(anim.getlistofjoints()) npyrot6d = np.empty(shape=(anim.frames, 6*njoints)) npyrot = np.empty(shape=(anim.frames, 3*njoints)) diff --git a/data_loaders/gesture/scripts/ptbrgesture_prep.py b/data_loaders/gesture/scripts/ptbrgesture_prep.py new file mode 100644 index 0000000..2d01411 --- /dev/null +++ b/data_loaders/gesture/scripts/ptbrgesture_prep.py @@ -0,0 +1,190 @@ +from argparse import ArgumentParser +import os +import numpy as np +import librosa +import torch +from wavlm.WavLM import WavLM, WavLMConfig +import torch.nn.functional as F +from data_loaders.gesture.scripts.motion_process import bvh2representations1 +import bvhsdk +from tqdm import tqdm + +def main(args): + bvhpath, wavpath, rot6dpath, rot3dpath, pospath, npy16k, wavlmpath = paths_get_and_check(args.data_dir) + takes = takes_get_and_check(bvhpath, wavpath) + assert args.step in ['all', 'bvh', 'wav', 'wavlm'], f"Step {args.step} not recognized. Options: \'all\', \'bvh\', \'wav\', \'wavlm\'" # Check if user is trying to process a step that does not exist + steps = [args.step] if args.step != 'all' else ['bvh', 'wav', 'wavlm'] + if 'bvh' in steps: + print('Processing bvh') + process_bvh(bvhpath, rot6dpath, rot3dpath, pospath, takes) + print('Computing mean and std') + compute_meanstd(rot6dpath, os.path.join(args.data_dir, 'rot6d'), npstep=1) + compute_meanstd(rot3dpath, os.path.join(args.data_dir, 'rot3d'), npstep=1) + compute_meanstd(pospath, os.path.join(args.data_dir, 'pos'), npstep=1) + compute_meanstd(rot3dpath, os.path.join(args.data_dir, 'velrot'), npstep=1, vel=True) + compute_meanstd(pospath, os.path.join(args.data_dir, 'velpos'), npstep=1, vel=True) + if 'wav' in steps: + print('Processing wav') + process_wav(wavpath, npy16k) + if 'wavlm' in steps: + print('Processing wavlm') + process_wavlm(npy16k, wavlmpath) + +def process_wavlm(sourcepath, savepath): + wavlm_layer = 11 + fps=30 + sr=16000 + device = 'cuda' if torch.cuda.is_available() else 'cpu' + assert os.path.exists(sourcepath), f"audio16k_npy not found in {sourcepath}. Required to process wavlm representations, make sure wav files were processed first." + #assert not os.path.exists(savepath), f"wavlm model directory already exists." + if not os.path.exists(savepath): + os.mkdir(savepath) + checkpoint = torch.load('./wavlm/WavLM-Base+.pt') + wavlm_cfg = WavLMConfig(checkpoint['cfg']) + wavlm = WavLM(wavlm_cfg) + wavlm.to(device) + wavlm.load_state_dict(checkpoint['model']) + wavlm.eval() + with torch.no_grad(): + for file in tqdm(os.listdir(sourcepath)): + if not os.path.exists(os.path.join(savepath, file)): + audio_path = os.path.join(sourcepath, file) + # Load with Numpy + signal = np.load(audio_path) + if signal.shape[0] < 960000: #1 minute + interp_reps = getwavlmrep(signal, wavlm, device, wavlm_layer, wavlm_cfg, sr=sr, fps=fps) + else: #Break the file into smaller chunks to avoid memory issues + interp_reps = [] + for subsignal in np.array_split(signal, np.ceil(signal.shape[0]/960000)): + subinterp_reps = getwavlmrep(subsignal, wavlm, device, wavlm_layer, wavlm_cfg, sr=sr, fps=fps) + interp_reps.append(subinterp_reps) + interp_reps = np.vstack(interp_reps) + np.save(os.path.join(savepath, file), interp_reps) + +def getwavlmrep(signal, wavlm, device, wavlm_layer, wavlm_cfg, sr=16000, fps=30): + # Set to model innput format + signal = torch.tensor(signal).unsqueeze(0).to(device) + # Normalize + if wavlm_cfg.normalize: + signal_norm = torch.nn.functional.layer_norm(signal , signal.shape) + else: + signal_norm = signal + # Run Model (rep=Desired Layer, layer_results=all layers) + rep, layer_results = wavlm.extract_features(signal_norm, output_layer=wavlm_layer, ret_layer_results=True)[0] + layer_reps = [x.transpose(0, 1) for x, _ in layer_results] # fix shape + # Get Number of Seconds of Audio File + n_secs = signal.shape[1] / sr + # Get Number of poses equivalent to audio file duration, given fps (alignment len) + n_pose = n_secs * fps + # Interpolate number of representations to match number of poses corresponding to audio file + interp_reps = F.interpolate(rep.transpose(1, 2), size=int(n_pose), align_corners=True, mode='linear') + # Prepare to save + interp_reps = interp_reps.squeeze(0).transpose(0,1).cpu().detach().data.cpu().numpy() + # Double check dimension + assert (interp_reps.shape[0] == int(np.ceil(n_pose)) or interp_reps.shape[0] == int(np.floor(n_pose))) + return interp_reps + +def process_wav(sourcepath, savepath, sr=16000): + #assert not os.path.exists(savepath), f"audio_16k_npy already exists in {savepath}. Delete it to process again." + if not os.path.exists(savepath): + os.mkdir(savepath) + for file in tqdm(os.listdir(sourcepath)): + if not os.path.exists(os.path.join(savepath, file[:-4] + '.npy')): + signal, _sr = librosa.load(os.path.join(sourcepath, file), mono=True, sr=sr) + assert _sr == sr + np.save(os.path.join(savepath, file[:-4]+'.npy'), signal) + return savepath + +def process_bvh(bvhpath, rot6dpath, rot3dpath, pospath, takes): + # Create paths + for path in [rot6dpath, rot3dpath, pospath]: + if not os.path.exists(path): + os.mkdir(path) + for file in tqdm(os.listdir(bvhpath)): + if not os.path.exists(os.path.join(rot6dpath, file[:-4] + '.npy')): + anim = bvhsdk.ReadFile(os.path.join(bvhpath, file)) + npyrot6d, npyrot, npypos = bvh2representations1(anim) + np.save(os.path.join(rot6dpath, file[:-4]), npyrot6d) + np.save(os.path.join(rot3dpath, file[:-4]), npyrot) + np.save(os.path.join(pospath, file[:-4]), npypos) + +def compute_meanstd(path, savepath, npstep=1, vel=False): + all_data = [] + for f in os.listdir(path)[::npstep]: + data = np.load(os.path.join(path, f)) + if vel: + data = data[1:,:] - data[:-1,:] + data[0,:] = np.zeros(data.shape[1]) + all_data.append(data) + all_data = np.vstack(all_data) + mean = np.mean(all_data, axis=0) + std = np.std(all_data, axis=0) + np.save(savepath + '_Mean.npy', mean) + np.save(savepath + '_Std.npy', std) + +def paths_get_and_check(data_dir): + assert os.path.exists(data_dir), 'Data directory does not exist' + motionpath = os.path.join(data_dir, 'motion') + audiopath = os.path.join(data_dir, 'audio') + bvhpath = os.path.join(motionpath, 'bvh_twh') + wavpath = os.path.join(audiopath, 'wav') + assert os.path.exists(motionpath), 'Motion directory does not exist' + assert os.path.exists(audiopath), 'Audio directory does not exist' + assert os.path.exists(bvhpath), 'BVH directory does not exist' + assert os.path.exists(wavpath), 'WAV directory does not exist' + rot6dpath = os.path.join(motionpath, 'rot6d') + rot3dpath = os.path.join(motionpath, 'rot3d') + pospath = os.path.join(motionpath, 'pos') + npy16k = os.path.join(audiopath, 'npy16k') + wavlmpath = os.path.join(audiopath, 'wavlm') + #assert not os.path.exists(rot6dpath), 'rot6d directory already exists' + #assert not os.path.exists(rot3dpath), 'rot3d directory already exists' + #assert not os.path.exists(pospath), 'pos directory already exists' + #assert not os.path.exists(npy16k), 'npy16k directory already exists' + #assert not os.path.exists(wavlmpath), 'wavlm directory already exists' + return bvhpath, wavpath, rot6dpath, rot3dpath, pospath, npy16k, wavlmpath + +def takes_get_and_check(bvhpath, wavpath): + takes = [] + assert len(os.listdir(bvhpath)) == len(os.listdir(wavpath)), 'Number of BVH files does not match number of WAV files' + for take in os.listdir(bvhpath): + takes.append(take[:-4]) + for take in os.listdir(wavpath): + assert take[:-4] in takes, 'WAV file {} does not have a corresponding BVH file'.format(take[:-4]) + return takes + +def addBodyWorld(): + """ + Adds body and world rotation and translation to the bvh files. + This is NOT a data processing method, it is a BVH preparation method. This model requires BVH files to have a body world joint as the root joint. + If you are trying to use this model with a new dataset, you will need to add a body world joint to your BVH files. + This method is provided as an example of how to do it. + """ + #a = bvhsdk.ReadFile('.\\dataset\\BRG-Unicamp\\motion\\bvh_twh\\newtess_id01_p01_e01_f01.bvh') + b = bvhsdk.ReadFile('.\\dataset\\Genea2023\\trn\\main-agent\\bvh\\trn_2023_v0_000_main-agent.bvh') + + path = '.\\dataset\\BRG-Unicamp\\motion\\bvh_twh' + savepath = '.\\dataset\\BRG-Unicamp\\motion\\bvh_twh\\with_body_world' + for f in os.listdir(path): + if f.endswith('.bvh'): + a = bvhsdk.ReadFile(os.path.join(path, f)) + + for j1,j2 in zip(a.getlistofjoints(), b.getlistofjoints()[1:]): + j2.rotation = j1.rotation + j2.translation = j1.translation + + b.frametime = a.frametime + b.root.translation = b.root.children[0].translation*[1,0,1] + b.root.children[0].translation *= [0,1,0] + b.frames = a.frames + b.root.rotation = np.zeros(shape=(b.frames, 3)) + + bvhsdk.WriteBVH(b, path=savepath, name=f.replace('.bvh', ''), frametime=b.frametime) + + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument('--data_dir', type=str, default='./dataset/BRG-Unicamp', help='path to the dataset directory') + parser.add_argument('--step', type=str, default='all', help='Which step to process. Use \'all\' to process all steps. Options: \'bvh\', \'wav\', \'wavlm\'') + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/data_loaders/gesture/scripts/ptbrgesture_prep_vad.py b/data_loaders/gesture/scripts/ptbrgesture_prep_vad.py new file mode 100644 index 0000000..01308e4 --- /dev/null +++ b/data_loaders/gesture/scripts/ptbrgesture_prep_vad.py @@ -0,0 +1,47 @@ +from argparse import ArgumentParser +import os +import numpy as np +import torch +from speechbrain.pretrained import VAD +import torchaudio +from scipy.signal import resample +from tqdm import tqdm + + + +def main(args): + #paths_check(args.data_dir) + + sourcepath = os.path.join(args.data_dir, 'audio', 'wav') + savepath = os.path.join(args.data_dir, 'audio', 'vad') + print('Processing VAD.') + process_vad(sourcepath, savepath, args.data_dir) + + +def process_vad(sourcepath, savepath, datadir): + sr=16000 + fps=30 + _VAD = VAD.from_hparams(source= "speechbrain/vad-crdnn-libriparty", savedir= os.path.join(datadir, '..','..','speechbrain', 'pretrained_models', 'vad-crdnn-libriparty')) + #assert not os.path.exists(savepathrot), f"vad already exists in {savepathrot}. Delete it to process again." + if not os.path.exists(savepath): + os.mkdir(savepath) + # VAD requires a torch tensor with sample rate = 16k. This process saves a temporary wav file with 16k sr. It can be deleted after processing. + for file in tqdm(os.listdir(sourcepath)): + if not os.path.exists(os.path.join(savepath, file[:-4]+'.npy')): + audio, old_sr = torchaudio.load(os.path.join(sourcepath,file)) + audio = torchaudio.functional.resample(audio, orig_freq=old_sr, new_freq=sr) + tmpfile = "tmp.wav" + torchaudio.save( + tmpfile , audio, sr + ) + boundaries = _VAD.get_speech_prob_file(audio_file=tmpfile, large_chunk_size=4, small_chunk_size=0.2) + boundaries = resample(boundaries[0,:,0], int(boundaries.shape[1]*fps/100)) + boundaries[boundaries>=0.5] = 1 + boundaries[boundaries<0.5] = 0 + np.save(os.path.join(savepath, file[:-4]+'.npy'), boundaries) + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument('--data_dir', type=str, default='./dataset/BRG-Unicamp', help='path to the dataset directory') + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/data_loaders/get_data.py b/data_loaders/get_data.py index 30f2053..a0f767e 100644 --- a/data_loaders/get_data.py +++ b/data_loaders/get_data.py @@ -1,28 +1,33 @@ from torch.utils.data import DataLoader from data_loaders.tensors import collate as all_collate -from data_loaders.tensors import gg_collate +from data_loaders.tensors import gg_collate, ptbr_collate def get_dataset_class(name): if name in ["genea2023", "genea2023+"]: from data_loaders.gesture.data.dataset import Genea2023 return Genea2023 + elif name in ['ptbr']: + from data_loaders.gesture.data.ptbrdataset import PTBRGesture + return PTBRGesture else: raise ValueError(f'Unsupported dataset name [{name}]') def get_collate_fn(name, hml_mode='train'): if name in ["genea2023", "genea2023+"]: return gg_collate + elif name in ['ptbr']: + return ptbr_collate else: raise ValueError(f'Unsupported dataset name [{name}]') -def get_dataset(name, data_dir,num_frames, seed_poses, step, use_wavlm, use_vad, vadfromtext, split='trn', hml_mode='train', ): +def get_dataset(name, num_frames, seed_poses, step, use_wavlm, use_vad, vadfromtext, split='trn', hml_mode='train', ): DATA = get_dataset_class(name) - dataset = DATA(name=name, datapath=data_dir, split=split, window=num_frames, n_seed_poses=seed_poses, step=step, use_wavlm=use_wavlm, use_vad=use_vad, vadfromtext=vadfromtext) + dataset = DATA(name=name, split=split, window=num_frames, n_seed_poses=seed_poses, step=step, use_wavlm=use_wavlm, use_vad=use_vad, vadfromtext=vadfromtext) return dataset -def get_dataset_loader(name, data_dir, batch_size, num_frames, step, use_wavlm, use_vad, vadfromtext, split='trn', hml_mode='train', seed_poses=10): - dataset = get_dataset(name, data_dir,num_frames, seed_poses, step, use_wavlm, use_vad, vadfromtext, split, hml_mode) +def get_dataset_loader(name, batch_size, num_frames, step, use_wavlm, use_vad, vadfromtext, split='trn', hml_mode='train', seed_poses=10): + dataset = get_dataset(name, num_frames, seed_poses, step, use_wavlm, use_vad, vadfromtext, split, hml_mode) collate = get_collate_fn(name, hml_mode) shuffled = True if split == 'trn' else False diff --git a/data_loaders/tensors.py b/data_loaders/tensors.py index 1688f69..8788564 100644 --- a/data_loaders/tensors.py +++ b/data_loaders/tensors.py @@ -54,9 +54,10 @@ def collate(batch): vadbatch = [b['vad'] for b in notnone_batches] vadbatch = torch.cat(vadbatch, dim=0) cond['y'].update({'vad': vadbatch}) - if 'takename' in notnone_batches[0]: - takenamebatch = [b['takename'] for b in notnone_batches] - cond['y'].update({'takename': takenamebatch}) + if 'onehot' in notnone_batches[0]: + onehotbatch = [b['onehot'] for b in notnone_batches] + onehotbatch = torch.cat(onehotbatch, dim=0) + cond['y'].update({'onehot': onehotbatch}) return motion, cond # an adapter to our collate func @@ -70,7 +71,17 @@ def gg_collate(batch): 'audio_rep': torch.from_numpy(b[4]).float(), # [1, AUDIO_HID_DIM, 1, CHUNK_LEN] , (AUDIO_HID_DIM = MFCC_DIM or 768) 'seed': torch.tensor(b[5].T).float().unsqueeze(1), # [n_seed_poses, J] -> [J, 1, n_seed_poses] 'vad': torch.from_numpy(b[6]).long(), # [1, CHUNK_LEN] - 'takename': b[7] } for b in batch] return collate(adapted_batch) +def ptbr_collate(batch): + adapted_batch = [{ + 'inp': torch.from_numpy(b[0].T).float().unsqueeze(1), #b[0] # motion [frames, motion_dim] + 'seed': torch.from_numpy(b[1].T).float().unsqueeze(1), #b[1] # seed poses [n_seed_poses, motion_dim] + 'audio': torch.from_numpy(b[2]).unsqueeze(0), #b[2] # audio [frames] + 'vad': torch.from_numpy(b[3]).long(), #b[3] # vad [frames] + 'audio_rep': torch.from_numpy(b[4]).float(), #b[4] # wavlm + 'sample_data': b[5], #b[5] # sample_data + 'onehot': torch.from_numpy(b[5][3]).float().unsqueeze(0), #b[6] # onehot + } for b in batch] + return collate(adapted_batch) \ No newline at end of file diff --git a/evaluation_metric/output/.gitkeep b/dataset/BRG-Unicamp/.gitkeep similarity index 100% rename from evaluation_metric/output/.gitkeep rename to dataset/BRG-Unicamp/.gitkeep diff --git a/environment.yml b/environment.yml index d8bc72d..9676e70 100644 --- a/environment.yml +++ b/environment.yml @@ -1,8 +1,8 @@ -name: ggvad +name: stylistic-env channels: - pytorch - - conda-forge - anaconda + - conda-forge - defaults dependencies: - _libgcc_mutex=0.1=main @@ -115,103 +115,22 @@ dependencies: - zlib=1.2.12=h5eee18b_3 - zstd=1.4.9=haebb681_0 - pip: - - anyio==3.7.1 - - appdirs==1.4.4 - - argon2-cffi==21.3.0 - - argon2-cffi-bindings==21.2.0 - - attrs==23.1.0 - - audioread==3.0.0 - - backcall==0.2.0 - - bleach==6.0.0 - blis==0.7.8 - - blobfile==2.0.2 - chumpy==0.70 - click==8.1.3 - - comm==0.1.4 - confection==0.0.2 - - debugpy==1.6.7.post1 - - decorator==5.1.1 - - defusedxml==0.7.1 - - docker-pycreds==0.4.0 - - einops==0.6.1 - - entrypoints==0.4 - - exceptiongroup==1.1.2 - - fastjsonschema==2.18.0 - ftfy==6.1.1 - - gitdb==4.0.10 - - gitpython==3.1.32 - importlib-metadata==5.0.0 - - importlib-resources==5.12.0 - - ipykernel==6.16.2 - - ipython==7.34.0 - - ipython-genutils==0.2.0 - - ipywidgets==8.1.0 - - jedi==0.19.0 - - jsonschema==4.17.3 - - jupyter==1.0.0 - - jupyter-client==7.4.9 - - jupyter-console==6.6.3 - - jupyter-core==4.12.0 - - jupyter-server==1.24.0 - - jupyterlab-pygments==0.2.2 - - jupyterlab-widgets==3.0.8 - - lazy-loader==0.2 - - librosa==0.10.0.post2 - - llvmlite==0.39.1 - lxml==4.9.1 - - matplotlib-inline==0.1.6 - - mistune==3.0.1 - - msgpack==1.0.5 - murmurhash==1.0.8 - - nbclassic==1.0.0 - - nbclient==0.7.4 - - nbconvert==7.6.0 - - nbformat==5.8.0 - - nest-asyncio==1.5.7 - - notebook==6.5.5 - - notebook-shim==0.2.3 - - numba==0.56.4 - - pandocfilters==1.5.0 - - parso==0.8.3 - - pathtools==0.1.2 - - pexpect==4.8.0 - - pickleshare==0.7.5 - - pkgutil-resolve-name==1.3.10 - - pooch==1.6.0 - preshed==3.0.7 - - prometheus-client==0.17.1 - - prompt-toolkit==3.0.39 - - protobuf==4.24.0 - - psutil==5.9.5 - - ptyprocess==0.7.0 - pycryptodomex==3.15.0 - - pygments==2.16.1 - - pyrsistent==0.19.3 - - python-speech-features==0.6 - - pyyaml==6.0 - - pyzmq==24.0.1 - - qtconsole==5.4.3 - - qtpy==2.3.1 - regex==2022.9.13 - - send2trash==1.8.2 - - sentry-sdk==1.29.2 - - setproctitle==1.3.2 - - smmap==5.0.0 - smplx==0.1.28 - - sniffio==1.3.0 - - soundfile==0.12.1 - - soxr==0.3.5 - srsly==2.4.4 - - terminado==0.17.1 - thinc==8.0.17 - - tinycss2==1.2.1 - - traitlets==5.9.0 - typing-extensions==4.1.1 - urllib3==1.26.12 - - wandb==0.15.8 - wasabi==0.10.1 - wcwidth==0.2.5 - - webencodings==0.5.1 - - websocket-client==1.6.1 - - widgetsnbextension==4.0.8 -prefix: /root/miniconda3/envs/ggvad \ No newline at end of file +prefix: /disk2/guytevet/anaconda3/envs/mdm-venv diff --git a/eval/eval_genea.py b/eval/eval_genea.py deleted file mode 100644 index 52d0d84..0000000 --- a/eval/eval_genea.py +++ /dev/null @@ -1,200 +0,0 @@ -from data_loaders.gesture.scripts import motion_process as mp -from data_loaders.get_data import get_dataset_loader -import numpy as np -from tqdm import tqdm -from utils import dist_util -import torch -import bvhsdk -from evaluation_metric.embedding_space_evaluator import EmbeddingSpaceEvaluator -from evaluation_metric.train_AE import make_tensor -import matplotlib.pyplot as plt - -# Imports for calling from command line -from utils.parser_util import generate_args -from utils.fixseed import fixseed -from utils.model_util import create_model_and_diffusion, load_model_wo_clip - - -class GeneaEvaluator: - def __init__(self, args, model, diffusion): - self.args = args - self.model = model - self.diffusion = diffusion - self.dataloader = get_dataset_loader(name=args.dataset, - batch_size=args.batch_size, - data_dir=args.data_dir, - num_frames=args.num_frames, - step=args.num_frames, #no overlap - use_wavlm=args.use_wavlm, - use_vad=True, #Hard-coded to get vad from files but the model will not use it since args.use_vad=False - vadfromtext=args.vadfromtext, - split='val') - self.data = self.dataloader.dataset - self.bvhreference = bvhsdk.ReadFile(args.bvh_reference_file, skipmotion=True) - self.idx_positions, self.idx_rotations = mp.get_indexes('genea2023') # hard-coded 'genea2023' because std and mean vec are computed for this representation - self.fgd_evaluator = EmbeddingSpaceEvaluator(args.fgd_embedding, args.num_frames, dist_util.dev()) - self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - - - def eval(self, samples=None): - print('Starting evaluation...') - n_samples = samples if samples else len(self.data.takes) - n_chunks = np.min(self.data.samples_per_file) - rot, gt_rot, pos, gt_pos, vad = self.sampleval(n_samples, n_chunks) - pos, rot = mp.filter_and_interp(rot, pos, num_frames=self.args.num_frames) - - listpos, listposgt = [], [] - print('Converting to BVH and back to get positions...') - for i in tqdm(range(len(pos))): - # Transform to BVH and get positions of sampled motion - bvhreference = mp.tobvh(self.bvhreference, rot[i], pos[i]) - listpos.append(mp.posfrombvh(bvhreference)) - # Transform to BVH and get positions of ground truth motion - # This is just a sanity check since we could get directly from the npy files - #bvhreference = mp.tobvh(self.bvhreference, gt_rot[i], gt_pos[i]) - #listposgt.append(mp.posfrombvh(bvhreference)) - - # Compute FGD - fgd_on_feat = self.fgd(listpos, listposgt, n_samples, stride=40) - - histfig = self.getvelhist(rot, vad) - - return fgd_on_feat, histfig - - def sampleval(self, samples=None, chunks=None): - assert chunks <= np.min(self.data.samples_per_file) # assert that we don't go over the number of chunks per file - allsampledmotion = [] - allsampleposition = [] - allgtmotion = [] - allgtposition = [] - allvad = [] - print('Evaluating validation set') - for idx in range(chunks): - print('### Sampling chunk {} of {}'.format(idx+1, chunks)) - batch = self.data.getvalbatch(num_takes=samples, index=idx) - gt_motion, model_kwargs = self.dataloader.collate_fn(batch) # gt_motion: [num_samples(bs), njoints, 1, chunk_len] - model_kwargs['y'] = {key: val.to(dist_util.dev()) if torch.is_tensor(val) else val for key, val in model_kwargs['y'].items()} #seed: [num_samples(bs), njoints, 1, seed_len] - if idx > 0 and self.args.seed_poses > 0: - model_kwargs['y']['seed'] = sample_out[...,-self.data.n_seed_poses:] - sample_fn = self.diffusion.p_sample_loop - sample_out = sample_fn( - self.model, - (samples, self.model.njoints, self.model.nfeats, self.args.num_frames), - clip_denoised=False, - model_kwargs=model_kwargs, - skip_timesteps=0, # 0 is the default value - i.e. don't skip any step - init_image=None, - progress=True, - dump_steps=None, - noise=None, - const_noise=False, - ) # [num_samples(bs), njoints, 1, chunk_len] - - sample = self.data.inv_transform(sample_out.cpu().permute(0, 2, 3, 1)).float() # [num_samples(bs), 1, chunk_len, njoints] - gt_motion = self.data.inv_transform(gt_motion.cpu().permute(0, 2, 3, 1)).float() # [num_samples(bs), 1, chunk_len, njoints] - - # Split the data into positions and rotations - gt_pos, gt_rot = mp.split_pos_rot(self.args.dataset, gt_motion) - sample_pos, sample_rot = mp.split_pos_rot(self.args.dataset, sample) - - # Convert numpy array to euler angles - if self.args.dataset == 'genea2023+': - gt_rot = mp.rot6d_to_euler(gt_rot) - sample_rot = mp.rot6d_to_euler(sample_rot) - - sample_pos = sample_pos.view(sample_pos.shape[:-1] + (-1, 3)) # [num_samples(bs), 1, chunk_len, n_joints/3, 3] - sample_pos = sample_pos.view(-1, *sample_pos.shape[2:]).permute(0, 2, 3, 1) - gt_pos = gt_pos.view(gt_pos.shape[:-1] + (-1, 3)) # [num_samples(bs), 1, chunk_len, n_joints/3, 3] - gt_pos = gt_pos.view(-1, *gt_pos.shape[2:]).permute(0, 2, 3, 1) - - allsampledmotion.append(sample_rot.cpu().numpy()) - allgtmotion.append(gt_rot.cpu().numpy()) - allsampleposition.append(sample_pos.squeeze().cpu().numpy()) - allgtposition.append(gt_pos.squeeze().cpu().numpy()) - allvad.append(model_kwargs['y']['vad'].cpu().numpy()) - - allsampledmotion = np.concatenate(allsampledmotion, axis=3) - allgtmotion = np.concatenate(allgtmotion, axis=3) - allsampleposition = np.concatenate(allsampleposition, axis=3) - allgtposition = np.concatenate(allgtposition, axis=3) - allvad = np.concatenate(allvad, axis=1) - - return allsampledmotion, allgtmotion, allsampleposition, allgtposition, allvad - - def getvelhist(self, motion, vad): - joints = self.data.getjoints() - fvad = vad[:,1:].flatten() - wvad, wovad = [], [] - for joint, index in joints.items(): - vels = np.sum(np.abs((motion[:,index,:, 1:] - motion[:,index,:, :-1])), axis=1).flatten() - wvad += list(vels[fvad==1]) - wovad += list(vels[fvad==0]) - n_bins =200 - fig, axs = plt.subplots(1, 1, sharex = True, tight_layout=True, figsize=(20,20)) - axs.hist(wvad, bins = n_bins, histtype='step', label='VAD = 1', linewidth=4, color='red') - axs.hist(wovad, bins = n_bins, histtype='step', label='VAD = 0', linewidth=4, color='black') - axs.set_yscale('log') - return fig - - def fgd(self, listpos, listposgt=None, n_samples=100, stride=None): - # "Direct" ground truth positions - real_val = make_tensor(f'./dataset/Genea2023/val/main-agent/motion_npy_rotpos', self.args.num_frames, max_files=n_samples, n_chunks=None, stride=stride).to(self.device) - real_trn = make_tensor(f'./dataset/Genea2023/trn/main-agent/motion_npy_rotpos', self.args.num_frames, max_files=n_samples, n_chunks=None, stride=stride).to(self.device) - - #gt_data = self.fgd_prep(listposgt).to(self.device) - test_data = self.fgd_prep(listpos, stride=stride).to(self.device) - - print('Samples shape:') - print(test_data.shape) - print('Validation shape:') - print(real_val.shape) - print('Train shape:') - print(real_trn.shape) - - fgd_on_feat = self.run_fgd(real_val, test_data) - print(f'Sampled to validation: {fgd_on_feat:8.3f}') - - fgd_on_feat_ = self.run_fgd(real_trn, test_data) - print(f'Sampled to train: {fgd_on_feat_:8.3f}') - #fgd_on_feat = self.run_fgd(gt_data, test_data) - #print(f'Sampled to validation from pipeline: {fgd_on_feat:8.3f}') - - #fgd_on_feat = self.run_fgd(real_val, gt_data) - #print(f'Validation from pipeline to validation (should be zero): {fgd_on_feat:8.3f}') - return fgd_on_feat - - def fgd_prep(self, data, n_frames=120, stride=None): - # Prepare samples for FGD evaluation - samples = [] - stride = n_frames // 2 if stride is None else stride - for take in data: - for i in range(0, len(take) - n_frames, stride): - sample = take[i:i+n_frames] - sample = (sample - self.data.mean[self.idx_positions]) / self.data.std[self.idx_positions] - samples.append(sample) - return torch.Tensor(samples) - - def run_fgd(self, gt_data, test_data): - # Run FGD evaluation on the given data - self.fgd_evaluator.reset() - self.fgd_evaluator.push_real_samples(gt_data) - self.fgd_evaluator.push_generated_samples(test_data) - fgd_on_feat = self.fgd_evaluator.get_fgd(use_feat_space=True) - return fgd_on_feat - -def main(): - args = generate_args() - fixseed(args.seed) - dist_util.setup_dist(args.device) - print("Creating model and diffusion...") - model, diffusion = create_model_and_diffusion(args, None) - print(f"Loading checkpoints from [{args.model_path}]...") - state_dict = torch.load(args.model_path, map_location='cpu') - load_model_wo_clip(model, state_dict) - model.to(dist_util.dev()) - model.eval() # disable random masking - GeneaEvaluator(args, model, diffusion).eval() - - -if __name__ == '__main__': - main() \ No newline at end of file diff --git a/eval/eval_geneaval.py b/eval/eval_geneaval.py deleted file mode 100644 index e85dc61..0000000 --- a/eval/eval_geneaval.py +++ /dev/null @@ -1,62 +0,0 @@ -from data_loaders.gesture.scripts import motion_process as mp -from data_loaders.get_data import get_dataset_loader -import numpy as np -from tqdm import tqdm -from utils import dist_util -import torch -import bvhsdk -from evaluation_metric.embedding_space_evaluator import EmbeddingSpaceEvaluator -from evaluation_metric.train_AE import make_tensor -import matplotlib.pyplot as plt - -# Imports for calling from command line -from utils.parser_util import generate_args -from utils.fixseed import fixseed -from utils.model_util import create_model_and_diffusion, load_model_wo_clip - - -class GeneaEvaluator: - def __init__(self): - self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - self.fgd_evaluator = EmbeddingSpaceEvaluator('./evaluation_metric/output/model_checkpoint_120.bin', 120, self.device) - - def eval(self, samples=None, chunks=None): - print('Starting evaluation...') - - # Compute FGD - fgd_on_feat = self.fgd(40, None) - - def fgd(self, n_samples=100, n_chunks=1): - # "Direct" ground truth positions - real_val = make_tensor(f'./dataset/Genea2023/trn/main-agent/motion_npy_rotpos', 120, max_files=n_samples, n_chunks=n_chunks, stride=40).to(self.device) - test_data = make_tensor(f'./dataset/Genea2023/val/main-agent/motion_npy_rotpos', 120, max_files=n_samples, n_chunks=n_chunks, stride=40).to(self.device) - - fgd_on_feat = self.run_fgd(real_val, test_data) - print(f'Validation to train: {fgd_on_feat:8.3f}') - return fgd_on_feat - - def fgd_prep(self, data, n_frames=120, stride=None): - # Prepare samples for FGD evaluation - samples = [] - stride = n_frames // 2 if stride is None else stride - for take in data: - for i in range(0, len(take) - n_frames, stride): - sample = take[i:i+n_frames] - sample = (sample - self.data.mean[self.idx_positions]) / self.data.std[self.idx_positions] - samples.append(sample) - return torch.Tensor(samples) - - def run_fgd(self, gt_data, test_data): - # Run FGD evaluation on the given data - self.fgd_evaluator.reset() - self.fgd_evaluator.push_real_samples(gt_data) - self.fgd_evaluator.push_generated_samples(test_data) - fgd_on_feat = self.fgd_evaluator.get_fgd(use_feat_space=True) - return fgd_on_feat - -def main(): - GeneaEvaluator().eval() - - -if __name__ == '__main__': - main() \ No newline at end of file diff --git a/eval/eval_ptbrgestures.py b/eval/eval_ptbrgestures.py new file mode 100644 index 0000000..7843f08 --- /dev/null +++ b/eval/eval_ptbrgestures.py @@ -0,0 +1,163 @@ +from data_loaders.gesture.scripts import motion_process as mp +from data_loaders.get_data import get_dataset_loader +import numpy as np +from tqdm import tqdm +from utils import dist_util +import torch +import bvhsdk +from evaluation_metric.embedding_space_evaluator import EmbeddingSpaceEvaluator +from evaluation_metric.train_AE import make_tensor, files_to_tensor +from sample import ptbrgenerate +import matplotlib.pyplot as plt +import os, glob + +# Imports for calling from command line +from utils.parser_util import generate_args +from utils.fixseed import fixseed +from utils.model_util import create_model_and_diffusion, load_model_wo_clip + + +class PTBREvaluator: + def __init__(self, args, model, diffusion): + self.args = args + self.model = model + self.diffusion = diffusion + self.dataloader = get_dataset_loader(name=args.dataset, + batch_size=args.batch_size, + num_frames=args.num_frames, + step=args.num_frames, #no overlap + use_wavlm=args.use_wavlm, + use_vad=True, #Hard-coded to get vad from files but the model will not use it since args.use_vad=False + vadfromtext=args.vadfromtext, + split='val') + self.data = self.dataloader.dataset + self.bvhreference = bvhsdk.ReadFile(args.bvh_reference_file, skipmotion=True) + #self.idx_positions, self.idx_rotations = mp.get_indexes('genea2023') # hard-coded 'genea2023' because std and mean vec are computed for this representation + self.fgd_evaluator = EmbeddingSpaceEvaluator('./evaluation_metric/output/model_checkpoint_120_ptbr.bin', args.num_frames, dist_util.dev()) + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.ground_truth_path = './dataset/PTBRGestures/motion/pos' + assert os.path.exists(self.ground_truth_path), f"Ground truth path {self.ground_truth_path} does not exist. Required to compute FGD." + + + def eval(self, samples=None, chunks=None): + print('Starting evaluation...') + n_samples = samples if samples else len(self.data.takes) + n_chunks = chunks if chunks else np.min(self.data.samples_per_file) + #rot, gt_rot, pos, gt_pos, vad = self.sampleval(n_samples, n_chunks) + n_joints = 83 + pos, rot, sample_names = ptbrgenerate.sample(self.args, self.model, self.diffusion, self.dataloader, self.dataloader.collate_fn, n_joints) + pos, rot = mp.filter_and_interp(rot, pos, num_frames=self.args.num_frames) + + listpos, listposgt = [], [] + print('Converting to BVH and back to get positions...') + for i in tqdm(range(len(pos))): + # Transform to BVH and get positions of sampled motion + bvhreference = mp.tobvh(self.bvhreference, rot[i], pos[i]) + listpos.append(mp.posfrombvh(bvhreference)) + # Transform to BVH and get positions of ground truth motion + # This is just a sanity check since we could get directly from the npy files + #bvhreference = mp.tobvh(self.bvhreference, gt_rot[i], gt_pos[i]) + #listposgt.append(mp.posfrombvh(bvhreference)) + + # Compute cross-FGD + cross_fgd = self.cross_fgd(listpos, sample_names ) + + # Compute FGD (whole test set versus whole train set) + fgd_on_feat = self.fgd(listpos, n_samples=n_samples, n_chunks=n_chunks) + + #histfig = self.getvelhist(rot, vad) + + return fgd_on_feat, None, cross_fgd + + def cross_fgd(self, listpos, sample_names): + # Prepare ground truth data + std_vec = np.load('./dataset/PTBRGestures/pos_Std.npy') + mean_vec = np.load('./dataset/PTBRGestures/pos_Mean.npy') + idx_positions = np.arange(len(mean_vec)) + std_vec[std_vec==0] = 1 + files = glob.glob(os.path.join(self.ground_truth_path, '*.npy')) + files = [file for file in files if '_un_' not in os.path.basename(file)] + div_ground_truth, div_test = [], [] + styles = ['p01_e01', 'p01_e02', 'p01_e03', 'p02_e01', 'p02_e02', 'p02_e03'] + for style in styles: + div_files = [file for file in files if style in os.path.basename(file)] + div_ground_truth.append(files_to_tensor(div_files, mean_vec, std_vec, n_frames=self.args.num_frames, max_files=1000).to(self.device)) + + div_samples = [listpos[i] for i, name in enumerate(sample_names) if style in name] + div_test.append(self.fgd_prep(div_samples, n_frames=self.args.num_frames).to(self.device)) + + cross_fgds = {} + for gt_style, style_in_gt in zip(div_ground_truth, styles): + self.fgd_evaluator.reset() + + self.fgd_evaluator.push_real_samples(gt_style) + for test_style, style_in_test in zip(div_test, styles): + self.fgd_evaluator.push_generated_samples(test_style) + fgd_on_feat = self.fgd_evaluator.get_fgd(use_feat_space=True) + print(f'Cross-FGD gt {style_in_gt} vs test {style_in_test}: {fgd_on_feat:8.3f}') + cross_fgds.update({'gt {} vs test {}'.format(style_in_gt, style_in_test): fgd_on_feat}) + + + return cross_fgds + + def getvelhist(self, motion, vad): + joints = self.data.getjoints() + fvad = vad[:,1:].flatten() + wvad, wovad = [], [] + for joint, index in joints.items(): + vels = np.sum(np.abs((motion[:,index,:, 1:] - motion[:,index,:, :-1])), axis=1).flatten() + wvad += list(vels[fvad==1]) + wovad += list(vels[fvad==0]) + n_bins =200 + fig, axs = plt.subplots(1, 1, sharex = True, tight_layout=True, figsize=(20,20)) + axs.hist(wvad, bins = n_bins, histtype='step', label='VAD = 1', linewidth=4, color='red') + axs.hist(wovad, bins = n_bins, histtype='step', label='VAD = 0', linewidth=4, color='black') + axs.set_yscale('log') + return fig + + def fgd(self, listpos, listposgt=None, n_samples=100, n_chunks=1): + # "Direct" ground truth positions + real_val = make_tensor(self.ground_truth_path, self.args.num_frames, dataset='ptbr',max_files=n_samples, n_chunks=None).to(self.device) + + #gt_data = self.fgd_prep(listposgt).to(self.device) + test_data = self.fgd_prep(listpos).to(self.device) + + fgd_on_feat = self.run_fgd(real_val, test_data) + print(f'Sampled to validation: {fgd_on_feat:8.3f}') + return fgd_on_feat + + def fgd_prep(self, data, n_frames=120, stride=None): + # Prepare samples for FGD evaluation + samples = [] + stride = n_frames // 2 if stride is None else stride + for take in data: + for i in range(0, len(take) - n_frames, stride): + sample = take[i:i+n_frames] + sample = (sample - self.data.pos_mean) / self.data.pos_std + samples.append(sample) + return torch.Tensor(samples) + + def run_fgd(self, gt_data, test_data): + # Run FGD evaluation on the given data + self.fgd_evaluator.reset() + self.fgd_evaluator.push_real_samples(gt_data) + self.fgd_evaluator.push_generated_samples(test_data) + fgd_on_feat = self.fgd_evaluator.get_fgd(use_feat_space=True) + return fgd_on_feat + +def main(): + args = generate_args() + fixseed(args.seed) + dist_util.setup_dist(args.device) + print("Creating model and diffusion...") + model, diffusion = create_model_and_diffusion(args, None) + print(f"Loading checkpoints from [{args.model_path}]...") + state_dict = torch.load(args.model_path, map_location='cpu') + load_model_wo_clip(model, state_dict) + model.to(dist_util.dev()) + model.eval() # disable random masking + GeneaEvaluator(args, model, diffusion).eval() + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/evaluation_metric/embedding_space_evaluator.py b/evaluation_metric/embedding_space_evaluator.py index 6909909..edac9dd 100644 --- a/evaluation_metric/embedding_space_evaluator.py +++ b/evaluation_metric/embedding_space_evaluator.py @@ -9,15 +9,16 @@ class EmbeddingSpaceEvaluator: - def __init__(self, embed_net_path, n_frames, device): - # init embed net - ckpt = torch.load(embed_net_path, map_location=device) - self.pose_dim = ckpt['pose_dim'] - self.net = EmbeddingNet(self.pose_dim, n_frames).to(device) - self.net.load_state_dict(ckpt['gen_dict']) - self.net.train(False) + def __init__(self, embed_net_path, n_frames, device, dummy=False): + if not dummy: + # init embed net + ckpt = torch.load(embed_net_path, map_location=device) + self.pose_dim = ckpt['pose_dim'] + self.net = EmbeddingNet(self.pose_dim, n_frames).to(device) + self.net.load_state_dict(ckpt['gen_dict']) + self.net.train(False) - self.reset() + self.reset() def reset(self): self.real_samples = [] @@ -55,7 +56,7 @@ def frechet_distance(self, samples_A, samples_B): B_mu = np.mean(samples_B, axis=0) B_sigma = np.cov(samples_B, rowvar=False) try: - print('Computing frechet distance') + #print('Computing frechet distance') frechet_dist = self.calculate_frechet_distance(A_mu, A_sigma, B_mu, B_sigma) except ValueError: print('Something went wrong') @@ -100,7 +101,6 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): if not np.isfinite(covmean).all(): msg = ('fid calculation produces singular product; ' 'adding %s to diagonal of cov estimates') % eps - print(msg) offset = np.eye(sigma1.shape[0]) * eps covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) # Numerical error might give slight imaginary component diff --git a/evaluation_metric/output/model_checkpoint_120.bin b/evaluation_metric/output/model_checkpoint_120.bin deleted file mode 100644 index 12a59e4..0000000 Binary files a/evaluation_metric/output/model_checkpoint_120.bin and /dev/null differ diff --git a/evaluation_metric/output/model_checkpoint_120_ptbr.bin b/evaluation_metric/output/model_checkpoint_120_ptbr.bin new file mode 100644 index 0000000..7e9f178 Binary files /dev/null and b/evaluation_metric/output/model_checkpoint_120_ptbr.bin differ diff --git a/evaluation_metric/output/model_checkpoint_30.bin b/evaluation_metric/output/model_checkpoint_30.bin deleted file mode 100644 index 6adf8c2..0000000 Binary files a/evaluation_metric/output/model_checkpoint_30.bin and /dev/null differ diff --git a/evaluation_metric/train_AE.py b/evaluation_metric/train_AE.py index 5dbd323..41be5b3 100644 --- a/evaluation_metric/train_AE.py +++ b/evaluation_metric/train_AE.py @@ -60,26 +60,36 @@ def train_iter(target_data, net, optim): return ret_dict -def make_tensor(path, n_frames, stride=None, max_files=64, n_chunks=None): +def make_tensor(path, n_frames, dataset, stride=None, max_files=64, n_chunks=None): # use calculate_mean_std.py to get mean/std values #mean_vec = np.array([0.06847747184608119,0.0357063580846738,0.058578770784627525,0.06847747184608119,183.03570635799917,0.058578770784627525,16.923291745240235,168.35926808117108,0.041955713046869694,22.215592763704525,96.00786795875209,7.367188758228035,-17.300431170968334,168.94294163004272,-0.24727338642555735,-20.8715950848037,96.38631298614739,5.896949875381165,0.20216133904775593,186.95854383542118,-6.432579379678317,0.6098873474970944,208.91504190095858,-4.915889192590782,1.0909151123302776,229.8609722171563,-12.6793862809497,1.928591364103774,266.5012367723191,-15.976334397752815,7.714763567890464,280.55824733756486,1.9359663647433214,39.544279960636416,270.51838703291963,-9.270711837971861,39.544279960636416,270.51838703291963,-9.270711837971861,50.81093482089116,222.26787089133526,-4.277154500217943,50.17939487042671,194.5012435696503,21.649902584660502,47.55546776177503,193.07937915800173,24.449814516984887,51.68258742769814,183.1130622846742,28.598284340608124,52.31257982349485,178.7300103572332,29.752167196655417,52.11476016985293,176.2234409034844,29.338322795250296,49.57245060363262,183.6685185194845,30.764422876178404,49.51180617225469,177.9781510143326,31.345989263749335,48.984835786867286,174.72433725238744,30.80040989078595,46.938570200832714,184.18559365577823,32.38094931234402,46.63205116111747,177.99271292951627,33.3617908178335,45.70316718692669,175.00854125031574,32.71919596837255,43.53471060814086,184.72586361214883,32.65900465541941,43.144235682959376,179.3784708387911,34.567474274114595,42.55463405258724,176.60663073315192,34.61292128747094,46.18735692733591,192.1666386454192,25.212175451995105,42.415178023338235,190.40999141207104,24.34708837726526,38.476911084063474,186.96434952919546,26.5919414956809,36.13208594764687,183.66904789495973,28.92036860488664,-3.260172112764866,280.83999461595937,1.9124809183428453,-35.07458946647343,272.1779993206412,-10.188934519524462,-35.07458946647343,272.1779993206412,-10.188934519524462,-44.221430642610315,224.12986592661338,-2.1810619636426996,-48.22807185664309,197.1836151601496,21.820415520252148,-46.33444054575024,195.86804954902195,25.039544449623758,-45.13668197604728,194.955477133584,26.116667176956216,-41.092202689799706,193.1729063777209,26.266446799572222,-37.40708496857298,189.6782655958083,29.307581254228797,-35.636081168567486,186.31843980348182,31.96766893075243,-50.79671007842462,186.06135889526658,28.08238050295619,-51.43370660186226,181.74414430313067,29.05281490814699,-50.935496917944754,179.24853570625916,28.80991130269415,-47.240733921717286,187.10254399331873,32.83320587793653,-46.90782281703879,180.84945866262322,34.0909044604069,-45.500171172703396,177.78021908057028,33.90137244712666,-49.34991372818972,186.60529710991787,30.652207359792364,-49.09096325437287,180.95734931633348,31.311876315818232,-48.05107944797273,177.7834688982274,31.007569225066618,-44.02123442944895,187.60657246793522,33.93921950361445,-43.76597286425863,182.3069165226035,36.05010015011614,-42.954543530752,179.48150399466223,36.4037713410187,2.5187557703509666,291.83247385777486,-9.389368690830807,2.4539357753355553,304.03709615227046,-2.881607696619349]) #std_vec = np.array([8.255541100752382,0.6939720601761805,5.222482738827612,8.255541100752382,0.693972060176339,5.222482738827612,9.42100963255725,2.188372644619285,6.361481612445378,6.176403474432683,2.03443820178515,11.830267060993494,9.348297312453099,1.9986317079084224,6.001779897668299,8.162462099200143,2.2708856073173265,12.143746363244206,8.07002281773558,1.146760813248187,5.218745502524028,7.302334393532628,1.0869661835365756,5.640455884072193,7.477263866275783,1.2241260130341458,5.704034599390462,8.93140892031836,1.3043997705897612,6.246232138662631,9.656330236153833,1.7514046003701789,6.8102931179285715,9.369187221685337,3.7379417578069427,9.746046092317828,9.369187221685337,3.7379417578069427,9.746046092317828,10.227838670518729,7.286263937298487,13.107272130744121,17.056602336480278,28.775355928036703,17.65156414660871,17.616759573364362,31.57948329721122,17.58583620094351,22.409700875985976,36.64470600046972,22.894290518164947,24.53008313723881,39.07577209642486,25.295457925596633,25.690655337228268,40.177065667200765,26.808543783812727,22.2591494319307,38.401747944200125,21.950631662738534,24.854161656529712,41.50067854859389,25.042011259666936,26.473295311534336,43.197994006554005,27.221170133903826,22.10861269363438,40.10343062678119,21.14882281947506,25.013596742327397,43.87399009904381,24.52999507148788,26.52013512923851,45.59851255330215,26.774862236150504,21.725195479879996,41.3712584397316,20.536614369930735,24.39943341761268,45.32943762847826,23.38680984964916,25.786178720369854,47.14171501427746,25.319483052572473,17.914690428682025,32.89781536653379,17.7403567588838,18.181332875330988,34.64247649156103,18.418610850063928,19.771894506185426,39.40693746620956,19.868089428475518,21.55081438729135,43.62997102172089,21.5116223452909,9.702982004349828,1.7826399048001014,6.831726011211061,9.990627742798965,4.4897264818483995,9.924705741596368,9.990627742798965,4.4897264818483995,9.924705741596368,11.786532337428321,9.522123886195812,12.65776670636571,17.031859201169226,33.584280601824254,18.38674479122489,17.777049225781234,36.38401573078446,18.127686404978764,18.17511401763394,37.691189634409334,18.084351816160314,18.551140073870705,39.31562974080336,18.02247544203335,20.28661077866561,43.99125147305048,18.701801984691905,22.30030442409078,48.11743158915539,19.99370126076239,21.89724529987023,42.49461135019713,23.793704882062592,23.74127743788964,45.32675754902733,26.24880800837308,24.594991051498763,46.72175188102939,27.67702463406083,22.183607917647112,45.59612353311565,21.613578453862164,24.6049569241885,50.12226277439626,24.67927794755758,25.732481639988514,52.301113166777355,26.560623330610035,22.023636510505785,44.0923860001992,22.687839346541782,24.11276937550739,47.8236277337321,25.586912334755322,25.285414927567018,49.8865610793381,27.509628579848478,22.140064803278623,46.596465179326735,20.51837959275402,24.421617248855064,51.09661706141259,22.99069540500737,25.553481257475042,53.30714083193669,24.588120394659704,10.509501956123042,1.3845656231709993,7.232725520260621,10.735209001060486,1.9048225062899646,7.924571615243941]) - mean_vec = np.load('./dataset/Genea2023/trn/main-agent/rotpos_Mean.npy') - std_vec = np.load('./dataset/Genea2023/trn/main-agent/rotpos_Std.npy') - idx_positions = np.asarray([ [i*6+3, i*6+4, i*6+5] for i in range(int(len(mean_vec)/6)) ]).flatten() - mean_vec = mean_vec[idx_positions] - std_vec = std_vec[idx_positions] + if dataset == 'genea': + mean_vec = np.load('./dataset/Genea2023/trn/main-agent/rotpos_Mean.npy') + std_vec = np.load('./dataset/Genea2023/trn/main-agent/rotpos_Std.npy') + idx_positions = np.asarray([ [i*6+3, i*6+4, i*6+5] for i in range(int(len(mean_vec)/6)) ]).flatten() + mean_vec = mean_vec[idx_positions] + std_vec = std_vec[idx_positions] + elif dataset == 'ptbr': + std_vec = np.load('./dataset/PTBRGestures/pos_Std.npy') + mean_vec = np.load('./dataset/PTBRGestures/pos_Mean.npy') + idx_positions = np.arange(len(mean_vec)) + else: + raise ValueError('Unknown dataset: {}'.format(dataset)) + + std_vec[std_vec==0] = 1 if os.path.isdir(path): files = glob.glob(os.path.join(path, '*.npy')) else: - files = [path] + raise ValueError('Unknown path: {}'.format(path)) - files.sort() + return files_to_tensor(files, mean_vec, std_vec, idx_positions=idx_positions, n_frames=n_frames, n_chunks=n_chunks, stride=stride, max_files=max_files) +def files_to_tensor(files, mean_vec, std_vec, idx_positions=None, n_frames=120, n_chunks=None, stride=None, max_files=None): # Make sure we don't run out of memory max_files = max_files if max_files < len(files) else len(files) - + idx_positions = np.arange(len(mean_vec)) if idx_positions is None else idx_positions samples = [] stride = n_frames // 2 if stride is None else stride print('Preparing data...') @@ -96,10 +106,20 @@ def make_tensor(path, n_frames, stride=None, max_files=64, n_chunks=None): return torch.Tensor(samples) -def main(n_frames): +def main(n_frames, dataset='genea'): #https://github.com/genea-workshop/genea_challenge_2023/tree/main/evaluation_metric # dataset - train_dataset = TensorDataset(make_tensor(f'./dataset/Genea2023/trn/main-agent/motion_npy_rotpos', n_frames)) + if dataset == 'genea': + motion_source = './dataset/Genea2023/trn/main-agent/motion_npy_rotpos' + max_files = 372 + stride = 60 + elif dataset == 'ptbr': + motion_source = './dataset/PTBRGestures/motion/pos' + max_files = 716 + stride = 30 + else: + raise ValueError('Unknown dataset: {}'.format(dataset)) + train_dataset = TensorDataset(make_tensor(motion_source, n_frames, dataset=dataset, max_files=max_files, stride=stride)) print('Done') train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True, drop_last=True) @@ -140,11 +160,11 @@ def main(n_frames): # save model gen_state_dict = generator.state_dict() - save_name = f'./evaluation_metric/output/model_checkpoint_{n_frames}.bin' + save_name = f'./evaluation_metric/output/model_checkpoint_{n_frames}_{dataset}.bin' torch.save({'pose_dim': pose_dim, 'n_frames': n_frames, 'gen_dict': gen_state_dict}, save_name) if __name__ == '__main__': n_frames = 120 print('Using n_frames: {}'.format(n_frames)) - main(n_frames) + main(n_frames, dataset='ptbr') diff --git a/ggvad_container.sh b/ggvad_container.sh deleted file mode 100644 index 4e44b9c..0000000 --- a/ggvad_container.sh +++ /dev/null @@ -1,11 +0,0 @@ -while getopts g:n:p: flag -do - case "${flag}" in - g) gpu=${OPTARG};; - n) number=${OPTARG};; - p) port=${OPTARG};; - esac -done -echo "Running container ggvad_container_$number on gpu $gpu and port $port"; - -docker run --rm -it --gpus device=$gpu --userns=host --shm-size 64G -v /work/rodolfo.tonoli/ggvad-genea2023:/workspace/ggvad/ -p $port --name ggvad_container$number ggvad:latest /bin/bash \ No newline at end of file diff --git a/model/local_attention_diffstylegest.py b/model/local_attention.py similarity index 98% rename from model/local_attention_diffstylegest.py rename to model/local_attention.py index f9d10f0..1b7cb44 100644 --- a/model/local_attention_diffstylegest.py +++ b/model/local_attention.py @@ -5,6 +5,8 @@ from einops import rearrange, repeat, pack, unpack +# From DiffuseStyleGesture https://github.com/youngseng/diffusestylegesture + TOKEN_SELF_ATTN_VALUE = -5e4 def exists(val): diff --git a/model/mdm.py b/model/mod_mdm.py similarity index 93% rename from model/mdm.py rename to model/mod_mdm.py index 35beba9..4da2f79 100644 --- a/model/mdm.py +++ b/model/mod_mdm.py @@ -3,8 +3,9 @@ import torch.nn as nn import torch.nn.functional as F import clip -from model.local_attention_diffstylegest import SinusoidalEmbeddings, apply_rotary_pos_emb -from model.local_attention_diffstylegest import LocalAttention +#from model.rotation2xyz import Rotation2xyz +from model.local_attention import SinusoidalEmbeddings, apply_rotary_pos_emb +from model.local_attention import LocalAttention class MDM(nn.Module): def __init__(self, njoints, nfeats, pose_rep, data_rep, latent_dim=256, text_dim=64, ff_size=1024, @@ -41,10 +42,18 @@ def __init__(self, njoints, nfeats, pose_rep, data_rep, latent_dim=256, text_dim # VAD self.use_vad = kargs.get('use_vad', False) + self.use_style_enc = kargs.get('use_style_enc', False) if self.use_vad: - vad_lat_dim = int(self.latent_dim) - self.vad_lookup = nn.Embedding(2, vad_lat_dim) - print('Using VAD') + if self.use_style_enc: + vad_lat_dim = int(self.latent_dim) + self.vad_lookup = nn.Embedding(2, int(vad_lat_dim/2)) + print('Using VAD') + self.style_lookup = nn.Linear(12, int(vad_lat_dim/2)) + print('Using Style Encoder') + else: + vad_lat_dim = int(self.latent_dim) + self.vad_lookup = nn.Embedding(2, vad_lat_dim) + print('Using VAD') # Seed Pose Encoder self.seed_poses = kargs.get('seed_poses', 0) @@ -116,6 +125,9 @@ def __init__(self, njoints, nfeats, pose_rep, data_rep, latent_dim=256, text_dim self.output_process = OutputProcess(self.data_rep, self.input_feats, self.latent_dim, self.njoints, self.nfeats) + # Unused? + #self.rot2xyz = Rotation2xyz(device='cpu', dataset=self.dataset) + self.log_train = False self.batch_log = {'text': [], 'vad': [], @@ -165,7 +177,12 @@ def forward(self, x, timesteps, y=None): if self.use_vad: vad_vals = y['vad'] # [BS, CHUNK_LEN] emb_vad = self.vad_lookup(vad_vals) # [BS, CHUNK_LEN, LAT_DIM] - emb_vad = emb_vad.permute(1, 0, 2) # [CHUNK_LEN, BS, LAT_DIM] + emb_vad = emb_vad.permute(1, 0, 2) # [CHUNK_LEN, BS, LAT_DIM or LAT_DIM/2] + + if self.use_style_enc: + style_vals = y['onehot'] # [BS, STYLE_DIM] + style_vals = style_vals.unsqueeze(0).repeat(nframes, 1, 1) # [CHUNK_LEN, BS, STYLE_DIM] + emb_style = self.style_lookup(style_vals) # [CHUNK_LEN, BS, LAT_DIM/2] # Timesteps Embeddings emb_t = self.embed_timestep(timesteps) # [1, BS, LAT_DIM] @@ -198,7 +215,10 @@ def forward(self, x, timesteps, y=None): # Cat Pose w/ Audio (Fine-Grained) Embeddings if self.use_vad: - fg_embs = torch.cat((emb_pose, emb_audio, emb_vad), axis=2) # [CHUNK_LEN, BS, LAT_DIM + AUDIO_DIM + LAT_DIM] + if self.use_style_enc: + fg_embs = torch.cat((emb_pose, emb_audio, emb_vad, emb_style), axis=2) # [CHUNK_LEN, BS, LAT_DIM + AUDIO_DIM + LAT_DIM/2 + LAT_DIM/2] + else: + fg_embs = torch.cat((emb_pose, emb_audio, emb_vad), axis=2) # [CHUNK_LEN, BS, LAT_DIM + AUDIO_DIM + LAT_DIM] else: fg_embs = torch.cat((emb_pose, emb_audio), axis=2) # [CHUNK_LEN, BS, LAT_DIM + AUDIO_DIM] diff --git a/sample/generate.py b/sample/generate.py deleted file mode 100644 index 80e6f5a..0000000 --- a/sample/generate.py +++ /dev/null @@ -1,232 +0,0 @@ -# This code is based on https://github.com/openai/guided-diffusion -""" -Generate a large batch of image samples from a model and save them as a large -numpy array. This can be used to produce samples for FID evaluation. -""" -from utils.fixseed import fixseed -import os -import numpy as np -import torch -from utils.parser_util import generate_args -from utils.model_util import create_model_and_diffusion, load_model_wo_clip -from utils import dist_util -from data_loaders.get_data import get_dataset_loader -import shutil -from data_loaders.tensors import gg_collate -import bvhsdk -import utils.rotation_conversions as geometry -from scipy.signal import savgol_filter - -def main(): - args = generate_args() - fixseed(args.seed) - out_path = args.output_dir - name = os.path.basename(os.path.dirname(args.model_path)) - niter = os.path.basename(args.model_path).replace('model', '').replace('.pt', '') - if args.dataset in ['genea2023', 'genea2023+']: - fps = 30 - n_joints = 83 - #TODO: change to receive args.bvh_reference_file - bvhreference = bvhsdk.ReadFile('./dataset/Genea2023/trn/main-agent/bvh/trn_2023_v0_000_main-agent.bvh', skipmotion=True) - else: - raise NotImplementedError - dist_util.setup_dist(args.device) - if out_path == '': - out_path = os.path.join(os.path.dirname(args.model_path), - 'samples_{}_{}_seed{}'.format(name, niter, args.seed)) - if args.text_prompt != '': - out_path += '_' + args.text_prompt.replace(' ', '_').replace('.', '') - elif args.input_text != '': - out_path += '_' + os.path.basename(args.input_text).replace('.txt', '').replace(' ', '_').replace('.', '') - - # Hard-coded takes to be generated - num_samples = 70 - takes_to_generate = np.arange(num_samples) - - - args.batch_size = num_samples # Sampling a single batch from the testset, with exactly args.num_samples - - #inputs_i = [155,271,320,400,500,600,700,800,1145,1185] - - print('Loading dataset...') - data = load_dataset(args, num_samples) - - print("Creating model and diffusion...") - model, diffusion = create_model_and_diffusion(args, data) - - print(f"Loading checkpoints from [{args.model_path}]...") - state_dict = torch.load(args.model_path, map_location='cpu') - load_model_wo_clip(model, state_dict) - - model.to(dist_util.dev()) - model.eval() # disable random masking - - all_motions = [] #np.zeros(shape=(num_samples, n_joints, 3, args.num_frames*chunks_per_take)) - all_motions_rot = [] - all_lengths = [] - all_text = [] - all_audios = [] - - # dummy motion, batch_text, window, batch_audio, batch_audio_rep, dummy seed_poses, max_length - dummy_motion, data_text, chunk_len, data_audio, data_audio_rep, dummy_seed, max_length, vad_vals, takenames = data.dataset.gettestbatch(num_samples) - - chunks_per_take = int(max_length/chunk_len) - for chunk in range(chunks_per_take): # Motion is generated in chunks, for each chunk we load the corresponding data - empty = np.array([]) - inputs = [] - for take in range(num_samples): # For each take we will load the current chunk - vad = vad_vals[take][..., chunk:chunk+chunk_len] if args.use_vad else empty - inputs.append((empty, data_text[take][chunk], chunk_len, empty, data_audio_rep[take][..., chunk:chunk+chunk_len], dummy_seed, vad, takenames[take])) - - _, model_kwargs = gg_collate(inputs) # gt_motion: [num_samples(bs), njoints, 1, chunk_len] - model_kwargs['y'] = {key: val.to(dist_util.dev()) if torch.is_tensor(val) else val for key, val in model_kwargs['y'].items()} #seed: [num_samples(bs), njoints, 1, seed_len] - - if chunk == 0: - pass #send mean pose - else: - model_kwargs['y']['seed'] = sample_out[...,-args.seed_poses:] - - - - print('### Sampling chunk {} of {}'.format(chunk+1, int(max_length/chunk_len))) - - # add CFG scale to batch - if args.guidance_param != 1: # default 2.5 - model_kwargs['y']['scale'] = torch.ones(num_samples, device=dist_util.dev()) * args.guidance_param - - sample_fn = diffusion.p_sample_loop - - sample_out = sample_fn( - model, - (num_samples, model.njoints, model.nfeats, args.num_frames), - clip_denoised=False, - model_kwargs=model_kwargs, - skip_timesteps=0, # 0 is the default value - i.e. don't skip any step - init_image=None, - progress=True, - dump_steps=None, - noise=None, - const_noise=False, - ) # [num_samples(bs), njoints, 1, chunk_len] - - sample = data.dataset.inv_transform(sample_out.cpu().permute(0, 2, 3, 1)).float() # [num_samples(bs), 1, chunk_len, njoints] - - - # Separating positions and rotations - if args.dataset == 'genea2023': - idx_positions = np.asarray([ [i*6+3, i*6+4, i*6+5] for i in range(n_joints) ]).flatten() - idx_rotations = np.asarray([ [i*6, i*6+1, i*6+2] for i in range(n_joints) ]).flatten() - sample, sample_rot = sample[..., idx_positions], sample[..., idx_rotations] - - #rotations - sample_rot = sample_rot.view(sample_rot.shape[:-1] + (-1, 3)) - sample_rot = sample_rot.view(-1, *sample_rot.shape[2:]).permute(0, 2, 3, 1) - - - elif args.dataset == 'genea2023+': - idx_rotations = np.asarray([ [i*9, i*9+1, i*9+2, i*9+3, i*9+4, i*9+5] for i in range(n_joints) ]).flatten() - idx_positions = np.asarray([ [i*9+6, i*9+7, i*9+8] for i in range(n_joints) ]).flatten() - sample, sample_rot = sample[..., idx_positions], sample[..., idx_rotations] # sample_rot: [num_samples(bs), 1, chunk_len, n_joints*6] - - #rotations - sample_rot = sample_rot.view(sample_rot.shape[:-1] + (-1, 6)) # [num_samples(bs), 1, chunk_len, n_joints, 6] - sample_rot = geometry.rotation_6d_to_matrix(sample_rot) # [num_samples(bs), 1, chunk_len, n_joints, 3, 3] - sample_rot = geometry.matrix_to_euler_angles(sample_rot, "ZXY")[..., [1, 2, 0] ]*180/np.pi # [num_samples(bs), 1, chunk_len, n_joints, 3] - sample_rot = sample_rot.view(-1, *sample_rot.shape[2:]).permute(0, 2, 3, 1) # [num_samples(bs)*chunk_len, n_joints, 3] - - else: - raise ValueError(f'Unknown dataset: {args.dataset}') - - #positions - sample = sample.view(sample.shape[:-1] + (-1, 3)) # [num_samples(bs), 1, chunk_len, n_joints/3, 3] - sample = sample.view(-1, *sample.shape[2:]).permute(0, 2, 3, 1) # [num_samples(bs), n_joints/3, 3, chunk_len] - - text_key = 'text' if 'text' in model_kwargs['y'] else 'action_text' - all_text += model_kwargs['y'][text_key] - - all_audios.append(model_kwargs['y']['audio'].cpu().numpy()) - all_motions.append(sample.cpu().numpy()) - all_motions_rot.append(sample_rot.cpu().numpy()) - all_lengths.append(model_kwargs['y']['lengths'].cpu().numpy()) - - - - all_audios = data_audio - all_motions = np.concatenate(all_motions, axis=3) - all_motions_rot = np.concatenate(all_motions_rot, axis=3) - all_lengths = np.concatenate(all_lengths, axis=0) - - # Smooth chunk transitions - inter_range = 10 #interpolation range in frames - for transition in np.arange(1, chunks_per_take-1)*args.num_frames: - all_motions[..., transition:transition+2] = np.tile(np.expand_dims(all_motions[..., transition]/2 + all_motions[..., transition-1]/2,-1),2) - all_motions_rot[..., transition:transition+2] = np.tile(np.expand_dims(all_motions_rot[..., transition]/2 + all_motions_rot[..., transition-1]/2,-1),2) - for i, s in enumerate(np.linspace(0, 1, inter_range-1)): - forward = transition-inter_range+i - backward = transition+inter_range-i - all_motions[..., forward] = all_motions[..., forward]*(1-s) + all_motions[:, :, :, transition-1]*s - all_motions[..., backward] = all_motions[..., backward]*(1-s) + all_motions[:, :, :, transition]*s - all_motions_rot[..., forward] = all_motions_rot[..., forward]*(1-s) + all_motions_rot[:, :, :, transition-1]*s - all_motions_rot[..., backward] = all_motions_rot[..., backward]*(1-s) + all_motions_rot[:, :, :, transition]*s - - all_motions = savgol_filter(all_motions, 9, 3, axis=-1) - all_motions_rot = savgol_filter(all_motions_rot, 9, 3, axis=-1) - - if os.path.exists(out_path): - shutil.rmtree(out_path) - os.makedirs(out_path) - print(f"saving results to [{out_path}]") - - npy_path = os.path.join(out_path, 'results.npy') - - np.save(npy_path, - {'motion': all_motions, 'text': all_text, 'lengths': all_lengths, - 'num_samples': len(takes_to_generate), 'num_chunks': chunks_per_take}) - with open(npy_path.replace('.npy', '.txt'), 'w') as fw: - fw.write('\n'.join(all_text)) - with open(npy_path.replace('.npy', '_len.txt'), 'w') as fw: - fw.write('\n'.join([str(l) for l in all_lengths])) - - - for i, take in enumerate(takes_to_generate): - final_frame = data.dataset.frames[i] - save_file = data.dataset.takes[take][0] - print('Saving take {}: {}'.format(i, save_file)) - positions = all_motions[i] - positions = positions[..., :final_frame] - positions = positions.transpose(2, 0, 1) - - # Saving generated motion as bvh file - rotations = all_motions_rot[i] # [njoints/3, 3, chunk_len*chunks] - rotations = rotations[..., :final_frame] - rotations = rotations.transpose(2, 0, 1) # [chunk_len*chunks, njoints/3, 3] - bvhreference.frames = rotations.shape[0] - for j, joint in enumerate(bvhreference.getlistofjoints()): - joint.rotation = rotations[:, j, :] - joint.translation = np.tile(joint.offset, (bvhreference.frames, 1)) - bvhreference.root.translation = positions[:, 0, :] - #bvhreference.root.children[0].translation = positions[:, 1, :] - print('Saving bvh file...') - bvhsdk.WriteBVH(bvhreference, path=out_path, name=save_file, frametime=1/fps, refTPose=False) - - abs_path = os.path.abspath(out_path) - print(f'[Done] Results are at [{abs_path}]') - - -def load_dataset(args, batch_size): - data = get_dataset_loader(name=args.dataset, - data_dir=args.data_dir, - batch_size=batch_size, - num_frames=args.num_frames, - split='tst', - hml_mode='text_only', - step = args.num_frames, - use_wavlm=args.use_wavlm, - use_vad = args.use_vad, - vadfromtext = args.vadfromtext,) - #data.fixed_length = n_frames - return data - - -if __name__ == "__main__": - main() diff --git a/sample/ptbrgenerate.py b/sample/ptbrgenerate.py new file mode 100644 index 0000000..8a56e11 --- /dev/null +++ b/sample/ptbrgenerate.py @@ -0,0 +1,279 @@ +# This code is based on https://github.com/openai/guided-diffusion +""" +Generate a large batch of image samples from a model and save them as a large +numpy array. This can be used to produce samples for FID evaluation. +""" +from utils.fixseed import fixseed +import os +import numpy as np +import torch +from utils.parser_util import generate_args +from utils.model_util import create_model_and_diffusion, load_model_wo_clip +from utils import dist_util +from model.cfg_sampler import ClassifierFreeSampleModel +from data_loaders.get_data import get_dataset_loader +from data_loaders.gesture.scripts.motion_process import rot6d_to_euler +import shutil +from data_loaders.tensors import gg_collate, ptbr_collate +from soundfile import write as wavwrite +import bvhsdk +import utils.rotation_conversions as geometry +from scipy.signal import savgol_filter + +def main(): + args = generate_args() + fixseed(args.seed) + out_path = args.output_dir + name = os.path.basename(os.path.dirname(args.model_path)) + niter = os.path.basename(args.model_path).replace('model', '').replace('.pt', '') + if args.dataset in ['ptbr']: + fps = 30 + n_joints = 83 + collate_fn = ptbr_collate + split = 'val' + # iterate over samples in a take + else: + raise NotImplementedError + dist_util.setup_dist(args.device) + if out_path == '': + out_path = os.path.join(os.path.dirname(args.model_path), + 'samples_{}_{}_seed{}'.format(name, niter, args.seed)) + if args.text_prompt != '': + out_path += '_' + args.text_prompt.replace(' ', '_').replace('.', '') + elif args.input_text != '': + out_path += '_' + os.path.basename(args.input_text).replace('.txt', '').replace(' ', '_').replace('.', '') + + print('Loading dataset...') + data = load_dataset(args, args.batch_size, split) + + print("Creating model and diffusion...") + model, diffusion = create_model_and_diffusion(args, data) + + print(f"Loading checkpoints from [{args.model_path}]...") + state_dict = torch.load(args.model_path, map_location='cpu') + load_model_wo_clip(model, state_dict) + + #args.guidance_param = 1 + #if args.guidance_param != 1: + # model = ClassifierFreeSampleModel(model) # wrapping model with the classifier-free sampler + model.to(dist_util.dev()) + model.eval() # disable random masking + all_motions, all_motions_rot, _ = sample(args, model, diffusion, data, collate_fn, n_joints) + all_motions, all_motions_rot = interpolate(args, all_motions, all_motions_rot, max_chunks_in_take=np.max(data.dataset.samples_per_file)) + savebvh(data, all_motions, all_motions_rot, out_path, fps, data.dataset.bvhreference) + + +def sample(args, model, diffusion, data, collate_fn, n_joints): + + total_batches = int(np.round(len(data.dataset.samples_per_file)/args.batch_size)) + assert total_batches >= int(np.round(len(data.dataset.samples_per_file)/args.batch_size)) + max_chunks_in_take = np.max(data.dataset.samples_per_file) + + all_motions = np.zeros(shape=(total_batches*args.batch_size, n_joints, 3, args.num_frames*max_chunks_in_take)) + all_motions_rot = np.zeros(shape=(total_batches*args.batch_size, n_joints, 3, args.num_frames*max_chunks_in_take)) + all_audios = [] + files = [] + + for batch_count in range(total_batches): + print('### Sampling batch {} of {}'.format(batch_count+1, total_batches)) + + chunked_motions = [] + chunked_motions_rot = [] + chunked_audios = [] + for chunk in range(max_chunks_in_take): + batch = [] + # iterate over each take and append the sample (chunk) to the batch + first_batch_take = batch_count * args.batch_size + last_batch_take = (batch_count + 1) * args.batch_size + for file_idx in range(first_batch_take, last_batch_take): + # Append dummy samples (1) if the take has less chunks than the max + # or (2) if we are in the last batch and the number of takes is smaller than the batch size + if file_idx < len(data.dataset.samples_per_file): + if chunk < data.dataset.samples_per_file[file_idx]: + item = chunk + data.dataset.samples_cumulative[file_idx-1] if file_idx > 0 else chunk + batch.append(data.dataset.__getitem__(item)) + continue + batch.append(data.dataset.__dummysample__()) + + + _, model_kwargs = collate_fn(batch) # gt_motion: [num_samples(bs), njoints, 1, chunk_len] + model_kwargs['y'] = {key: val.to(dist_util.dev()) if torch.is_tensor(val) else val for key, val in model_kwargs['y'].items()} #seed: [num_samples(bs), njoints, 1, seed_len] + + if chunk == 0: + pass #send mean pose + else: + model_kwargs['y']['seed'] = sample_out[...,-args.seed_poses:] + + print('### Sampling chunk {} of {}'.format(chunk+1, max_chunks_in_take)) + + # add CFG scale to batch + #if args.guidance_param != 1: # default 2.5 + # model_kwargs['y']['scale'] = torch.ones(num_samples, device=dist_util.dev()) * args.guidance_param + + sample_fn = diffusion.p_sample_loop + sample_out = sample_fn( + model, + (args.batch_size, model.njoints, model.nfeats, args.num_frames), + clip_denoised=False, + model_kwargs=model_kwargs, + skip_timesteps=0, # 0 is the default value - i.e. don't skip any step + init_image=None, + progress=True, + dump_steps=None, + noise=None, + const_noise=False, + ) # [num_samples(bs), njoints, 1, chunk_len] + + sample = data.dataset.inv_transform(sample_out.cpu().permute(0, 2, 3, 1)).float() # [num_samples(bs), 1, chunk_len, njoints] + + + # Separating positions and rotations + if args.dataset in ['ptbr']: + idx_rotations = np.asarray([ [i*6, i*6+1, i*6+2, i*6+3, i*6+4, i*6+5] for i in range(n_joints) ]).flatten() + idx_positions = np.asarray([ [498 + i*3, 498 + i*3+1, 498 + i*3+2] for i in range(n_joints) ]).flatten() + + sample, sample_rot = sample[..., idx_positions], sample[..., idx_rotations] # sample_rot: [num_samples(bs), 1, chunk_len, n_joints*6] + + #rotations + sample_rot = rot6d_to_euler(sample_rot, n_joints) # [num_samples(bs)*chunk_len, n_joints, 3] + + else: + raise ValueError(f'Unknown dataset: {args.dataset}') + + #positions + sample = sample.view(sample.shape[:-1] + (-1, 3)) # [num_samples(bs), 1, chunk_len, n_joints, 3] + sample = sample.view(-1, *sample.shape[2:]).permute(0, 2, 3, 1) # [num_samples(bs), n_joints, 3, chunk_len] + + #rot2xyz_pose_rep = 'xyz' + #rot2xyz_mask = None if rot2xyz_pose_rep == 'xyz' else model_kwargs['y']['mask'].reshape(args.batch_size, n_frames).bool() + #sample = model.rot2xyz(x=sample, mask=rot2xyz_mask, pose_rep=rot2xyz_pose_rep, glob=True, translation=True, + # jointstype='smpl', vertstrans=True, betas=None, beta=0, glob_rot=None, + # get_rotations_back=False) + + #text_key = 'text' if 'text' in model_kwargs['y'] else 'action_text' + #all_text += model_kwargs['y'][text_key] + + chunked_audios.append(model_kwargs['y']['audio'].cpu().numpy()) + chunked_motions.append(sample.cpu().numpy()) + chunked_motions_rot.append(sample_rot.cpu().numpy()) + if chunk == 0: + for sample in batch: + files.append(sample[-1][0]) + #total_num_samples = num_samples * chunks_per_take + #all_audios = np.concatenate(all_audios, axis=1) + #all_audios = audio + b,e = batch_count*args.batch_size, (batch_count+1)*args.batch_size + all_motions[b:e] = np.concatenate(chunked_motions, axis=3) + #all_motions = all_motions[:total_num_samples] # [num_samples(bs), njoints/3, 3, chunk_len*chunks] + all_motions_rot[b:e] = np.concatenate(chunked_motions_rot, axis=3) + #all_motions_rot = all_motions_rot[:total_num_samples] # [num_samples(bs), njoints/3, 3, chunk_len*chunks] + #all_text = all_text[:total_num_samples] + #all_lengths = np.concatenate(all_lengths, axis=0) + return all_motions, all_motions_rot, files + +def interpolate(args, all_motions, all_motions_rot, max_chunks_in_take): + # Smooth chunk transitions + inter_range = 10 #interpolation range in frames + for transition in np.arange(1, max_chunks_in_take-1)*args.num_frames: + all_motions[..., transition:transition+2] = np.tile(np.expand_dims(all_motions[..., transition]/2 + all_motions[..., transition-1]/2,-1),2) + all_motions_rot[..., transition:transition+2] = np.tile(np.expand_dims(all_motions_rot[..., transition]/2 + all_motions_rot[..., transition-1]/2,-1),2) + for i, s in enumerate(np.linspace(0, 1, inter_range-1)): + forward = transition-inter_range+i + backward = transition+inter_range-i + all_motions[..., forward] = all_motions[..., forward]*(1-s) + all_motions[:, :, :, transition-1]*s + all_motions[..., backward] = all_motions[..., backward]*(1-s) + all_motions[:, :, :, transition]*s + all_motions_rot[..., forward] = all_motions_rot[..., forward]*(1-s) + all_motions_rot[:, :, :, transition-1]*s + all_motions_rot[..., backward] = all_motions_rot[..., backward]*(1-s) + all_motions_rot[:, :, :, transition]*s + + all_motions = savgol_filter(all_motions, 9, 3, axis=-1) + all_motions_rot = savgol_filter(all_motions_rot, 9, 3, axis=-1) + + return all_motions, all_motions_rot + +def savebvh(data, all_motions, all_motions_rot, out_path, fps, bvhreference_path): + + if os.path.exists(out_path): + shutil.rmtree(out_path) + os.makedirs(out_path) + + npy_path = os.path.join(out_path, 'results.npy') + print(f"saving results file to [{npy_path}]") + np.save(npy_path, + {'motion': all_motions}) + #with open(npy_path.replace('.npy', '.txt'), 'w') as fw: + # fw.write('\n'.join(all_text)) + #with open(npy_path.replace('.npy', '_len.txt'), 'w') as fw: + # fw.write('\n'.join([str(l) for l in all_lengths])) + + print(f"saving visualizations to [{out_path}]...") + #if args.dataset in ['genea2023+', 'genea2023']: + # skeleton = paramUtil.genea2022_kinematic_chain + #else: + # raise NotImplementedError + + sample_files = [] + num_samples_in_out_file = 7 + + #sample_print_template, row_print_template, all_print_template, \ + #sample_file_template, row_file_template, all_file_template = construct_template_variables() + + bvhreference = bvhsdk.ReadFile(bvhreference_path, skipmotion=True) + + for i, take in enumerate(range(len(data.dataset.samples_per_file))): + final_frame = data.dataset.frames[i] + save_file = 'gen_' + data.dataset.takes[take].name + print('Saving take {}: {}'.format(i, save_file)) + animation_save_path = os.path.join(out_path, save_file) + caption = '' # since we are generating a ~1 min long take the caption would be too long + positions = all_motions[i] + positions = positions[..., :final_frame] + positions = positions.transpose(2, 0, 1) + #plot_3d_motion(animation_save_path + '.mp4', skeleton, positions, dataset=args.dataset, title=caption, fps=fps) + # Credit for visualization: https://github.com/EricGuo5513/text-to-motion + + #saving samples with seed + #aux_positions = all_sample_with_seed[i] + #aux_positions = aux_positions.transpose(2, 0, 1) + #plot_3d_motion(animation_save_path + '_with_seed.mp4', skeleton, aux_positions, dataset=args.dataset, title=caption, fps=fps) + + # Saving generated motion as bvh file + rotations = all_motions_rot[i] # [njoints/3, 3, chunk_len*chunks] + rotations = rotations[..., :final_frame] + rotations = rotations.transpose(2, 0, 1) # [chunk_len*chunks, njoints/3, 3] + bvhreference.frames = rotations.shape[0] + for j, joint in enumerate(bvhreference.getlistofjoints()): + joint.rotation = rotations[:, j, :] + joint.translation = np.tile(joint.offset, (bvhreference.frames, 1)) + bvhreference.root.translation = positions[:, 0, :] + bvhreference.root.children[0].translation[:, 1] = positions[:, 1, 1] + + print('Saving bvh file...') + bvhsdk.WriteBVH(bvhreference, path=animation_save_path, name=None, frametime=1/fps, refTPose=False) + + # Saving audio and joinning it with the mp4 file of generated motion + #wavfile = animation_save_path + '.wav' + #mp4file = wavfile.replace('.wav', '.mp4') + #wavwrite( wavfile, samplerate= 22050, data = all_audios[i]) + #joinaudio = f'ffmpeg -y -loglevel warning -i {mp4file} -i {wavfile} -c:v copy -map 0:v:0 -map 1:a:0 -c:a aac -b:a 192k {mp4file[:-4]}_audio.mp4' + #os.system(joinaudio) + + abs_path = os.path.abspath(out_path) + print(f'[Done] Results are at [{abs_path}]') + + +def load_dataset(args, batch_size, split='tst'): + data = get_dataset_loader(name=args.dataset, + batch_size=batch_size, + num_frames=args.num_frames, + split=split, + hml_mode='text_only', + step = args.num_frames, + use_wavlm=args.use_wavlm, + use_vad = args.use_vad, + vadfromtext = args.vadfromtext,) + #data.fixed_length = n_frames + return data + + +if __name__ == "__main__": + main() diff --git a/save/stylistic-gesture/args.json b/save/stylistic-gesture/args.json new file mode 100644 index 0000000..5101432 --- /dev/null +++ b/save/stylistic-gesture/args.json @@ -0,0 +1,48 @@ +{ + "arch": "trans_enc", + "batch_size": 64, + "bvh_reference_file": "./dataset/Genea2023/trn/main-agent/bvh/trn_2023_v0_000_main-agent.bvh", + "cond_mask_prob": 0.1, + "cuda": true, + "data_dir": "", + "dataset": "ptbr", + "device": 0, + "diffusion_steps": 1000, + "emb_trans_dec": false, + "eval_batch_size": 32, + "eval_during_training": false, + "eval_num_samples": 1000, + "eval_rep_times": 3, + "eval_split": "test", + "fgd_embedding": "./evaluation_metric/output/model_checkpoint_120.bin", + "lambda_fc": 0.0, + "lambda_rcxyz": 0.0, + "lambda_vel": 0.0, + "latent_dim": 256, + "layers": 8, + "log_interval": 1000, + "lr": 0.0001, + "lr_anneal_steps": 0, + "mfcc_input": false, + "noise_schedule": "cosine", + "num_frames": 120, + "num_steps": 600000, + "overwrite": false, + "resume_checkpoint": "", + "save_dir": "save/6_ptbr", + "save_interval": 50000, + "seed": 10, + "seed_poses": 10, + "sigma_small": true, + "step": 10, + "train_platform_type": "NoPlatform", + "unconstrained": false, + "use_style_enc": true, + "use_text": false, + "use_vad": true, + "use_wav_enc": false, + "use_wavlm": true, + "vadfromtext": false, + "wandb": true, + "weight_decay": 0.0 +} \ No newline at end of file diff --git a/save/stylistic-gesture/model000600000.pt b/save/stylistic-gesture/model000600000.pt new file mode 100644 index 0000000..58b6b3d Binary files /dev/null and b/save/stylistic-gesture/model000600000.pt differ diff --git a/train/train_mdm.py b/train/train_mdm.py index 4b27b6c..cb9ef86 100644 --- a/train/train_mdm.py +++ b/train/train_mdm.py @@ -31,20 +31,13 @@ def main(): projectname = os.path.basename(os.path.normpath(args.save_dir)) import wandb wandb.login(anonymous="allow") - wandb.init(project='ggvad-genea2023', config=vars(args)) + wandb.init(project='ptbrgesture', config=vars(args)) args.wandb = wandb dist_util.setup_dist(args.device) print("creating data loader...") - data = get_dataset_loader(name=args.dataset, - data_dir=args.data_dir, - batch_size=args.batch_size, - num_frames=args.num_frames, - step=args.step, - use_wavlm=args.use_wavlm, - use_vad=args.use_vad, - vadfromtext=args.vadfromtext) + data = get_dataset_loader(name=args.dataset, batch_size=args.batch_size, num_frames=args.num_frames, step=args.step, use_wavlm=args.use_wavlm, use_vad=args.use_vad, vadfromtext=args.vadfromtext) print("creating model and diffusion...") model, diffusion = create_model_and_diffusion(args, data) diff --git a/train/training_loop.py b/train/training_loop.py index ca527e0..ac82d0e 100644 --- a/train/training_loop.py +++ b/train/training_loop.py @@ -15,7 +15,7 @@ from diffusion.resample import LossAwareSampler, UniformSampler from tqdm import tqdm from diffusion.resample import create_named_schedule_sampler -from eval import eval_genea +from eval import eval_ptbrgestures from data_loaders.get_data import get_dataset_loader import utils.rotation_conversions as geometry @@ -45,7 +45,8 @@ def __init__(self, args, model, diffusion, data): self.lr_anneal_steps = args.lr_anneal_steps self.log_wandb = args.wandb if self.log_wandb: - self.genea_evaluator = eval_genea.GeneaEvaluator(args, self.model, self.diffusion) + if args.dataset == 'ptbr': + self.evaluator = eval_ptbrgestures.PTBREvaluator(args, self.model, self.diffusion) self.step = 0 self.resume_step = 0 @@ -80,6 +81,24 @@ def __init__(self, args, model, diffusion, data): self.schedule_sampler_type = 'uniform' self.schedule_sampler = create_named_schedule_sampler(self.schedule_sampler_type, diffusion) self.eval_wrapper, self.eval_data, self.eval_gt_data = None, None, None + if args.dataset in ['kit', 'humanml'] and args.eval_during_training: + mm_num_samples = 0 # mm is super slow hence we won't run it during training + mm_num_repeats = 0 # mm is super slow hence we won't run it during training + gen_loader = get_dataset_loader(name=args.dataset, batch_size=args.eval_batch_size, num_frames=None, + split=args.eval_split, + hml_mode='eval') + + self.eval_gt_data = get_dataset_loader(name=args.dataset, batch_size=args.eval_batch_size, num_frames=None, + split=args.eval_split, + hml_mode='gt') + self.eval_wrapper = EvaluatorMDMWrapper(args.dataset, dist_util.dev()) + self.eval_data = { + 'test': lambda: eval_humanml.get_mdm_loader( + model, diffusion, args.eval_batch_size, + gen_loader, mm_num_samples, mm_num_repeats, gen_loader.dataset.opt.max_motion_length, + args.eval_num_samples, scale=1., + ) + } self.use_ddp = False self.ddp_model = self.model @@ -125,7 +144,6 @@ def run_loop(self): 'timestep':np.zeros(size), 'audio': np.zeros(size), 'poses': np.zeros(size)} - for stepcount, (motion, cond) in enumerate(tqdm(self.data)): if not (not self.lr_anneal_steps or self.step + self.resume_step < self.lr_anneal_steps): break @@ -166,8 +184,6 @@ def run_loop(self): if k in ['step', 'samples'] or '_q' in k: continue - #else: - # self.train_platform.report_scalar(name=k, value=v, iteration=self.step, group_name='Loss') if self.step % self.save_interval == 0: self.save() @@ -189,7 +205,6 @@ def run_loop(self): self.model.train() self.step += 1 - if not (not self.lr_anneal_steps or self.step + self.resume_step < self.lr_anneal_steps): break @@ -200,8 +215,12 @@ def run_loop(self): def valwandb(self): assert self.log_wandb - fgd, histfig = self.genea_evaluator.eval() - self.log_wandb.wandb.log({'FGD Validation': fgd, 'Rot Vel Hist': self.log_wandb.wandb.Image(histfig)}) + fgd, histfig, cross_fgd = self.evaluator.eval() + if histfig is not None: + self.log_wandb.wandb.log({'FGD Validation': fgd, 'Rot Vel Hist': self.log_wandb.wandb.Image(histfig)}) + if cross_fgd is not None: + self.log_wandb.wandb.log(cross_fgd) + self.log_wandb.wandb.log({'FGD Validation': fgd}) def run_debugemb(self): @@ -217,45 +236,6 @@ def run_debugemb(self): return self.model.debug_seed,self.model.debug_text,self.model.debug_timestep,self.model.debug_audio,self.model.debug_vad,self.model.debug_poses - #def evaluate(self): - # if not self.args.eval_during_training: - # return - # start_eval = time.time() - # if self.eval_wrapper is not None: - # print('Running evaluation loop: [Should take about 90 min]') - # log_file = os.path.join(self.save_dir, f'eval_humanml_{(self.step + self.resume_step):09d}.log') - # diversity_times = 300 - # mm_num_times = 0 # mm is super slow hence we won't run it during training - # eval_dict = eval_humanml.evaluation( - # self.eval_wrapper, self.eval_gt_data, self.eval_data, log_file, - # replication_times=self.args.eval_rep_times, diversity_times=diversity_times, mm_num_times=mm_num_times, run_mm=False) - # print(eval_dict) - # for k, v in eval_dict.items(): - # if k.startswith('R_precision'): - # for i in range(len(v)): - # self.train_platform.report_scalar(name=f'top{i + 1}_' + k, value=v[i], - # iteration=self.step + self.resume_step, - # group_name='Eval') - # else: - # self.train_platform.report_scalar(name=k, value=v, iteration=self.step + self.resume_step, - # group_name='Eval') -# - # elif self.dataset in ['humanact12', 'uestc']: - # eval_args = SimpleNamespace(num_seeds=self.args.eval_rep_times, num_samples=self.args.eval_num_samples, - # batch_size=self.args.eval_batch_size, device=self.device, guidance_param = 1, - # dataset=self.dataset, unconstrained=self.args.unconstrained, - # model_path=os.path.join(self.save_dir, self.ckpt_file_name())) - # eval_dict = eval_humanact12_uestc.evaluate(eval_args, model=self.model, diffusion=self.diffusion, data=self.data.dataset) - # print(f'Evaluation results on {self.dataset}: {sorted(eval_dict["feats"].items())}') - # for k, v in eval_dict["feats"].items(): - # if 'unconstrained' not in k: - # self.train_platform.report_scalar(name=k, value=np.array(v).astype(float).mean(), iteration=self.step, group_name='Eval') - # else: - # self.train_platform.report_scalar(name=k, value=np.array(v).astype(float).mean(), iteration=self.step, group_name='Eval Unconstrained') -# - # end_eval = time.time() - # print(f'Evaluation time: {round(end_eval-start_eval)/60}min') - def run_step(self, batch, cond): self.forward_backward(batch, cond) diff --git a/utils/misc.py b/utils/misc.py new file mode 100644 index 0000000..abe0cdc --- /dev/null +++ b/utils/misc.py @@ -0,0 +1,40 @@ +import torch + + +def to_numpy(tensor): + if torch.is_tensor(tensor): + return tensor.cpu().numpy() + elif type(tensor).__module__ != 'numpy': + raise ValueError("Cannot convert {} to numpy array".format( + type(tensor))) + return tensor + + +def to_torch(ndarray): + if type(ndarray).__module__ == 'numpy': + return torch.from_numpy(ndarray) + elif not torch.is_tensor(ndarray): + raise ValueError("Cannot convert {} to torch tensor".format( + type(ndarray))) + return ndarray + + +def cleanexit(): + import sys + import os + try: + sys.exit(0) + except SystemExit: + os._exit(0) + +def load_model_wo_clip(model, state_dict): + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + assert len(unexpected_keys) == 0 + assert all([k.startswith('clip_model.') for k in missing_keys]) + +def freeze_joints(x, joints_to_freeze): + # Freezes selected joint *rotations* as they appear in the first frame + # x [bs, [root+n_joints], joint_dim(6), seqlen] + frozen = x.detach().clone() + frozen[:, joints_to_freeze, :, :] = frozen[:, joints_to_freeze, :, :1] + return frozen diff --git a/utils/model_util.py b/utils/model_util.py index 0b16a17..5139644 100644 --- a/utils/model_util.py +++ b/utils/model_util.py @@ -1,4 +1,5 @@ -from model.mdm import MDM +#from model.mdm_old import MDM_Old as MDM +from model.mod_mdm import MDM from diffusion import gaussian_diffusion as gd from diffusion.respace import SpacedDiffusion, space_timesteps @@ -20,11 +21,11 @@ def get_model_args(args, data): # default args clip_version = 'ViT-B/32' - if args.dataset in ['genea2023']: + if args.dataset in ['genea2022', 'genea2023']: data_rep = 'genea_vec' njoints = 498 nfeats = 1 - elif args.dataset in ['genea2023+']: + elif args.dataset in ['genea2023+', 'ptbr']: data_rep = 'genea_vec+' njoints = 1245 nfeats = 1 @@ -35,7 +36,7 @@ def get_model_args(args, data): 'dropout': 0.1, 'activation': "gelu", 'data_rep': data_rep, 'cond_mask_prob': args.cond_mask_prob, 'clip_version': clip_version, 'dataset': args.dataset, 'use_text': args.use_text, 'mfcc_input': args.mfcc_input, 'use_wavlm': args.use_wavlm, 'use_vad':args.use_vad, - 'use_wav_enc':args.use_wav_enc, 'seed_poses': args.seed_poses} + 'use_wav_enc':args.use_wav_enc, 'seed_poses': args.seed_poses, 'use_style_enc': args.use_style_enc, 'dataloader': data,} def create_gaussian_diffusion(args): diff --git a/utils/parser_util.py b/utils/parser_util.py index 2ed3c1d..a00c07b 100644 --- a/utils/parser_util.py +++ b/utils/parser_util.py @@ -98,9 +98,9 @@ def add_model_options(parser): def add_data_options(parser): group = parser.add_argument_group('dataset') - group.add_argument("--dataset", default='humanml', choices=['genea2023+','genea2023'], type=str, + group.add_argument("--dataset", default='ptbr', choices=['genea2023+','genea2023', 'ptbr'], type=str, help="Dataset name (choose from list).") - group.add_argument("--data_dir", default="./dataset/Genea2023/", type=str, + group.add_argument("--data_dir", default="", type=str, help="If empty, will use defaults according to the specified dataset.") group.add_argument("--num_frames", default=120, type=int, help="Window length to be used in the dataset.") @@ -110,9 +110,11 @@ def add_data_options(parser): help="Use wavlm representations.") group.add_argument("--use_vad", default=False, type=bool, help="Use vad speech indicator values.") + group.add_argument("--use_style_enc", default=False, type=bool, + help="Use motion style indicator values.") group.add_argument("--vadfromtext", default=False, type=bool, help="Get vad speech indicator values from text.") - group.add_argument("--bvh_reference_file", default='./dataset/Genea2023/trn/main-agent/bvh/trn_2023_v0_000_main-agent.bvh', type=str, + group.add_argument("--bvh_reference_file", default='', type=str, help="BVH file reference. Used for extracting joint length and hierarchy during evaluation and generation.") group.add_argument("--fgd_embedding", default='./evaluation_metric/output/model_checkpoint_120.bin', type=str, help="Embedding Space Evaluator network for computing FGD metric.")