Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use per-particle pt weight in loss #383

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions mlpf/model/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from torch.nn import functional as F
from torch import Tensor, nn

from mlpf.model.logger import _logger


def sliced_wasserstein_loss(y_pred, y_true, num_projections=200):
# create normalized random basis vectors
Expand Down Expand Up @@ -74,9 +76,9 @@ def mlpf_loss(y, ypred, batch):
loss_regression_energy[batch.mask == 0] *= 0

# add weight based on target pt
# sqrt_target_pt = torch.sqrt(torch.exp(y["pt"]) * batch.X[:, :, 1])
# loss_regression_pt *= sqrt_target_pt
# loss_regression_energy *= sqrt_target_pt
sqrt_target_pt = torch.sqrt(torch.exp(y["pt"]) * batch.X[:, :, 1])
loss_regression_pt *= sqrt_target_pt
loss_regression_energy *= sqrt_target_pt

# average over all target particles
loss["Regression_pt"] = loss_regression_pt.sum() / npart
Expand Down Expand Up @@ -122,10 +124,15 @@ def mlpf_loss(y, ypred, batch):
+ loss["Regression_energy"]
)
loss_opt = loss["Total"]
if torch.isnan(loss_opt):
_logger.error(ypred)
_logger.error(sqrt_target_pt)
_logger.error(loss)
raise Exception("Loss became NaN")

# store these separately but detached
for k in loss.keys():
loss[k] = loss[k].detach().cpu().item()
loss[k] = loss[k].detach()

return loss_opt, loss

Expand Down
90 changes: 48 additions & 42 deletions mlpf/model/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,45 +123,51 @@ def validation_plots(batch, ypred_raw, ytarget, ypred, tensorboard_writer, epoch
plt.xlabel("particle proba")
tensorboard_writer.add_figure("sig_proba_elemtype{}".format(int(xcls)), fig, global_step=epoch)

tensorboard_writer.add_histogram("pt_target", torch.clamp(batch.ytarget[batch.mask][:, 2], -10, 10), global_step=epoch)
tensorboard_writer.add_histogram("pt_pred", torch.clamp(ypred_raw[2][batch.mask][:, 0], -10, 10), global_step=epoch)
ratio = (ypred_raw[2][batch.mask][:, 0] / batch.ytarget[batch.mask][:, 2])[batch.ytarget[batch.mask][:, 0] != 0]
tensorboard_writer.add_histogram("pt_ratio", torch.clamp(ratio, -10, 10), global_step=epoch)

tensorboard_writer.add_histogram("eta_target", torch.clamp(batch.ytarget[batch.mask][:, 3], -10, 10), global_step=epoch)
tensorboard_writer.add_histogram("eta_pred", torch.clamp(ypred_raw[2][batch.mask][:, 1], -10, 10), global_step=epoch)
ratio = (ypred_raw[2][batch.mask][:, 1] / batch.ytarget[batch.mask][:, 3])[batch.ytarget[batch.mask][:, 0] != 0]
tensorboard_writer.add_histogram("eta_ratio", torch.clamp(ratio, -10, 10), global_step=epoch)

tensorboard_writer.add_histogram("sphi_target", torch.clamp(batch.ytarget[batch.mask][:, 4], -10, 10), global_step=epoch)
tensorboard_writer.add_histogram("sphi_pred", torch.clamp(ypred_raw[2][batch.mask][:, 2], -10, 10), global_step=epoch)
ratio = (ypred_raw[2][batch.mask][:, 2] / batch.ytarget[batch.mask][:, 4])[batch.ytarget[batch.mask][:, 0] != 0]
tensorboard_writer.add_histogram("sphi_ratio", torch.clamp(ratio, -10, 10), global_step=epoch)

tensorboard_writer.add_histogram("cphi_target", torch.clamp(batch.ytarget[batch.mask][:, 5], -10, 10), global_step=epoch)
tensorboard_writer.add_histogram("cphi_pred", torch.clamp(ypred_raw[2][batch.mask][:, 3], -10, 10), global_step=epoch)
ratio = (ypred_raw[2][batch.mask][:, 3] / batch.ytarget[batch.mask][:, 5])[batch.ytarget[batch.mask][:, 0] != 0]
tensorboard_writer.add_histogram("cphi_ratio", torch.clamp(ratio, -10, 10), global_step=epoch)

tensorboard_writer.add_histogram("energy_target", torch.clamp(batch.ytarget[batch.mask][:, 6], -10, 10), global_step=epoch)
tensorboard_writer.add_histogram("energy_pred", torch.clamp(ypred_raw[2][batch.mask][:, 4], -10, 10), global_step=epoch)
ratio = (ypred_raw[2][batch.mask][:, 4] / batch.ytarget[batch.mask][:, 6])[batch.ytarget[batch.mask][:, 0] != 0]
tensorboard_writer.add_histogram("energy_ratio", torch.clamp(ratio, -10, 10), global_step=epoch)

for attn in sorted(list(glob.glob(f"{outdir}/attn_conv_*.npz"))):
attn_name = os.path.basename(attn).split(".")[0]
attn_matrix = np.load(attn)["att"]
batch_size = min(attn_matrix.shape[0], 8)
fig, axes = plt.subplots(1, batch_size, figsize=((batch_size * 3, 1 * 3)))
if isinstance(axes, matplotlib.axes._axes.Axes):
axes = [axes]
for ibatch in range(batch_size):
plt.sca(axes[ibatch])
# plot the attention matrix of the first event in the batch
plt.imshow(attn_matrix[ibatch].T, cmap="hot", norm=matplotlib.colors.LogNorm())
plt.xticks([])
plt.yticks([])
plt.colorbar()
plt.title("event {}, m={:.2E}".format(ibatch, np.mean(attn_matrix[ibatch][attn_matrix[ibatch] > 0])))
plt.suptitle(attn_name)
tensorboard_writer.add_figure(attn_name, fig, global_step=epoch)
try:
tensorboard_writer.add_histogram("pt_target", torch.clamp(batch.ytarget[batch.mask][:, 2], -10, 10), global_step=epoch)
tensorboard_writer.add_histogram("pt_pred", torch.clamp(ypred_raw[2][batch.mask][:, 0], -10, 10), global_step=epoch)
ratio = (ypred_raw[2][batch.mask][:, 0] / batch.ytarget[batch.mask][:, 2])[batch.ytarget[batch.mask][:, 0] != 0]
tensorboard_writer.add_histogram("pt_ratio", torch.clamp(ratio, -10, 10), global_step=epoch)

tensorboard_writer.add_histogram("eta_target", torch.clamp(batch.ytarget[batch.mask][:, 3], -10, 10), global_step=epoch)
tensorboard_writer.add_histogram("eta_pred", torch.clamp(ypred_raw[2][batch.mask][:, 1], -10, 10), global_step=epoch)
ratio = (ypred_raw[2][batch.mask][:, 1] / batch.ytarget[batch.mask][:, 3])[batch.ytarget[batch.mask][:, 0] != 0]
tensorboard_writer.add_histogram("eta_ratio", torch.clamp(ratio, -10, 10), global_step=epoch)

tensorboard_writer.add_histogram("sphi_target", torch.clamp(batch.ytarget[batch.mask][:, 4], -10, 10), global_step=epoch)
tensorboard_writer.add_histogram("sphi_pred", torch.clamp(ypred_raw[2][batch.mask][:, 2], -10, 10), global_step=epoch)
ratio = (ypred_raw[2][batch.mask][:, 2] / batch.ytarget[batch.mask][:, 4])[batch.ytarget[batch.mask][:, 0] != 0]
tensorboard_writer.add_histogram("sphi_ratio", torch.clamp(ratio, -10, 10), global_step=epoch)

tensorboard_writer.add_histogram("cphi_target", torch.clamp(batch.ytarget[batch.mask][:, 5], -10, 10), global_step=epoch)
tensorboard_writer.add_histogram("cphi_pred", torch.clamp(ypred_raw[2][batch.mask][:, 3], -10, 10), global_step=epoch)
ratio = (ypred_raw[2][batch.mask][:, 3] / batch.ytarget[batch.mask][:, 5])[batch.ytarget[batch.mask][:, 0] != 0]
tensorboard_writer.add_histogram("cphi_ratio", torch.clamp(ratio, -10, 10), global_step=epoch)

tensorboard_writer.add_histogram("energy_target", torch.clamp(batch.ytarget[batch.mask][:, 6], -10, 10), global_step=epoch)
tensorboard_writer.add_histogram("energy_pred", torch.clamp(ypred_raw[2][batch.mask][:, 4], -10, 10), global_step=epoch)
ratio = (ypred_raw[2][batch.mask][:, 4] / batch.ytarget[batch.mask][:, 6])[batch.ytarget[batch.mask][:, 0] != 0]
tensorboard_writer.add_histogram("energy_ratio", torch.clamp(ratio, -10, 10), global_step=epoch)
except ValueError as e:
print(e)

try:
for attn in sorted(list(glob.glob(f"{outdir}/attn_conv_*.npz"))):
attn_name = os.path.basename(attn).split(".")[0]
attn_matrix = np.load(attn)["att"]
batch_size = min(attn_matrix.shape[0], 8)
fig, axes = plt.subplots(1, batch_size, figsize=((batch_size * 3, 1 * 3)))
if isinstance(axes, matplotlib.axes._axes.Axes):
axes = [axes]
for ibatch in range(batch_size):
plt.sca(axes[ibatch])
# plot the attention matrix of the first event in the batch
plt.imshow(attn_matrix[ibatch].T, cmap="hot", norm=matplotlib.colors.LogNorm())
plt.xticks([])
plt.yticks([])
plt.colorbar()
plt.title("event {}, m={:.2E}".format(ibatch, np.mean(attn_matrix[ibatch][attn_matrix[ibatch] > 0])))
plt.suptitle(attn_name)
tensorboard_writer.add_figure(attn_name, fig, global_step=epoch)
except ValueError as e:
print(e)
104 changes: 42 additions & 62 deletions mlpf/model/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,57 +67,26 @@ def configure_model_trainable(model: MLPF, trainable: Union[str, List[str]], is_
model.eval()


def train_step(batch, model, optimizer, lr_schedule, loss_fn):
"""Single training step logic

Args:
batch: The input batch data
model: The neural network model
optimizer: The optimizer
lr_schedule: Learning rate scheduler
loss_fn: Loss function to use

Returns:
dict: Dictionary containing all computed losses with gradient detached
"""
def model_step(batch, model, loss_fn):
ypred_raw = model(batch.X, batch.mask)
ypred = unpack_predictions(ypred_raw)
ytarget = unpack_target(batch.ytarget, model)

loss_opt, losses_detached = loss_fn(ytarget, ypred, batch)
return loss_opt, losses_detached, ypred_raw, ypred, ytarget


def optimizer_step(model, loss_opt, optimizer, lr_schedule, scaler):
# Clear gradients
for param in model.parameters():
param.grad = None

# Backward pass and optimization
loss_opt.backward()
optimizer.step()
scaler.scale(loss_opt).backward()
scaler.step(optimizer)
scaler.update()
if lr_schedule:
lr_schedule.step()

return losses_detached


def eval_step(batch, model, loss_fn):
"""Single evaluation step logic

Args:
batch: The input batch data
model: The neural network model
loss_fn: Loss function to use

Returns:
tuple: (losses dict, predictions dict, targets dict)
"""
with torch.no_grad():
ypred_raw = model(batch.X, batch.mask)
ypred = unpack_predictions(ypred_raw)
ytarget = unpack_target(batch.ytarget, model)
_, losses_detached = loss_fn(ytarget, ypred, batch)

return losses_detached, ypred_raw, ypred, ytarget


def train_epoch(
rank: Union[int, str],
Expand All @@ -133,6 +102,7 @@ def train_epoch(
checkpoint_dir="",
device_type="cuda",
dtype=torch.float32,
scaler=None,
):
"""Run one training epoch

Expand Down Expand Up @@ -167,7 +137,9 @@ def train_epoch(
batch = batch.to(rank, non_blocking=True)

with torch.autocast(device_type=device_type, dtype=dtype, enabled=device_type == "cuda"):
loss = train_step(batch, model, optimizer, lr_schedule, mlpf_loss)
loss_opt, loss, _, _, _ = model_step(batch, model, mlpf_loss)

optimizer_step(model, loss_opt, optimizer, lr_schedule, scaler)

# Accumulate losses
for loss_name in loss:
Expand All @@ -191,14 +163,14 @@ def train_epoch(
comet_experiment.log_metric("learning_rate", lr_schedule.get_last_lr(), step=step)

# Average losses across steps
num_steps = len(train_loader)
num_steps = torch.tensor(float(len(train_loader)), device=rank, dtype=torch.float32)
if world_size > 1:
torch.distributed.all_reduce(num_steps)

for loss_name in epoch_loss:
if world_size > 1:
torch.distributed.all_reduce(epoch_loss[loss_name])
epoch_loss[loss_name] = epoch_loss[loss_name] / num_steps
epoch_loss[loss_name] = epoch_loss[loss_name].cpu().item() / num_steps.cpu().item()

if world_size > 1:
dist.barrier()
Expand Down Expand Up @@ -261,7 +233,8 @@ def eval_epoch(
set_save_attention(model, outdir, False)

with torch.autocast(device_type=device_type, dtype=dtype, enabled=device_type == "cuda"):
loss, ypred_raw, ypred, ytarget = eval_step(batch, model, mlpf_loss)
with torch.no_grad():
loss_opt, loss, ypred_raw, ypred, ytarget = model_step(batch, model, mlpf_loss)

# Update confusion matrices
cm_X_target += sklearn.metrics.confusion_matrix(
Expand Down Expand Up @@ -297,14 +270,14 @@ def eval_epoch(
)

# Average losses across steps
num_steps = len(valid_loader)
num_steps = torch.tensor(float(len(valid_loader)), device=rank, dtype=torch.float32)
if world_size > 1:
torch.distributed.all_reduce(num_steps)

for loss_name in epoch_loss:
if world_size > 1:
torch.distributed.all_reduce(epoch_loss[loss_name])
epoch_loss[loss_name] = epoch_loss[loss_name] / num_steps
epoch_loss[loss_name] = epoch_loss[loss_name].cpu().item() / num_steps.cpu().item()

if world_size > 1:
dist.barrier()
Expand Down Expand Up @@ -383,6 +356,8 @@ def train_all_epochs(
stale_epochs = torch.tensor(0, device=rank)
best_val_loss = float("inf")

scaler = torch.amp.GradScaler()

for epoch in range(start_epoch, num_epochs + 1):
epoch_start_time = time.time()

Expand All @@ -401,6 +376,7 @@ def train_all_epochs(
checkpoint_dir=checkpoint_dir,
device_type=device_type,
dtype=dtype,
scaler=scaler,
)
train_time = time.time() - epoch_start_time

Expand Down Expand Up @@ -430,21 +406,6 @@ def train_all_epochs(

# Handle checkpointing and early stopping on rank 0
if (rank == 0) or (rank == "cpu"):

# evaluate the model at this epoch on test datasets, make plots, track metrics
testdir_name = f"_epoch_{epoch}"
for sample in config["test_dataset"]:
run_test(rank, world_size, config, outdir, model, sample, testdir_name, dtype)
plot_metrics = make_plots(outdir, sample, config["dataset"], testdir_name, config["ntest"])

# track the following jet metrics in tensorboard
for k in ["med", "iqr", "match_frac"]:
tensorboard_writer_valid.add_scalar(
"epoch/{}/jet_ratio/jet_ratio_target_to_pred_pt/{}".format(sample, k),
plot_metrics["jet_ratio"]["jet_ratio_target_to_pred_pt"][k],
epoch,
)

# Log learning rate
tensorboard_writer_train.add_scalar("epoch/learning_rate", lr_schedule.get_last_lr()[0], epoch)

Expand Down Expand Up @@ -504,6 +465,20 @@ def train_all_epochs(
tensorboard_writer_train.flush()
tensorboard_writer_valid.flush()

# evaluate the model at this epoch on test datasets, make plots, track metrics
testdir_name = f"_epoch_{epoch}"
for sample in config["enabled_test_datasets"]:
run_test(rank, world_size, config, outdir, model, sample, testdir_name, dtype)
plot_metrics = make_plots(outdir, sample, config["dataset"], testdir_name, config["ntest"])

# track the following jet metrics in tensorboard
for k in ["med", "iqr", "match_frac"]:
tensorboard_writer_valid.add_scalar(
"epoch/{}/jet_ratio/jet_ratio_target_to_pred_pt/{}".format(sample, k),
plot_metrics["jet_ratio"]["jet_ratio_target_to_pred_pt"][k],
epoch,
)

# Ray training specific logging
if use_ray:
import ray
Expand Down Expand Up @@ -787,14 +762,14 @@ def run(rank, world_size, config, outdir, logfile):
testdir_name = "_best_weights"

if config["test"]:
for sample in config["test_dataset"]:
for sample in config["enabled_test_datasets"]:
run_test(rank, world_size, config, outdir, model, sample, testdir_name, dtype)

# make plots only on a single machine
if (rank == 0) or (rank == "cpu"):
if config["make_plots"]:
ntest_files = -1
for sample in config["test_dataset"]:
for sample in config["enabled_test_datasets"]:
_logger.info(f"Plotting distributions for {sample}")
make_plots(outdir, sample, config["dataset"], testdir_name, ntest_files)

Expand All @@ -817,8 +792,13 @@ def override_config(config: dict, args):
for model in ["gnn_lsh", "attention", "attention", "mamba"]:
config["model"][model]["num_convs"] = args.num_convs

config["enabled_test_datasets"] = list(config["test_dataset"].keys())
if len(args.test_datasets) != 0:
config["test_dataset"] = args.test_datasets
config["enabled_test_datasets"] = args.test_datasets

config["train"] = args.train
config["test"] = args.test
config["make_plots"] = args.make_plots

return config

Expand Down
3 changes: 2 additions & 1 deletion mlpf/plotting/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def get_class_names(sample_name):
"cms_pf_single_pi0": r"single neutral pion particle gun events",
"cms_pf_single_proton": r"single proton particle gun events",
"cms_pf_single_tau": r"single tau particle gun events",
"cms_pf_single_k0": r"single K0 particle gun events",
"cms_pf_sms_t1tttt": r"sms t1tttt events",
}

Expand Down Expand Up @@ -418,7 +419,7 @@ def compute_3dmomentum_and_ratio(yvals):
}


def save_img(outfile, epoch, cp_dir=None, comet_experiment=None):
def save_img(outfile, epoch=None, cp_dir=None, comet_experiment=None):
if cp_dir:
image_path = str(cp_dir / outfile)
plt.savefig(image_path, dpi=100, bbox_inches="tight")
Expand Down
11 changes: 7 additions & 4 deletions mlpf/timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,11 @@ def get_mem_mb(use_gpu):
sess_options.add_session_config_entry("session.intra_op.allow_spinning", "1")

onnx_sess = rt.InferenceSession(args.model, sess_options, providers=EP_list)
# warmup

mem_onnx = get_mem_mb(use_gpu)
print("mem_onnx", mem_onnx)

# warmup
X = np.array(np.random.randn(batch_size, bin_size, num_features), getattr(np, args.input_dtype))
for i in range(10):
onnx_sess.run(None, {"Xfeat_normed": X, "mask": (X[..., 0] != 0).astype(np.float32)})
Expand All @@ -103,9 +103,12 @@ def get_mem_mb(use_gpu):

# transfer data to GPU, run model, transfer data back
t0 = time.time()
# pred_onx = onnx_sess.run(None, {"Xfeat_normed": X, "l_mask_": X[..., 0]==0})
pred_onx = onnx_sess.run(None, {"Xfeat_normed": X, "mask": (X[..., 0] != 0).astype(np.float32)})
t1 = time.time()
try:
onnx_sess.run(None, {"Xfeat_normed": X, "mask": (X[..., 0] != 0).astype(np.float32)})
t1 = time.time()
except Exception as e:
print(e)
t1 = t0
dt = (t1 - t0) / batch_size
times.append(dt)

Expand Down
Loading