-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlibtorch.h
40 lines (29 loc) · 1.22 KB
/
libtorch.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#pragma once
#include <torch/script.h> // One-stop header.
#include <future>
#include <memory>
#include <queue>
#include <string>
#include <utility>
#include <vector>
#include "GameField.h"
class NeuralNetwork {
public:
using return_type = std::vector<std::vector<double>>;
NeuralNetwork(std::string model_path, bool use_gpu, unsigned int batch_size);
~NeuralNetwork();
std::future<return_type> commit(GameField* game_field); // commit task to queue
void set_batch_size(unsigned int batch_size) { // set batch_size
this->batch_size = batch_size;
};
using task_type = std::pair<torch::Tensor, std::promise<return_type>>;
void infer(); // infer
std::unique_ptr<std::thread> loop; // call infer in loop
bool running; // is running
std::queue<task_type> tasks; // tasks queue
std::mutex lock; // lock for tasks queue
std::condition_variable cv; // condition variable for tasks queue
std::shared_ptr<torch::jit::script::Module> module; // torch module
unsigned int batch_size; // batch size
bool use_gpu; // use gpu
};