forked from QwenLM/qwen.cpp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
qwen.h
417 lines (320 loc) · 12.7 KB
/
qwen.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
#pragma once
#include "tiktoken.h"
#include <ggml.h>
#include <sstream>
#include <string>
#include <unordered_map>
#include <vector>
#ifdef GGML_USE_CUBLAS
#include <ggml-cuda.h>
#endif
namespace qwen {
class QwenTokenizer;
// ===== common =====
static constexpr size_t MB = 1024 * 1024;
static const std::string PAT_STR = R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?:$|[^\S])|\s+)";
class LogMessageFatal {
public:
LogMessageFatal(const char *file, int line) { oss_ << file << ':' << line << ' '; }
[[noreturn]] ~LogMessageFatal() noexcept(false) { throw std::runtime_error(oss_.str()); }
auto stream() -> std::ostringstream & { return oss_; }
private:
std::ostringstream oss_;
};
#define QWEN_THROW ::qwen::LogMessageFatal(__FILE__, __LINE__).stream()
#define QWEN_CHECK(cond) \
if (!(cond)) \
QWEN_THROW << "check failed (" #cond ") "
ggml_tensor *tensor_assign_buffers(ggml_tensor *tensor);
auto tensor_to_device(ggml_tensor *tensor) -> ggml_tensor *;
auto tensor_to_cpu(ggml_tensor *tensor) -> ggml_tensor *;
auto get_num_physical_cores() -> int;
auto get_default_num_threads() -> int;
struct ggml_context_deleter_t {
auto operator()(ggml_context *ctx) const noexcept -> void { ggml_free(ctx); }
};
using unique_ggml_context_t = std::unique_ptr<ggml_context, ggml_context_deleter_t>;
static inline auto make_unique_ggml_context(
size_t mem_size, void *mem_buffer, bool no_alloc
) -> unique_ggml_context_t {
return unique_ggml_context_t(ggml_init({mem_size, mem_buffer, no_alloc}));
}
struct uninitialized_char {
char m;
uninitialized_char() {}
};
auto ggml_graph_compute_helper(std::vector<uninitialized_char> &buf, ggml_cgraph *graph, int n_threads) -> void;
struct ModelContext {
ggml_type dtype;
unique_ggml_context_t ctx_w; // weight
unique_ggml_context_t ctx_kv; // kv cache
unique_ggml_context_t ctx_b; // buffer
ggml_cgraph gf;
ggml_scratch scratch;
std::vector<uninitialized_char> compute_buffer; // BLAS buffer
std::vector<uninitialized_char> scratch_buffer; // intermediate tensor buffer
std::string_view weight_buffer; // mapped weight
std::vector<uninitialized_char> work_buffer; // temporary buffer for graph computing
auto init_device_context() -> void;
};
class Embedding {
public:
Embedding() : weight(nullptr) {}
Embedding(ModelContext *ctx, int num_embeddings, int embedding_dim)
: weight(ggml_new_tensor_2d(ctx->ctx_w.get(), ctx->dtype, embedding_dim, num_embeddings)) {}
auto forward(ModelContext *ctx, ggml_tensor *input) const -> ggml_tensor *;
ggml_tensor *weight;
};
class Linear {
public:
Linear() : weight(nullptr), bias(nullptr) {}
Linear(ModelContext *ctx, int in_features, int out_features, bool use_bias = true)
: weight(ggml_new_tensor_2d(ctx->ctx_w.get(), ctx->dtype, in_features, out_features)),
bias(use_bias ? ggml_new_tensor_1d(ctx->ctx_w.get(), GGML_TYPE_F32, out_features) : nullptr) {}
auto in_features() const -> int { return weight->ne[0]; }
auto out_features() const -> int { return weight->ne[1]; }
auto forward(ModelContext *ctx, ggml_tensor *input) const -> ggml_tensor *;
ggml_tensor *weight; // [out_features, in_features]
ggml_tensor *bias; // [out_features]
};
class RMSNorm {
public:
RMSNorm() : weight(nullptr), inplace(true) {}
RMSNorm(ModelContext *ctx, int normalized_shape, bool inplace = true)
: weight(ggml_new_tensor_1d(ctx->ctx_w.get(), GGML_TYPE_F32, normalized_shape)), inplace(inplace) {}
auto forward(ModelContext *ctx, ggml_tensor *input, float eps = 1e-5f) const -> ggml_tensor *;
ggml_tensor *weight;
bool inplace;
};
class BaseStreamer{
public:
virtual ~BaseStreamer() = default;
virtual auto put(const std::vector<int> &output_ids) -> void = 0;
virtual auto end() -> void = 0;
};
class StreamerGroup : public BaseStreamer {
public:
StreamerGroup(std::vector<std::shared_ptr<BaseStreamer>> streamers) : streamers_(std::move(streamers)) {}
auto put(const std::vector<int> &output_ids) -> void override;
auto end() -> void override;
private:
std::vector<std::shared_ptr<BaseStreamer>> streamers_;
};
// reference: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py
class TextStreamer : public BaseStreamer {
public:
TextStreamer(std::ostream &os, QwenTokenizer *tokenizer)
: os_(os), tokenizer_(tokenizer), is_prompt_(true), print_len_(0) {}
auto put(const std::vector<int> &output_ids) -> void override;
auto end() -> void override;
private:
std::ostream &os_;
QwenTokenizer *tokenizer_;
bool is_prompt_;
std::vector<int> token_cache_;
int print_len_;
};
class PerfStreamer : public BaseStreamer {
public:
PerfStreamer() : start_us_(0), prompt_us_(0), end_us_(0), num_prompt_tokens_(0), num_output_tokens_(0) {}
auto put(const std::vector<int> &output_ids) -> void override;
auto end() -> void override { end_us_ = ggml_time_us(); }
auto reset() -> void;
auto to_string() -> std::string const;
auto num_prompt_tokens() const -> int64_t { return num_prompt_tokens_; }
auto prompt_total_time_us() const -> int64_t { return prompt_us_ - start_us_; }
auto prompt_token_time_us() const -> int64_t {
return num_prompt_tokens() ? prompt_total_time_us() / num_prompt_tokens() : 0;
}
auto num_output_tokens() const -> int64_t { return num_output_tokens_; }
auto output_total_time_us() const -> int64_t { return end_us_ - prompt_us_; }
auto output_token_time_us() const -> int64_t {
return num_output_tokens() ? output_total_time_us() / num_output_tokens() : 0;
}
private:
int64_t start_us_;
int64_t prompt_us_;
int64_t end_us_;
int64_t num_prompt_tokens_;
int64_t num_output_tokens_;
};
class MappedFile {
public:
MappedFile(const std::string &path);
~MappedFile();
public:
char *data;
size_t size;
};
class ModelLoader {
public:
ModelLoader(std::string_view buffer) : data(buffer.data()), size(buffer.size()), ptr(buffer.data()) {}
auto tell() const -> int64_t { return ptr - data; }
auto seek(int64_t offset, int whence) -> void;
template <typename T>
auto read_basic() -> T {
T obj = *(T *)ptr;
ptr += sizeof(T);
return obj;
}
auto read_string(size_t length) -> std::string;
auto read_tensor(const std::string &name, ggml_tensor *tensor) -> void;
public:
const char *const data;
size_t size;
const char *ptr;
};
// ===== generation =====
struct GenerationConfig {
int max_length;
int max_context_length;
bool do_sample;
int top_k;
float top_p;
float temperature;
float repetition_penalty;
int num_threads;
GenerationConfig(int max_length = 2048, int max_context_length = 512, bool do_sample = true, int top_k = 0,
float top_p = 0.7, float temperature = 0.95, float repetition_penalty = 1.f, int num_threads = 0)
: max_length(max_length), max_context_length(max_context_length), do_sample(do_sample), top_k(top_k),
top_p(top_p), temperature(temperature), repetition_penalty(repetition_penalty), num_threads(num_threads) {}
};
struct TokenIdScore {
int id;
float score;
TokenIdScore() = default;
TokenIdScore(int id, float score) : id(id), score(score) {}
auto operator<(const TokenIdScore &other) const -> bool { return score < other.score; }
auto operator>(const TokenIdScore &other) const -> bool { return score > other.score; }
friend auto operator<<(std::ostream &os, const TokenIdScore &self) -> std::ostream & {
return os << "TokenIdScore(id=" << self.id << ", score=" << self.score << ")";
}
};
// ===== Qwen-7B =====
struct QwenConfig {
// common attributes
ggml_type dtype;
int vocab_size;
int hidden_size;
int num_attention_heads;
int num_kv_heads;
int num_hidden_layers;
int intermediate_size;
// for sequence generation
int max_length;
// for tokenizer
int eos_token_id;
int pad_token_id;
int im_start_id;
int im_end_id;
};
class QwenTokenizer {
public:
QwenTokenizer(const std::string & tiktoken_path, const QwenConfig &config);
auto encode(const std::string &text, int max_length) const -> std::vector<int>;
auto decode(const std::vector<int> &ids) const -> std::string;
auto encode_history(const std::vector<std::string> &history, int max_length) const -> std::vector<int>;
auto build_prompt(const std::vector<std::string> &history) const -> std::string;
auto is_special_id(int id) const -> bool;
tiktoken::tiktoken tokenizer;
int eos_token_id;
int im_start_id;
int im_end_id;
};
class QwenAttention {
public:
QwenAttention() : num_attention_heads(0), num_kv_heads(0) {}
QwenAttention(ModelContext *ctx, int hidden_size, int num_attention_heads, int num_kv_heads, int max_length);
auto forward(ModelContext *ctx, ggml_tensor *hidden_states, ggml_tensor *KQ_pos, int n_ctx) const -> ggml_tensor *;
int num_attention_heads;
int num_kv_heads;
Linear c_attn;
Linear c_proj;
ggml_tensor *k_cache; // [n_head, maxlen, head_size]
ggml_tensor *v_cache; // [n_head, head_size, maxlen]
};
class QwenMLP {
public:
QwenMLP() = default;
QwenMLP(ModelContext * ctx, int hidden_size, int intermediate_size)
: w1(ctx, hidden_size, intermediate_size / 2, false),
w2(ctx, hidden_size, intermediate_size / 2, false),
c_proj(ctx, intermediate_size / 2, hidden_size, false) {}
auto forward(ModelContext *ctx, ggml_tensor *hidden_states) const -> ggml_tensor *;
Linear w1;
Linear w2;
Linear c_proj;
};
class QwenBlock {
public:
QwenBlock() = default;
QwenBlock(ModelContext *ctx, int hidden_size, int num_attention_heads, int num_kv_heads, int intermediate_size, int max_length)
: ln_1(ctx, hidden_size, false),
attn(ctx, hidden_size, num_attention_heads, num_kv_heads, max_length),
ln_2(ctx, hidden_size, false),
mlp(ctx, hidden_size, intermediate_size) {}
auto forward(ModelContext *ctx, ggml_tensor *hidden_states, ggml_tensor *KQ_pos, int n_ctx) const -> ggml_tensor *;
RMSNorm ln_1;
QwenAttention attn;
RMSNorm ln_2;
QwenMLP mlp;
};
class QwenModel {
public:
QwenModel() = default;
QwenModel(ModelContext *ctx, const QwenConfig &config);
auto forward(ModelContext *ctx, ggml_tensor *input_ids, ggml_tensor *KQ_pos, int n_ctx) const -> ggml_tensor *;
Embedding wte;
std::vector<QwenBlock> layers;
RMSNorm ln_f;
};
class QwenForCausalLM {
public:
QwenForCausalLM(const QwenConfig &config);
~QwenForCausalLM();
auto generate_next_token(
const std::vector<int> &input_ids,
const GenerationConfig &gen_config,
int n_past,
int n_ctx
) -> int;
auto generate(
const std::vector<int> &input_ids,
const GenerationConfig &gen_config,
BaseStreamer *streamer = nullptr
) -> std::vector<int>;
// logits processor
static auto sampling_repetition_penalty(float *first, float *last, const std::vector<int32_t> &input_ids,
float penalty) -> void;
// logits warper
static auto sampling_temperature(float *first, float *last, float temp) -> void;
static auto sampling_top_k(TokenIdScore *first, TokenIdScore *kth, TokenIdScore *last) -> void;
static auto sampling_top_p(TokenIdScore *first, TokenIdScore *last, float top_p) -> TokenIdScore *;
static auto sampling_softmax_inplace(TokenIdScore *first, TokenIdScore *last) -> void;
auto load(ModelLoader &loader) -> void;
auto forward(ModelContext *ctx, ggml_tensor *input_ids, ggml_tensor *KQ_pos, int n_ctx) const -> ggml_tensor *;
static constexpr size_t MEM_SIZE = 512 * MB; // 2k context
static constexpr size_t SCRATCH_SIZE = 1280 * MB; // 2k context
QwenConfig config;
QwenModel transformer;
Linear lm_head;
private:
ModelContext ctx_;
std::vector<std::pair<std::string, ggml_tensor *>> state_dict_;
};
// ===== pipeline =====
class Pipeline {
public:
Pipeline(const std::string &path, const std::string &tiktoken_path);
auto generate(const std::vector<int> &input_ids, const GenerationConfig &gen_config,
BaseStreamer *streamer = nullptr) const -> std::vector<int>;
auto generate(const std::string &prompt, const GenerationConfig &gen_config,
BaseStreamer *streamer = nullptr) const -> std::string;
auto chat(const std::vector<std::string> &history, const GenerationConfig &gen_config,
BaseStreamer *streamer = nullptr) const -> std::string;
public:
std::unique_ptr<QwenTokenizer> tokenizer;
std::unique_ptr<QwenForCausalLM> model;
std::unique_ptr<MappedFile> mapped_file;
};
} // namespace qwen