Skip to content
This repository has been archived by the owner on Dec 14, 2020. It is now read-only.

Commit

Permalink
Update learning curve script to show f1 over frameworks
Browse files Browse the repository at this point in the history
  • Loading branch information
danielhers committed Dec 2, 2019
1 parent 674f37e commit 9c34f9f
Showing 1 changed file with 4 additions and 11 deletions.
15 changes: 4 additions & 11 deletions tupa/scripts/visualize_learning_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,22 @@
def load_scores(basename, div="dev"):
filename = "%s.%s.csv" % (basename, div)
print("Loading %s scores from '%s'" % (div, filename))
offset = 0
try:
with open(filename) as f:
if "iteration" not in f.readline():
offset = -1
columns = [i for i, text in enumerate(f.readline().split(",")) if text == "f"]
scores = np.genfromtxt(f, delimiter=",", invalid_raise=False)
except ValueError as e:
raise ValueError("Failed reading '%s'" % filename) from e
try:
return scores[:, [PRIMARY_F1_COLUMN + offset, REMOTE_F1_COLUMN + offset]]
except IndexError:
try:
return scores[:, PRIMARY_F1_COLUMN + offset]
except IndexError as e:
raise ValueError("Failed reading '%s'" % filename) from e
return scores[:, columns]
except IndexError as e:
raise ValueError("Failed reading '%s'" % filename) from e


def visualize(scores, filename, div="dev"):
plt.plot(range(1, 1 + len(scores)), scores)
plt.xlabel("epochs")
plt.ylabel("%s f1" % div)
if len(scores.shape) > 1:
plt.legend(["primary", "remote"])
plt.title(filename)
output_file = "%s.%s.png" % (filename, div)
plt.savefig(output_file)
Expand Down

0 comments on commit 9c34f9f

Please sign in to comment.