Skip to content

Commit

Permalink
Implement gray scale frames
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikel committed Jul 28, 2024
1 parent d3a72a1 commit 882d0cb
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 16 deletions.
20 changes: 18 additions & 2 deletions craftium/craftium_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ class CraftiumEnv(Env):
:param pipe_proc: If `True`, the minetest process stderr and stdout will be piped into two files inside the run's directory. Otherwise, the minetest process will not be piped and its output will be shown in the terminal. This option is disabled by default to reduce verbosity, but can be useful for debugging.
:param mt_listen_timeout: Number of milliseconds to wait for MT to connect to the TCP channel. If the timeout is reached a Timeout exception is raised. **WARNING:** When using multiple (serial) MT environments, timeout can be easily reached for the last environment. In this case, you might want to increase the value of this parameter according to the number of environments.
:param mt_port: TCP port to employ for MT's internal client<->server communication. If not provided a random port in the [49152, 65535] range is used.
:params frameskip: The number of frames skipped between steps, 1 by default (disabled). Note that `max_timesteps` and `init_frames` parameters will be divided by the frameskip value.
:param frameskip: The number of frames skipped between steps, 1 by default (disabled). Note that `max_timesteps` and `init_frames` parameters will be divided by the frameskip value.
:param rgb_observations: Whether to use RGB images or gray scale images as observations. Note that RGB images are slower to send from MT to python via TCP. By default RGB images are used.
:param gray_scale_keepdim: If `True`, a singleton dimension will be added, i.e. observations are of the shape WxHx1. Otherwise, they are of shape WxH.
"""
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 30}

Expand All @@ -59,13 +61,16 @@ def __init__(
mt_listen_timeout: int = 10_000,
mt_port: Optional[int] = None,
frameskip: int = 1,
rgb_observations: bool = False,
gray_scale_keepdim: bool = False,
):
super(CraftiumEnv, self).__init__()

self.obs_width = obs_width
self.obs_height = obs_height
self.init_frames = init_frames // frameskip
self.max_timesteps = None if max_timesteps is None else max_timesteps // frameskip
self.gray_scale_keepdim = gray_scale_keepdim and (not rgb_observations)

# define the action space
action_dict = {}
Expand All @@ -75,7 +80,10 @@ def __init__(
action_dict[ACTION_ORDER[-1]] = Box(low=-1, high=1, shape=(2,), dtype=np.float32)
self.action_space = Dict(action_dict)

self.observation_space = Box(low=0, high=255, shape=(obs_width, obs_height, 3), dtype=np.uint8)
# define the observation space
n_chan = 3 if rgb_observations else 1
shape = (obs_width, obs_height, n_chan) if gray_scale_keepdim else (obs_width, obs_height)
self.observation_space = Box(low=0, high=255, shape=shape, dtype=np.uint8)

assert render_mode is None or render_mode in self.metadata["render_modes"]
self.render_mode = render_mode
Expand All @@ -86,6 +94,7 @@ def __init__(
img_height=self.obs_height,
port=tcp_port,
listen_timeout=mt_listen_timeout,
rgb_imgs=rgb_observations,
)

# handles the MT configuration and process
Expand All @@ -105,6 +114,7 @@ def __init__(
pipe_proc=pipe_proc,
mt_port=mt_port,
frameskip=frameskip,
rgb_frames=rgb_observations,
)

self.last_observation = None # used in render if "rgb_array"
Expand Down Expand Up @@ -161,6 +171,9 @@ def reset(
self.mt_chann.send([0]*21, 0, 0) # nop action

observation, _reward, _term = self.mt_chann.receive()
if not self.gray_scale_keepdim:
observation = observation[:, :, 0]

self.last_observation = observation

info = self._get_info()
Expand Down Expand Up @@ -191,6 +204,9 @@ def step(self, action):

# receive the new info from minetest
observation, reward, termination = self.mt_chann.receive()
if not self.gray_scale_keepdim:
observation = observation[:, :, 0]

self.last_observation = observation

info = self._get_info()
Expand Down
2 changes: 2 additions & 0 deletions craftium/minetest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
pipe_proc: bool = True,
mt_port: Optional[int] = None,
frameskip: int = 1,
rgb_frames: bool = True,
):
self.pipe_proc = pipe_proc

Expand Down Expand Up @@ -68,6 +69,7 @@ def __init__(

craftium_port=tcp_port,
frameskip=frameskip,
rgb_frames=rgb_frames,

# port used for MT's internal client<->server comm.
port=port,
Expand Down
15 changes: 12 additions & 3 deletions craftium/mt_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@


class MtChannel():
def __init__(self, img_width: int, img_height: int, port: Optional[int] = None, listen_timeout: int = 2000):
def __init__(
self,
img_width: int,
img_height: int,
port: Optional[int] = None,
listen_timeout: int = 2000,
rgb_imgs: bool = True
):
self.img_width = img_width
self.img_height = img_height
self.listen_timeout = listen_timeout
Expand All @@ -19,14 +26,16 @@ def __init__(self, img_width: int, img_height: int, port: Optional[int] = None,

# pre-compute the number of bytes that we should receive from MT.
# the RGB image + 8 bytes of the reward + 1 byte of the termination flag
self.rec_bytes = img_width*img_height*3 + 8 + 1
self.n_chan = 3 if rgb_imgs else 1
self.rec_bytes = img_width*img_height*self.n_chan + 8 + 1

def receive(self):
img, reward, termination = mt_server.server_recv(
self.connfd,
self.rec_bytes,
self.img_width,
self.img_height
self.img_height,
self.n_chan,
)
return img, reward, termination

Expand Down
8 changes: 4 additions & 4 deletions mt_server.c
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,13 @@ static PyObject* server_listen(PyObject* self, PyObject* args) {


static PyObject* server_recv(PyObject* self, PyObject* args) {
int connfd, n_bytes, obs_width, obs_height, n_read;
int connfd, n_bytes, obs_width, obs_height, n_read, n_channels;
double reward;
char *buff;

if (!PyArg_ParseTuple(args, "iiii", &connfd, &n_bytes, &obs_width, &obs_height)) {
if (!PyArg_ParseTuple(args, "iiiii", &connfd, &n_bytes, &obs_width, &obs_height, &n_channels)) {
PyErr_SetString(PyExc_TypeError,
"Arguments must be 4 integers: connection's fd, num. of bytes to read, obs. width, and obs. height.");
"Arguments must be 5 integers: connection's fd, num. of bytes to read, obs. width and height, and num. channels.");
return NULL;
}

Expand All @@ -123,7 +123,7 @@ static PyObject* server_recv(PyObject* self, PyObject* args) {
PyObject* py_reward = PyFloat_FromDouble(reward);

// Create the numpy array of the image
npy_intp dims[3] = {obs_width, obs_height, 3};
npy_intp dims[3] = {obs_width, obs_height, n_channels};
PyObject* array = PyArray_SimpleNewFromData(3, dims, NPY_UINT8, buff);
if (!array) {
PyErr_SetString(PyExc_RuntimeError, "Failed to create NumPy array");
Expand Down
28 changes: 21 additions & 7 deletions src/client/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,11 @@ void Client::pyConnStep() {
W*H*3 for the WxH RGB image, +8 for the reward value (a double),
and +1 for the episode termination flag
*/
obs_rwd_buffer_size = W*H*3 + 8 + 1;
if (g_settings->getBool("rgb_frames")) {
obs_rwd_buffer_size = W*H*3 + 8 + 1; // full RGB images
} else {
obs_rwd_buffer_size = W*H + 8 + 1; // grayscale images
}

/* If obs_rwd_buffer is not initialized, allocate memory for it now */
if (!obs_rwd_buffer) {
Expand All @@ -235,13 +239,23 @@ void Client::pyConnStep() {

/* Copy RGB image into a flat u8 array (obs_rwd_buffer) */
int i = 0;
for (int w=0; w<W; w++) {
if (g_settings->getBool("rgb_frames")) {
for (int h=0; h<H; h++) {
for (int w=0; w<W; w++) {
c = raw_image->getPixel(w, h).color;
obs_rwd_buffer[i] = (c>>16) & 0xff; // R
obs_rwd_buffer[i+1] = (c>>8) & 0xff; // G
obs_rwd_buffer[i+2] = c & 0xff; // B
i = i + 3;
}
}
} else {
for (int h=0; h<H; h++) {
c = raw_image->getPixel(w, h).color;
obs_rwd_buffer[i] = (c>>16) & 0xff; // R
obs_rwd_buffer[i+1] = (c>>8) & 0xff; // G
obs_rwd_buffer[i+2] = c & 0xff; // B
i = i + 3;
for (int w=0; w<W; w++) {
c = raw_image->getPixel(w, h).color;
obs_rwd_buffer[i] = (((c>>16) & 0xff) / 3) + (((c>>8) & 0xff) / 3) + ((c & 0xff) / 3);
i++;
}
}
}

Expand Down
1 change: 1 addition & 0 deletions src/defaultsettings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ void set_default_settings()
settings->setDefault("chat_weblink_color", "#8888FF");
settings->setDefault("craftium_port", "55555");
settings->setDefault("frameskip", "1");
settings->setDefault("rgb_frames", "true");

// Keymap
settings->setDefault("remote_port", "30000");
Expand Down

0 comments on commit 882d0cb

Please sign in to comment.