Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to show the t-SNE? #18

Open
ZhihaoChe opened this issue Oct 20, 2024 · 1 comment
Open

How to show the t-SNE? #18

ZhihaoChe opened this issue Oct 20, 2024 · 1 comment

Comments

@ZhihaoChe
Copy link

I am very excited to see such an excellent article. Could you please provide the code for the t-SNE graph?

@viyjy
Copy link
Collaborator

viyjy commented Oct 21, 2024

Hi, I probably used the following code.

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.manifold import TSNE
import matplotlib.cm as cm
import argparse
import os

class FeatureVisualize(object):
    '''
    Visualize features by TSNE
    '''
    def __init__(self, features, labels):
        '''
        features: (m,n)
        labels: (m,)
        '''
        self.features = features
        self.labels = labels

    def plot_tsne(self, save_path, save_eps=False):
        ''' Plot TSNE figure. Set save_eps=True if you want to save a .eps file.
        '''
        tsne = TSNE(n_components=2, init='pca', random_state=0)
        features = tsne.fit_transform(self.features)
        x_min, x_max = np.min(features, 0), np.max(features, 0)
        data = (features - x_min) / (x_max - x_min)
        del features
        
        colors = ["b", "r"]
        #plt.plot(data[:, 0], data[:, 1], color=self.labels 'bo')
        for i in range(data.shape[0]):
            color = colors[self.labels[i]]
            plt.scatter(data[i, 0], data[i, 1], color=color, s=10)

        plt.xticks([])
        plt.yticks([])
        if save_eps:
            plt.savefig(os.path.join(save_path, 'tsne_v1.png'), dpi=600, bbox_inches='tight')
        plt.show()


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--name", required=True,
                        help="Name of this run. Used for monitoring.")
    parser.add_argument("--dataset", default="svhn2mnist",
                        help="Which downstream task.")
    parser.add_argument("--save_dir", default="./tsne_plot", type=str,
                        help="The directory where attention maps will be saved.")
    args = parser.parse_args()

    save_path = os.path.join(args.save_dir, args.dataset, args.name)
    features = np.load(os.path.join(save_path, 'feature.npy'))
    labels = np.load(os.path.join(save_path, 'label.npy'))
    vis = FeatureVisualize(features, labels)
    vis.plot_tsne(save_path, save_eps=True)


if __name__ == '__main__':
    main()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants