-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.rs
111 lines (91 loc) · 2.68 KB
/
main.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
use std::io::{self, BufRead};
use memmap2::Mmap;
use std::process::exit;
use clap::Parser;
use heavykeeper::TopK;
use memchr::memchr;
#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
struct Args {
#[arg(short = 'k')]
k: usize,
#[arg(short = 'w', default_value_t = 8)]
width: usize,
#[arg(short = 'd', default_value_t = 2048)]
depth: usize,
#[arg(short = 'y', default_value_t = 0.9)]
decay: f64,
#[arg(short = 'f')]
input: Option<String>,
}
fn main() {
let args = Args::parse();
let mut topk = TopK::<String>::new(args.k, args.width, args.depth, args.decay);
if args.input.is_none() {
let stdin = io::stdin();
let mut stdin_lock = stdin.lock();
let mut buffer = Vec::with_capacity(1024 * 1024);
while stdin_lock.read_until(b'\n', &mut buffer).unwrap() > 0 {
process_bytes(&buffer, &mut topk);
buffer.clear();
}
} else {
let file = std::fs::File::open(args.input.unwrap()).unwrap_or_else(|e| {
eprintln!("Error: {}", e);
exit(1);
});
let mmap = unsafe { Mmap::map(&file) }.unwrap_or_else(|e| {
eprintln!("Error mapping file: {}", e);
exit(1);
});
process_bytes(&mmap, &mut topk);
}
for node in topk.list() {
println!("{} {}", node.item, node.count);
}
}
fn process_bytes(bytes: &[u8], topk: &mut TopK<String>) {
let mut pos = 0;
let len = bytes.len();
let mut words = Vec::with_capacity(1024);
while pos < len {
// Skip any whitespace
while pos < len && (bytes[pos] == b' ' || bytes[pos] == b'\n') {
pos += 1;
}
if pos >= len {
break;
}
// Find next space using memchr
let word_start = pos;
pos = if let Some(space_pos) = memchr(b' ', &bytes[pos..len]) {
word_start + space_pos
} else {
len
};
// Create word more efficiently
let word_len = pos - word_start;
let mut word = String::with_capacity(word_len);
unsafe {
let vec = word.as_mut_vec();
vec.set_len(word_len);
std::ptr::copy_nonoverlapping(
bytes.as_ptr().add(word_start),
vec.as_mut_ptr(),
word_len
);
}
words.push(word);
// Batch process when buffer is full
if words.len() >= 1024 {
for word in words.drain(..) {
topk.add(word);
}
}
pos += 1;
}
// Process remaining words
for word in words {
topk.add(word);
}
}