Skip to content

Commit

Permalink
err_val handling in vbd_data
Browse files Browse the repository at this point in the history
  • Loading branch information
nadarenator committed Jan 2, 2025
1 parent 4379b89 commit 4050069
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 106 deletions.
53 changes: 29 additions & 24 deletions data_utils/vbd_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,20 @@ def filter_topk_roadgraph_points(global_road_graph, reference_points, topk):
filtered_id = global_road_graph.id[0][top_idx]

# Stack the filtered attributes to form a new roadgraph tensor
filtered_tensor = torch.stack([
filtered_xy[..., 0],
filtered_xy[..., 1],
filtered_length,
filtered_width,
filtered_height,
filtered_orientation,
torch.zeros_like(filtered_length),
filtered_id,
filtered_type
], dim=-1)
filtered_tensor = torch.stack(
[
filtered_xy[..., 0],
filtered_xy[..., 1],
filtered_length,
filtered_width,
filtered_height,
filtered_orientation,
torch.zeros_like(filtered_length),
filtered_id,
filtered_type
],
dim=-1
)

return GlobalRoadGraphPoints(filtered_tensor.clone())
else:
Expand Down Expand Up @@ -194,7 +197,7 @@ def data_process_agent(
agents_interested[i] = 1

agents_type[i] = agent_type
agents_history[i] = torch.stack(
agents_history[i] = torch.column_stack(
[
log_trajectory.pos_xy[0, a, :init_steps+1, 0],
log_trajectory.pos_xy[0, a, :init_steps+1, 1],
Expand All @@ -205,21 +208,23 @@ def data_process_agent(
global_agent_obs.vehicle_width[0, a].repeat(init_steps + 1),
global_agent_obs.vehicle_height[0, a].repeat(init_steps + 1),
],
dim=-1
).numpy()

agents_history[i][~log_trajectory.valids[0, a, :init_steps+1]] = 0
mask = log_trajectory.valids[0, a, :init_steps+1].numpy()
agents_history[i] *= mask

agents_future[i] = torch.stack(
[
log_trajectory.pos_xy[0, a, init_steps:, 0],
log_trajectory.pos_xy[0, a, init_steps:, 1],
log_trajectory.yaw[0, a, init_steps:, 0],
log_trajectory.vel_xy[0, a, init_steps:, 0],
log_trajectory.vel_xy[0, a, init_steps:, 1],
],
dim=-1
).numpy()
agents_future[i] = torch.column_stack(
[
log_trajectory.pos_xy[0, a, init_steps:, 0],
log_trajectory.pos_xy[0, a, init_steps:, 1],
log_trajectory.yaw[0, a, init_steps:, 0],
log_trajectory.vel_xy[0, a, init_steps:, 0],
log_trajectory.vel_xy[0, a, init_steps:, 1],
],
).numpy()

mask = log_trajectory.valids[0, a, init_steps:].numpy()
agents_future[i] *= mask

# Type of agents: 0 for None, 1 for Vehicle, 2 for Pedestrian, 3 for Cyclist
mapped_agents_type = np.zeros_like(agents_type)
Expand Down
Binary file added gpudrive_vbd_sample_6d2a107f2e8390a.pkl
Binary file not shown.
165 changes: 83 additions & 82 deletions integrations/models/notebooks/01_features_deepdive.ipynb

Large diffs are not rendered by default.

0 comments on commit 4050069

Please sign in to comment.