Skip to content

Commit

Permalink
finish up most of the python bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
Swatinem committed Jan 7, 2025
1 parent 90c0a19 commit 618e703
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 51 deletions.
1 change: 1 addition & 0 deletions benches/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ fn binary(c: &mut Criterion) {
b.iter(|| {
let parsed = TestAnalytics::parse(&buf, 0).unwrap();
for test in parsed.tests(0..60, None).unwrap() {
let test = test.unwrap();
let _name = black_box(test.name().unwrap());
let _aggregates = black_box(test.aggregates());
}
Expand Down
108 changes: 92 additions & 16 deletions src/binary/bindings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,33 @@ use super::{TestAnalytics, TestAnalyticsWriter};

#[pyclass]
pub struct BinaryFormatWriter {
writer: TestAnalyticsWriter,
writer: Option<TestAnalyticsWriter>,
}

#[pymethods]
impl BinaryFormatWriter {
#[new]
pub fn new() -> Self {
Self {
writer: TestAnalyticsWriter::new(60),
writer: Some(TestAnalyticsWriter::new(60)),
}
}

#[staticmethod]
pub fn open(buffer: &[u8]) -> anyhow::Result<Self> {
let format = TestAnalytics::parse(buffer, 0)?;
let writer = TestAnalyticsWriter::from_existing_format(&format)?;
Ok(Self {
writer: Some(writer),
})
}

pub fn add_testruns(
&mut self,
timestamp: u32,
commit_hash: &str,
flags: &[&str],
testruns: &[Testrun],
flags: Vec<String>,
testruns: Vec<Testrun>,
) -> anyhow::Result<()> {
let commit_hash_base16 = if commit_hash.len() > 40 {
commit_hash
Expand All @@ -35,29 +47,61 @@ impl BinaryFormatWriter {
let mut commit_hash = super::CommitHash::default();
base16ct::mixed::decode(commit_hash_base16, &mut commit_hash.0)?;

let mut session = self.writer.start_session(timestamp, commit_hash, flags);
let writer = self
.writer
.as_mut()
.context("writer was already serialized")?;

let flags: Vec<_> = flags.iter().map(|s| s.as_str()).collect();
let mut session = writer.start_session(timestamp, commit_hash, &flags);
for test in testruns {
session.insert(test);
session.insert(&test);
}
Ok(())
}

pub fn serialize(self) -> anyhow::Result<Vec<u8>> {
pub fn serialize(&mut self) -> anyhow::Result<Vec<u8>> {
let writer = self
.writer
.take()
.context("writer was already serialized")?;
let mut buffer = vec![];
self.writer.serialize(&mut buffer)?;
writer.serialize(&mut buffer)?;
Ok(buffer)
}
}

#[pyclass]
pub struct AggregationReader {
buffer: Vec<u8>,
_buffer: Vec<u8>,
format: TestAnalytics<'static>,
}

#[pyclass]
#[pyclass(get_all)]
pub struct TestAggregate {
// TODO
pub name: String,
// TODO:
pub test_id: String,

pub testsuite: Option<String>,
pub flags: Vec<String>,

pub failure_rate: f32,
pub flake_rate: f32,

// TODO:
pub updated_at: u32,
pub avg_duration: f64,

pub total_fail_count: u32,
pub total_flaky_fail_count: u32,
pub total_pass_count: u32,
pub total_skip_count: u32,

pub commits_where_fail: usize,

// TODO:
pub last_duration: f32,
}

#[pymethods]
Expand All @@ -69,16 +113,48 @@ impl AggregationReader {
// which we do not mutate, and which outlives the parsed format.
let format = unsafe { transmute(format) };

Ok(Self { buffer, format })
Ok(Self {
_buffer: buffer,
format,
})
}

#[pyo3(signature = (interval_start, interval_end, flag=None))]
#[pyo3(signature = (interval_start, interval_end, flags=None))]
pub fn get_test_aggregates(
&self,
interval_start: usize,
interval_end: usize,
flag: Option<&str>,
) -> Vec<TestAggregate> {
vec![]
flags: Option<Vec<String>>,
) -> anyhow::Result<Vec<TestAggregate>> {
let flags: Option<Vec<_>> = flags
.as_ref()
.map(|flags| flags.iter().map(|flag| flag.as_str()).collect());
let desired_range = interval_start..interval_end;

let tests = self.format.tests(desired_range, flags.as_deref())?;
let mut collected_tests = vec![];

for test in tests {
let test = test?;

collected_tests.push(TestAggregate {
name: test.name()?.into(),
test_id: "TODO".into(),
testsuite: Some(test.testsuite()?.into()),
flags: test.flags()?.into_iter().map(|s| s.into()).collect(),
failure_rate: test.aggregates().failure_rate,
flake_rate: test.aggregates().flake_rate,
updated_at: 0, // TODO
avg_duration: test.aggregates().avg_duration,
total_fail_count: test.aggregates().total_fail_count,
total_flaky_fail_count: test.aggregates().total_flaky_fail_count,
total_pass_count: test.aggregates().total_pass_count,
total_skip_count: test.aggregates().total_skip_count,
commits_where_fail: test.aggregates().failing_commits,
last_duration: 0., // TODO
});
}

Ok(collected_tests)
}
}
27 changes: 14 additions & 13 deletions src/binary/format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,18 @@ impl<'data> TestAnalytics<'data> {
pub fn tests(
&self,
desired_range: Range<usize>,
flag: Option<&str>,
) -> Result<impl Iterator<Item = Test<'data, '_>> + '_, TestAnalyticsError> {
let matching_flags_sets = if let Some(flag) = flag {
flags: Option<&[&str]>,
) -> Result<
impl Iterator<Item = Result<Test<'data, '_>, TestAnalyticsError>> + '_,
TestAnalyticsError,
> {
let matching_flags_sets = if let Some(flags) = flags {
let flag_sets = self.flags_set.iter(self.string_bytes);

let mut matching_flags_sets: SmallVec<u32, 4> = Default::default();
for res in flag_sets {
let (offset, flags) = res?;
if flags.contains(&flag) {
let (offset, flag_set) = res?;
if flags.iter().any(|flag| flag_set.contains(&flag.as_ref())) {
matching_flags_sets.push(offset);
}
}
Expand Down Expand Up @@ -132,11 +135,11 @@ impl<'data> TestAnalytics<'data> {
&self.testdata[adjusted_range],
);

Some(Test {
Some(aggregates.map(|aggregates| Test {
container: self,
data: test,
aggregates,
})
}))
});
Ok(tests)
}
Expand Down Expand Up @@ -211,7 +214,7 @@ impl Aggregates {
commithashes_bytes: &[u8],
all_failing_commits: &mut HashSet<CommitHash>,
data: &[raw::TestData],
) -> Self {
) -> Result<Self, TestAnalyticsError> {
let mut total_pass_count = 0;
let mut total_fail_count = 0;
let mut total_skip_count = 0;
Expand All @@ -225,10 +228,8 @@ impl Aggregates {
total_flaky_fail_count += testdata.total_flaky_fail_count as u32;
total_duration += testdata.total_duration as f64;

// TODO: make sure we validate this data ahead of time!
let failing_commits =
CommitHashesSet::read_raw(commithashes_bytes, testdata.failing_commits_set)
.unwrap();
CommitHashesSet::read_raw(commithashes_bytes, testdata.failing_commits_set)?;
all_failing_commits.extend(failing_commits);
}

Expand All @@ -246,7 +247,7 @@ impl Aggregates {
(0., 0., 0.)
};

Aggregates {
Ok(Aggregates {
total_pass_count,
total_fail_count,
total_skip_count,
Expand All @@ -258,6 +259,6 @@ impl Aggregates {
avg_duration,

failing_commits,
}
})
}
}
Loading

0 comments on commit 618e703

Please sign in to comment.