-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmjsim.pyx
383 lines (334 loc) · 15.3 KB
/
mjsim.pyx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
from xml.dom import minidom
from mujoco_py.utils import remove_empty_lines
from mujoco_py.builder import build_callback_fn
from threading import Lock
_MjSim_render_lock = Lock()
ctypedef void (*substep_udd_t)(const mjModel* m, mjData* d)
cdef class MjSim(object):
"""MjSim represents a running simulation including its state.
Similar to Gym's ``MujocoEnv``, it internally wraps a :class:`.PyMjModel`
and a :class:`.PyMjData`.
Parameters
----------
model : :class:`.PyMjModel`
The model to simulate.
data : :class:`.PyMjData`
Optional container for the simulation state. Will be created if ``None``.
nsubsteps : int
Optional number of MuJoCo steps to run for every call to :meth:`.step`.
Buffers will be swapped only once per step.
udd_callback : fn(:class:`.MjSim`) -> dict
Optional callback for user-defined dynamics. At every call to
:meth:`.step`, it receives an MjSim object ``sim`` containing the
current user-defined dynamics state in ``sim.udd_state``, and returns the
next ``udd_state`` after applying the user-defined dynamics. This is
useful e.g. for reward functions that operate over functions of historical
state.
substep_callback : str or int or None
This uses a compiled C function as user-defined dynamics in substeps.
If given as a string, it's compiled as a C function and set as pointer.
If given as int, it's interpreted as a function pointer.
See :meth:`.set_substep_callback` for detailed info.
userdata_names : list of strings or None
This is a convenience parameter which is just set on the model.
Equivalent to calling ``model.set_userdata_names``
render_callback : callback for rendering.
"""
# MjRenderContext for rendering camera views.
cdef readonly list render_contexts
cdef readonly object _render_context_window
cdef readonly object _render_context_offscreen
# MuJoCo model
cdef readonly PyMjModel model
# MuJoCo data
"""
DATAZ
"""
# debug
cdef readonly PyMjData data
# cdef public PyMjData data
# Number of substeps when calling .step
cdef public int nsubsteps
# User defined state.
cdef public dict udd_state
# User defined dynamics callback
cdef readonly object _udd_callback
# Allows to store extra information in MjSim.
cdef readonly dict extras
# Function pointer for substep callback, stored as uintptr
cdef readonly uintptr_t substep_callback_ptr
# Callback executed before rendering.
cdef public object render_callback
def __cinit__(self, PyMjModel model, PyMjData data=None, int nsubsteps=1,
udd_callback=None, substep_callback=None, userdata_names=None,
render_callback=None):
self.nsubsteps = nsubsteps
self.model = model
if data is None:
with wrap_mujoco_warning():
_data = mj_makeData(self.model.ptr)
if _data == NULL:
raise Exception('mj_makeData failed!')
self.data = WrapMjData(_data, self.model)
else:
self.data = data
self.render_contexts = []
self._render_context_offscreen = None
self._render_context_window = None
self.udd_state = None
self.udd_callback = udd_callback
self.render_callback = render_callback
self.extras = {}
self.set_substep_callback(substep_callback, userdata_names)
def reset(self):
"""
Resets the simulation data and clears buffers.
"""
with wrap_mujoco_warning():
mj_resetData(self.model.ptr, self.data.ptr)
self.udd_state = None
self.step_udd()
def forward(self):
"""
Computes the forward kinematics. Calls ``mj_forward`` internally.
"""
with wrap_mujoco_warning():
mj_forward(self.model.ptr, self.data.ptr)
def step(self, with_udd=True):
"""
Advances the simulation by calling ``mj_step``.
If ``qpos`` or ``qvel`` have been modified directly, the user is required to call
:meth:`.forward` before :meth:`.step` if their ``udd_callback`` requires access to MuJoCo state
set during the forward dynamics.
"""
if with_udd:
self.step_udd()
# debug
cdef PyMjvPerturb pert
nbody = self.model.nbody
cdef mjtNum [:,::view.contiguous] xfrc
with wrap_mujoco_warning():
for _ in range(self.nsubsteps):
self.substep_callback()
# debug
if self._render_context_window:
pert = self._render_context_window.pert
# clear perturb force on other parts
# xfrc = self.data.xfrc_applied
# mju_zero(&xfrc[0][0], 6*self.model.nbody)
mjv_applyPerturbPose(self.model.ptr, self.data.ptr, pert.ptr, 0)
mjv_applyPerturbForce(self.model.ptr, self.data.ptr, pert.ptr)
# print('perturb select id: ', pert.select)
mj_step(self.model.ptr, self.data.ptr)
def render(self, width=None, height=None, *, camera_name=None, depth=False,
mode='offscreen', device_id=-1):
"""
Renders view from a camera and returns image as an `numpy.ndarray`.
Args:
- width (int): desired image width.
- height (int): desired image height.
- camera_name (str): name of camera in model. If None, the free
camera will be used.
- depth (bool): if True, also return depth buffer
- device (int): device to use for rendering (only for GPU-backed
rendering).
Returns:
- rgb (uint8 array): image buffer from camera
- depth (float array): depth buffer from camera (only returned
if depth=True)
"""
# debug
# print('render mode', mode)
if camera_name is None:
camera_id = None
else:
camera_id = self.model.camera_name2id(camera_name)
if mode == 'offscreen':
with _MjSim_render_lock:
if self._render_context_offscreen is None:
render_context = MjRenderContextOffscreen(
self, device_id=device_id)
else:
render_context = self._render_context_offscreen
render_context.render(
width=width, height=height, camera_id=camera_id)
return render_context.read_pixels(
width, height, depth=depth)
elif mode == 'window':
if self._render_context_window is None:
from mujoco_py.mjviewer import MjViewer
render_context = MjViewer(self)
else:
render_context = self._render_context_window
render_context.render()
else:
raise ValueError("Mode must be either 'window' or 'offscreen'.")
def add_render_context(self, render_context):
self.render_contexts.append(render_context)
if render_context.offscreen and self._render_context_offscreen is None:
self._render_context_offscreen = render_context
elif not render_context.offscreen and self._render_context_window is None:
self._render_context_window = render_context
@property
def udd_callback(self):
return self._udd_callback
@udd_callback.setter
def udd_callback(self, value):
self._udd_callback = value
self.udd_state = None
self.step_udd()
cpdef substep_callback(self):
if self.substep_callback_ptr:
(<mjfGeneric>self.substep_callback_ptr)(self.model.ptr, self.data.ptr)
def set_substep_callback(self, substep_callback, userdata_names=None):
'''
Set a substep callback function.
Parameters :
substep_callback : str or int or None
If `substep_callback` is a string, compile to function pointer and set.
See `builder.build_callback_fn()` for documentation.
If `substep_callback` is an int, we interpret it as a function pointer.
If `substep_callback` is None, we disable substep_callbacks.
userdata_names : list of strings or None
This is a convenience parameter, if not None, this is passed
onto ``model.set_userdata_names()``.
'''
if userdata_names is not None:
self.model.set_userdata_names(userdata_names)
if substep_callback is None:
self.substep_callback_ptr = 0
elif isinstance(substep_callback, int):
self.substep_callback_ptr = substep_callback
elif isinstance(substep_callback, str):
self.substep_callback_ptr = build_callback_fn(substep_callback,
self.model.userdata_names)
else:
raise TypeError('invalid: {}'.format(type(substep_callback)))
def step_udd(self):
if self._udd_callback is None:
self.udd_state = {}
else:
schema_example = self.udd_state
self.udd_state = self._udd_callback(self)
# Check to make sure the udd_state has consistent keys and dimension across steps
if schema_example is not None:
keys = set(schema_example.keys()) | set(self.udd_state.keys())
for key in keys:
assert key in schema_example, "Keys cannot be added to udd_state between steps."
assert key in self.udd_state, "Keys cannot be dropped from udd_state between steps."
if isinstance(schema_example[key], Number):
assert isinstance(self.udd_state[key], Number), \
"Every value in udd_state must be either a number or a numpy array"
else:
assert isinstance(self.udd_state[key], np.ndarray), \
"Every value in udd_state must be either a number or a numpy array"
assert self.udd_state[key].shape == schema_example[key].shape, \
"Numpy array values in udd_state must keep the same dimension across steps."
def get_state(self):
""" Returns a copy of the simulator state. """
qpos = np.copy(self.data.qpos)
qvel = np.copy(self.data.qvel)
if self.model.na == 0:
act = None
else:
act = np.copy(self.data.act)
udd_state = copy.deepcopy(self.udd_state)
return MjSimState(self.data.time, qpos, qvel, act, udd_state)
def set_state(self, value):
"""
Sets the state from an MjSimState.
If the MjSimState was previously unflattened from a numpy array, consider
set_state_from_flattened, as the defensive copy is a substantial overhead
in an inner loop.
Args:
- value (MjSimState): the desired state.
- call_forward: optionally call sim.forward(). Called by default if
the udd_callback is set.
"""
self.data.time = value.time
self.data.qpos[:] = np.copy(value.qpos)
self.data.qvel[:] = np.copy(value.qvel)
if self.model.na != 0:
self.data.act[:] = np.copy(value.act)
self.udd_state = copy.deepcopy(value.udd_state)
def set_state_from_flattened(self, value):
""" This helper method sets the state from an array without requiring a defensive copy."""
state = MjSimState.from_flattened(value, self)
self.data.time = state.time
self.data.qpos[:] = state.qpos
self.data.qvel[:] = state.qvel
if self.model.na != 0:
self.data.act[:] = state.act
self.udd_state = state.udd_state
def save(self, file, format='xml', keep_inertials=False):
"""
Saves the simulator model and state to a file as either
a MuJoCo XML or MJB file. The current state is saved as
a keyframe in the model file. This is useful for debugging
using MuJoCo's `simulate` utility.
Note that this doesn't save the UDD-state which is
part of MjSimState, since that's not supported natively
by MuJoCo. If you want to save the model together with
the UDD-state, you should use the `get_xml` or `get_mjb`
methods on `MjModel` together with `MjSim.get_state` and
save them with e.g. pickle.
Args:
- file (IO stream): stream to write model to.
- format: format to use (either 'xml' or 'mjb')
- keep_inertials (bool): if False, removes all <inertial>
properties derived automatically for geoms by MuJoco. Note
that this removes ones that were provided by the user
as well.
"""
xml_str = self.model.get_xml()
dom = minidom.parseString(xml_str)
mujoco_node = dom.childNodes[0]
assert mujoco_node.tagName == 'mujoco'
keyframe_el = dom.createElement('keyframe')
key_el = dom.createElement('key')
keyframe_el.appendChild(key_el)
mujoco_node.appendChild(keyframe_el)
def str_array(arr):
return " ".join(map(str, arr))
key_el.setAttribute('time', str(self.data.time))
key_el.setAttribute('qpos', str_array(self.data.qpos))
key_el.setAttribute('qvel', str_array(self.data.qvel))
if self.data.act is not None:
key_el.setAttribute('act', str_array(self.data.act))
if not keep_inertials:
for element in dom.getElementsByTagName('inertial'):
element.parentNode.removeChild(element)
result_xml = remove_empty_lines(dom.toprettyxml(indent=" " * 4))
if format == 'xml':
file.write(result_xml)
elif format == 'mjb':
new_model = load_model_from_xml(result_xml)
file.write(new_model.get_mjb())
else:
raise ValueError("Unsupported format. Valid ones are 'xml' and 'mjb'")
def ray(self,
np.ndarray[np.float64_t, mode="c", ndim=1] pnt,
np.ndarray[np.float64_t, mode="c", ndim=1] vec,
include_static_geoms=True, exclude_body=-1):
"""
Cast a ray into the scene, and return the first valid geom it intersects.
pnt - origin point of the ray in world coordinates (X Y Z)
vec - direction of the ray in world coordinates (X Y Z)
include_static_geoms - if False, we exclude geoms that are children of worldbody.
exclude_body - if this is a body ID, we exclude all children geoms of this body.
Returns (distance, geom_id) where
distance - distance along ray until first collision with geom
geom_id - id of the geom the ray collided with
If no collision was found in the scene, return (-1, None)
NOTE: sometimes self.forward() needs to be called before self.ray().
"""
cdef int geom_id
cdef mjtNum distance
cdef mjtNum[::view.contiguous] pnt_view = pnt
cdef mjtNum[::view.contiguous] vec_view = vec
distance = mj_ray(self.model.ptr, self.data.ptr,
&pnt_view[0], &vec_view[0], NULL,
1 if include_static_geoms else 0,
exclude_body,
&geom_id)
return (distance, geom_id)