diff --git a/tupa/scripts/visualize_learning_curve.py b/tupa/scripts/visualize_learning_curve.py index f74b457c..a8055d1a 100644 --- a/tupa/scripts/visualize_learning_curve.py +++ b/tupa/scripts/visualize_learning_curve.py @@ -16,29 +16,22 @@ def load_scores(basename, div="dev"): filename = "%s.%s.csv" % (basename, div) print("Loading %s scores from '%s'" % (div, filename)) - offset = 0 try: with open(filename) as f: - if "iteration" not in f.readline(): - offset = -1 + columns = [i for i, text in enumerate(f.readline().split(",")) if text == "f"] scores = np.genfromtxt(f, delimiter=",", invalid_raise=False) except ValueError as e: raise ValueError("Failed reading '%s'" % filename) from e try: - return scores[:, [PRIMARY_F1_COLUMN + offset, REMOTE_F1_COLUMN + offset]] - except IndexError: - try: - return scores[:, PRIMARY_F1_COLUMN + offset] - except IndexError as e: - raise ValueError("Failed reading '%s'" % filename) from e + return scores[:, columns] + except IndexError as e: + raise ValueError("Failed reading '%s'" % filename) from e def visualize(scores, filename, div="dev"): plt.plot(range(1, 1 + len(scores)), scores) plt.xlabel("epochs") plt.ylabel("%s f1" % div) - if len(scores.shape) > 1: - plt.legend(["primary", "remote"]) plt.title(filename) output_file = "%s.%s.png" % (filename, div) plt.savefig(output_file)