Skip to content

Commit

Permalink
Speed up convergence a lot by adding additional parameter per player …
Browse files Browse the repository at this point in the history
…that is shared across time. Model output is mathematically identical.
  • Loading branch information
lukaszlew committed May 12, 2024
1 parent e597650 commit b9ec5e8
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 11 deletions.
20 changes: 13 additions & 7 deletions accurating/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,12 @@ def fit(
winner_prior = config.winner_prior_rating / config.rating_difference_for_2_to_1_odds
loser_prior = config.loser_prior_rating / config.rating_difference_for_2_to_1_odds

def get_ratings(p):
return p['season_rating'] + p['shared_rating']

def model(params):
log_likelihood = 0.0
ratings = params['rating']
ratings = get_ratings(params)
assert ratings.shape == (player_count, season_count)
p1_ratings = ratings[p1s, seasons]
p2_ratings = ratings[p2s, seasons]
Expand Down Expand Up @@ -183,8 +186,9 @@ def model(params):
# return jnp.sum(winner_win_prob_log) - 0.005*jnp.sum(cons ** 2) # or mean?

# Optimize for these params:
rating = jnp.zeros([player_count, season_count], dtype=jnp.float64) + (loser_prior + winner_prior) / 2.0
params = { 'rating': rating }
shared_rating = jnp.zeros([player_count, 1], dtype=jnp.float64) + (loser_prior + winner_prior) / 2.0
season_rating = jnp.zeros([player_count, season_count], dtype=jnp.float64)
params = { 'season_rating': season_rating, 'shared_rating': shared_rating }
# 'consistency': jnp.zeros([player_count, season_count]),

# Momentum gradient descent with restarts
Expand All @@ -196,6 +200,7 @@ def model(params):
last_grad = tree_map(jnp.zeros_like, params)
last_reset_step = 0


for i in range(config.max_steps):
(eval, model_fit), grad = jax.value_and_grad(model, has_aux=True)(params)

Expand All @@ -219,12 +224,13 @@ def model(params):
params = tree_map(lambda p, m: p + lr * m, params, momentum)

max_d_rating = jnp.max(
jnp.abs(params['rating'] - last_params['rating']))
jnp.abs(get_ratings(params) - get_ratings(last_params)))

if config.do_log:
g = jnp.linalg.norm(grad['rating'])
g = get_ratings(grad)
g = jnp.sqrt(jnp.mean(g*g))
print(
f'Step {i:4}: eval={jnp.exp2(eval):0.12f} pred_power={model_fit:0.6f} lr={lr: 4.4f} grad={g:2.4f} delta={max_d_rating}')
f'Step {i:4}: eval={jnp.exp2(eval):0.12f} pred_power={model_fit:0.6f} lr={lr: 4.4f} grad={g:2.8f} delta={max_d_rating}')

if max_d_rating < 1e-15:
break
Expand All @@ -237,7 +243,7 @@ def postprocess():
for id, name in enumerate(data.player_name):
rating[name] = {}
for season in range(season_count):
rating[name][season] = float(params['rating'][id, season]) * config.rating_difference_for_2_to_1_odds
rating[name][season] = float(get_ratings(params)[id, season]) * config.rating_difference_for_2_to_1_odds
last_rating.append((rating[name][season_count - 1], name))
if config.do_log:
headers = ['Nick']
Expand Down
6 changes: 3 additions & 3 deletions accurating/tests/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_fit():

elos = elos - jnp.min(elos, axis=0, keepdims=True)
err = jnp.linalg.norm(elos - jnp.array(true_elos) * 100.0)
assert err < 0.0001, f'FAIL err={err}; results={model}'
assert err < 0.001, f'FAIL err={err}; {elos=}; results={model}'


def test_data_from_dicts():
Expand Down Expand Up @@ -104,8 +104,8 @@ def test_data_from_dicts():
assert_almost_equal(model.rating['Alusia'][0], 0.0)
assert_almost_equal(model.rating['Alusia'][1], 0.0)

v = 13.45060623432789
v2 = 26.901228584915188
v = 13.45061478482753
v2 = 26.901218694229996
assert_almost_equal(model.rating['Caesar'][0], -v)
assert_almost_equal(model.rating['Caesar'][1], v2)
assert_almost_equal(model.rating['Leon'][0], v)
Expand Down
2 changes: 1 addition & 1 deletion iglo/ratings.json

Large diffs are not rendered by default.

0 comments on commit b9ec5e8

Please sign in to comment.