Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sdpa stride fix #2596

Closed
wants to merge 103 commits into from
Closed
Changes from 1 commit
Commits
Show all changes
103 commits
Select commit Hold shift + click to select a range
83a9e88
Mistral.rs Squash Changes (#4)
EricLBuehler May 15, 2024
4e82fab
Merge remote-tracking branch 'upstream/main'
EricLBuehler May 15, 2024
37cafcc
Merge remote-tracking branch 'upstream/main'
EricLBuehler May 16, 2024
5892fac
fix issue with cuda header file for A10G (#5)
joshpopelka20 May 16, 2024
9b151f5
Merge remote-tracking branch 'upstream/main'
EricLBuehler May 18, 2024
ea49ea2
Remove candle-layer-norm (#6)
EricLBuehler May 19, 2024
38f8d9e
Merge
EricLBuehler May 19, 2024
c10fc33
Merge
EricLBuehler May 27, 2024
527ebcc
Merge remote-tracking branch 'upstream/main'
EricLBuehler May 28, 2024
bfc197b
Merge remote-tracking branch 'upstream/main'
EricLBuehler May 29, 2024
0c2ac76
Merge remote-tracking branch 'upstream/main'
EricLBuehler May 30, 2024
cb3dbc2
Merge remote-tracking branch 'upstream/main'
EricLBuehler Jun 1, 2024
faa9435
Add a set_dtype method
EricLBuehler Jun 3, 2024
462d948
Merge remote-tracking branch 'upstream/main'
EricLBuehler Jun 3, 2024
5c06acd
Merge remote-tracking branch 'upstream/main'
EricLBuehler Jun 4, 2024
696acaa
Add more capability to slice_assign (#7)
EricLBuehler Jun 9, 2024
0936406
Implement unfold (#8)
EricLBuehler Jun 9, 2024
636de1d
Merge remote-tracking branch 'upstream/main'
EricLBuehler Jun 11, 2024
f52e234
Bump cudarc to 0.11.5 (#10)
EricLBuehler Jun 11, 2024
bb8f6f0
Add QTensor::quantize_onto (#12)
EricLBuehler Jun 29, 2024
5b04d96
implement Slice op (#2260)
shua Jun 12, 2024
f7095bb
Fix the fast bf16 gemm cublas kernels. (#2274)
LaurentMazare Jun 18, 2024
b55b360
Fix a bug in the metal implemtation of col2im1d. (#2284)
LaurentMazare Jun 22, 2024
08e93a6
Depth Anything v2 (#2279)
jeroenvlek Jun 24, 2024
5df1ae2
Adding Gemm and ArgMax operators to candle-onnx (#2231)
socathie Jun 28, 2024
0bb678c
Add DINOv2Reg4 + PlantCLEF2024 (#2293)
v-espitalier Jun 29, 2024
b438cba
make up for the missing last token output of phi2 example (#2299)
Czxck001 Jun 29, 2024
b7a3e34
Patch metal function
EricLBuehler Jun 30, 2024
c967be9
Complete merge
EricLBuehler Jul 15, 2024
9e09d7f
Expose cublas handle
EricLBuehler Jul 26, 2024
8b357f6
Merge remote-tracking branch 'upstream/main'
EricLBuehler Jul 26, 2024
2064fb0
Merge remote-tracking branch 'upstream/main'
EricLBuehler Jul 31, 2024
1a48767
Add sdpa function with cublaslt
EricLBuehler Aug 4, 2024
7bbcf00
Update docs
EricLBuehler Aug 4, 2024
1bf7101
Add matmul_bias_and_scale
EricLBuehler Aug 4, 2024
d6d3d18
Rename
EricLBuehler Aug 4, 2024
e20d85a
Add a simple test and fix for cpu
EricLBuehler Aug 4, 2024
8d2f32a
Update sdpa function
EricLBuehler Aug 4, 2024
9f144d6
Add matmul_alpha
EricLBuehler Aug 4, 2024
c830f26
Use matmul_with_alpha in sdpa
EricLBuehler Aug 4, 2024
86d0876
Add it to mistral
EricLBuehler Aug 5, 2024
8d8889c
Add it to q llama
EricLBuehler Aug 5, 2024
d18eb13
Add attention benches
EricLBuehler Aug 5, 2024
d71b7d7
Fixes
EricLBuehler Aug 5, 2024
412e9f4
Merge commit 'd71b7d78396a944817876c56f1677bd17633234d'
EricLBuehler Aug 5, 2024
27ca77e
Simplify things a bit
EricLBuehler Aug 7, 2024
7ad6494
Mistral.rs GPTQ dev PR (#14)
EricLBuehler Aug 9, 2024
6f0e190
Fix on metal
EricLBuehler Aug 14, 2024
ec55f58
Add the flux model for image generation. (#2390)
LaurentMazare Aug 4, 2024
0a146d7
Simplify handling of flux modulations. (#2394)
LaurentMazare Aug 4, 2024
0f55c37
optimize gradient for silu a bit (#2393)
MilkFather Aug 4, 2024
aef4eba
Support the flux-dev model too. (#2395)
LaurentMazare Aug 4, 2024
c301efa
Support for mistral-nemo. (#2396)
LaurentMazare Aug 4, 2024
fd0e933
add models support and example for THUDM/glm-4 (#2362)
donjuanplatinum Aug 5, 2024
f8e2b36
Add the MMDiT model of Stable Diffusion 3 (#2397)
Czxck001 Aug 5, 2024
0e78d29
Add the import script for the T5 tokenizer. (#2399)
LaurentMazare Aug 5, 2024
1b796b9
fix: usage of `actions/checkout@v2` (#2403)
hamirmahal Aug 6, 2024
c9cdd54
Fix issues in the encodec example README.md (#2407)
jnises Aug 10, 2024
283a5cf
Soft Non-Maximum Suppression (#2400)
onichmath Aug 10, 2024
de719a2
Add documentation examples for `Tensor::i` and `Tensor::narrow` metho…
csicar Aug 10, 2024
2e72a3d
Add Based LLM from Hazy Research. (#2411)
janimo Aug 12, 2024
d7a9bd0
Fix the device for the bert attention mask. (#2414)
LaurentMazare Aug 14, 2024
3d40ffc
Clippy fixes. (#2415)
LaurentMazare Aug 14, 2024
c5c5d49
Update flash_fwd_launch_template.h with fix for kernels (#16)
joshpopelka20 Aug 14, 2024
2386e4e
Build fixes
EricLBuehler Aug 14, 2024
a38053f
Merge branch 'sdpa'
EricLBuehler Aug 14, 2024
1b1974e
Add GGUF BF16 support (#17)
EricLBuehler Aug 21, 2024
36bd9f9
Merge remote-tracking branch 'upstream/main'
EricLBuehler Aug 22, 2024
6fbddd6
Complete merge
EricLBuehler Aug 22, 2024
f706ef2
Add softcapping support to flash attention (#18)
EricLBuehler Aug 22, 2024
3c8e120
Update kernels for metal bf16 (#19)
EricLBuehler Sep 2, 2024
014f140
fix(metal/accelerate): f64-f32 type mismatch (#20)
sammcj Sep 5, 2024
f317df8
Bump the version to 0.6.1. (#2438)
LaurentMazare Aug 22, 2024
8a9d2be
onnx: workaround pow with negative base (#2439)
shua Aug 22, 2024
a7142d3
onnx: support negative index in Gather (#2440)
shua Aug 22, 2024
f62d7e8
silero-vad v5 example (#2321)
shua Aug 22, 2024
ceab78e
Fix for parler-tts, do not add the last slice of padding tokens. (#2442)
LaurentMazare Aug 22, 2024
5b4c593
Add FastViT model. (#2444)
janimo Aug 23, 2024
ef9649c
fix: qwen2 lm_head loading #2443 (#2445)
ilookee Aug 23, 2024
7412bd0
Update cudarc to 0.12. (#2451)
LaurentMazare Aug 27, 2024
8e39086
FastViT fixes. (#2452)
janimo Aug 28, 2024
8632a2f
MobileCLIP models S1 and S2 (#2454)
janimo Aug 29, 2024
f492c04
Fix FLUX.1 weights (#2457)
eugenehp Aug 29, 2024
91e0c6e
Clippy fixes for 1.81.0. (#2461)
LaurentMazare Sep 5, 2024
ad84486
Improve candle_core::Error to make it more ergonomic (#21)
EricLBuehler Sep 11, 2024
7f5a470
Add API to get current device seed (#22)
EricLBuehler Sep 11, 2024
9240d03
Add QStorage::data for cuda and metal (#23)
EricLBuehler Sep 13, 2024
8a99f7c
Fix build error with seed (#25)
EricLBuehler Sep 13, 2024
9e31a19
Add the i16 dtype (2) (#26)
ro99 Sep 15, 2024
d08212c
Merge remote-tracking branch 'upstream/main'
EricLBuehler Oct 2, 2024
c04861d
Should compile now on metal
EricLBuehler Oct 2, 2024
156ebd1
Fix dtype cast
EricLBuehler Oct 2, 2024
20a57c4
Fix set_dtype
EricLBuehler Oct 3, 2024
fa4902f
Add initial f8 e4m3 dtype (#31)
EricLBuehler Oct 17, 2024
d050b60
Remove .vscode
EricLBuehler Oct 17, 2024
6287750
Fix some metal warnings
EricLBuehler Oct 17, 2024
1f8a28a
Sync ggml metal kernels (#33)
EricLBuehler Oct 25, 2024
522531d
Add some fast Metal MLX SDPA kernels (#32)
EricLBuehler Oct 26, 2024
41324ef
Merge remote-tracking branch 'upstream/main'
EricLBuehler Oct 27, 2024
aa93235
Conditional compilation for bf16
EricLBuehler Oct 28, 2024
629ec72
Conditional compilation for bf16
EricLBuehler Oct 28, 2024
2d3df4a
Patch missing seed value
EricLBuehler Oct 29, 2024
ec1c76e
Fix metal sdpa for v stride
EricLBuehler Nov 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Soft Non-Maximum Suppression (#2400)
* Soft NMS with thresholds

* NMS Test

* Soft nms w/ boxes removed below threshold

* Soft nms test

* No longer removing bounding boxes to fit Soft-NMS focus

* Initialize confidence

* Added comments

* Refactored out updating based on IOU/sigma

* Score_threshold -> confidence_threshold for clarity

* Remove bboxes below confidence threshold

* Softnms basic functionality test

* Softnms confidence decay test

* Softnms confidence threshold test

* Softnms no overlapping bbox test

* Testing confidence after no overlap test

* Single bbox and no bbox tests

* Signify test completion

* Handling result of test functions

* Checking all pairs of bboxes instead of a forward pass

* Equal confidence overlap test

* Clarified tests for implementation

* No longer dropping boxes, just setting to 0.0

* Formatted w/ cargo
onichmath authored and EricLBuehler committed Aug 14, 2024
commit 283a5cf7f446dc4b00e34353f5c77c32eee5b94c
58 changes: 58 additions & 0 deletions candle-transformers/src/object_detection.rs
Original file line number Diff line number Diff line change
@@ -50,3 +50,61 @@ pub fn non_maximum_suppression<D>(bboxes: &mut [Vec<Bbox<D>>], threshold: f32) {
bboxes_for_class.truncate(current_index);
}
}

// Updates confidences starting at highest and comparing subsequent boxes.
fn update_confidences<D>(
bboxes_for_class: &[Bbox<D>],
updated_confidences: &mut [f32],
iou_threshold: f32,
sigma: f32,
) {
let len = bboxes_for_class.len();
for current_index in 0..len {
let current_bbox = &bboxes_for_class[current_index];
for index in (current_index + 1)..len {
let iou_val = iou(current_bbox, &bboxes_for_class[index]);
if iou_val > iou_threshold {
// Decay calculation from page 4 of: https://arxiv.org/pdf/1704.04503
let decay = (-iou_val * iou_val / sigma).exp();
let updated_confidence = bboxes_for_class[index].confidence * decay;
updated_confidences[index] = updated_confidence;
}
}
}
}

// Sorts the bounding boxes by confidence and applies soft non-maximum suppression.
// This function is based on the algorithm described in https://arxiv.org/pdf/1704.04503
pub fn soft_non_maximum_suppression<D>(
bboxes: &mut [Vec<Bbox<D>>],
iou_threshold: Option<f32>,
confidence_threshold: Option<f32>,
sigma: Option<f32>,
) {
let iou_threshold = iou_threshold.unwrap_or(0.5);
let confidence_threshold = confidence_threshold.unwrap_or(0.1);
let sigma = sigma.unwrap_or(0.5);

for bboxes_for_class in bboxes.iter_mut() {
// Sort boxes by confidence in descending order
bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap());
let mut updated_confidences = bboxes_for_class
.iter()
.map(|bbox| bbox.confidence)
.collect::<Vec<_>>();
update_confidences(
bboxes_for_class,
&mut updated_confidences,
iou_threshold,
sigma,
);
// Update confidences, set to 0.0 if below threshold
for (i, &confidence) in updated_confidences.iter().enumerate() {
bboxes_for_class[i].confidence = if confidence < confidence_threshold {
0.0
} else {
confidence
};
}
}
}
222 changes: 222 additions & 0 deletions candle-transformers/tests/nms_tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
use candle::Result;
use candle_transformers::object_detection::{
non_maximum_suppression, soft_non_maximum_suppression, Bbox,
};

#[test]
fn nms_basic() -> Result<()> {
// Boxes based upon https://thepythoncode.com/article/non-maximum-suppression-using-opencv-in-python
let mut bboxes = vec![vec![
Bbox {
xmin: 245.0,
ymin: 305.0,
xmax: 575.0,
ymax: 490.0,
confidence: 0.9,
data: (),
}, // Box 1
Bbox {
xmin: 235.0,
ymin: 300.0,
xmax: 485.0,
ymax: 515.0,
confidence: 0.8,
data: (),
}, // Box 2
Bbox {
xmin: 305.0,
ymin: 270.0,
xmax: 540.0,
ymax: 500.0,
confidence: 0.6,
data: (),
}, // Box 3
]];

non_maximum_suppression(&mut bboxes, 0.5);
let bboxes = bboxes.into_iter().next().unwrap();
assert_eq!(bboxes.len(), 1);
assert_eq!(bboxes[0].confidence, 0.9);

Ok(())
}

#[test]
fn softnms_basic_functionality() -> Result<()> {
let mut bboxes = vec![vec![
Bbox {
xmin: 0.0,
ymin: 0.0,
xmax: 1.0,
ymax: 1.0,
confidence: 0.5,
data: (),
},
Bbox {
xmin: 0.1,
ymin: 0.1,
xmax: 1.1,
ymax: 1.1,
confidence: 0.9,
data: (),
},
Bbox {
xmin: 0.2,
ymin: 0.2,
xmax: 1.2,
ymax: 1.2,
confidence: 0.6,
data: (),
},
]];

soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5));

// Should decay boxes following highest confidence box
assert!(bboxes[0][0].confidence == 0.9);
assert!(bboxes[0][1].confidence < 0.5);
assert!(bboxes[0][2].confidence < 0.6);
Ok(())
}

#[test]
fn softnms_confidence_decay() -> Result<()> {
let mut bboxes = vec![vec![
Bbox {
xmin: 0.0,
ymin: 0.0,
xmax: 1.0,
ymax: 1.0,
confidence: 0.9,
data: (),
}, // Reference box
Bbox {
xmin: 0.1,
ymin: 0.1,
xmax: 1.1,
ymax: 1.1,
confidence: 0.8,
data: (),
}, // Overlapping box
]];

soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5));

// Check that confidence of the overlapping box is decayed
assert!(bboxes[0][0].confidence == 0.9);
assert!(bboxes[0][1].confidence < 0.8);
Ok(())
}

#[test]
fn softnms_confidence_threshold() -> Result<()> {
let mut bboxes = vec![vec![
Bbox {
xmin: 0.0,
ymin: 0.0,
xmax: 1.0,
ymax: 1.0,
confidence: 0.9,
data: (),
},
Bbox {
xmin: 0.1,
ymin: 0.1,
xmax: 1.1,
ymax: 1.1,
confidence: 0.05,
data: (),
},
]];

soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5));

// Box with confidence below the threshold should be removed
assert_eq!(bboxes[0].len(), 2);
assert_eq!(bboxes[0][0].confidence, 0.9);
assert_eq!(bboxes[0][1].confidence, 0.00);
Ok(())
}

#[test]
fn softnms_no_overlap() -> Result<()> {
let mut bboxes = vec![vec![
Bbox {
xmin: 0.0,
ymin: 0.0,
xmax: 1.0,
ymax: 1.0,
confidence: 0.9,
data: (),
},
Bbox {
xmin: 2.0,
ymin: 2.0,
xmax: 3.0,
ymax: 3.0,
confidence: 0.8,
data: (),
},
]];

soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5));

// Both boxes should remain as they do not significantly overlap
assert_eq!(bboxes[0].len(), 2);
assert_eq!(bboxes[0][0].confidence, 0.9);
assert_eq!(bboxes[0][1].confidence, 0.8);
Ok(())
}
#[test]
fn softnms_no_bbox() -> Result<()> {
let mut bboxes: Vec<Vec<Bbox<()>>> = vec![];
soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5));
assert!(bboxes.is_empty());
Ok(())
}

#[test]
fn softnms_single_bbox() -> Result<()> {
let mut bboxes = vec![vec![Bbox {
xmin: 0.0,
ymin: 0.0,
xmax: 1.0,
ymax: 1.0,
confidence: 0.9,
data: (),
}]];
soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5));
assert_eq!(bboxes[0].len(), 1);
Ok(())
}

#[test]
fn softnms_equal_confidence_overlap() -> Result<()> {
let mut bboxes = vec![vec![
Bbox {
xmin: 0.0,
ymin: 0.0,
xmax: 1.0,
ymax: 1.0,
confidence: 0.5,
data: (),
},
Bbox {
xmin: 0.1,
ymin: 0.1,
xmax: 1.1,
ymax: 1.1,
confidence: 0.5,
data: (),
},
]];

soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5));

// First box will be reference box, second box should be decayed
// Implementation must change to have both be decayed
assert_eq!(bboxes[0].len(), 2);
assert!(bboxes[0][0].confidence == 0.5);
assert!(bboxes[0][1].confidence < 0.5);
Ok(())
}