forked from unixpickle/muniverse
-
Notifications
You must be signed in to change notification settings - Fork 0
/
cursor.go
101 lines (89 loc) · 2.25 KB
/
cursor.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
package muniverse
import (
"image"
"image/color"
"time"
"github.com/unixpickle/muniverse/chrome"
)
const (
mouseSize = 13
mouseStem = 5
)
type cursorEnv struct {
Env
initX int
initY int
curX int
curY int
}
// CursorEnv creates a wrapped environment which renders a
// cursor at the current mouse location.
//
// At every episode, the mouse is initialized to the given
// x and y coordinates.
//
// By default, Chrome will not render a mouse.
// Thus, it is necessary to render a mouse manually.
//
// When the resulting Env is closed, e is closed as well.
func CursorEnv(e Env, initX, initY int) Env {
return &cursorEnv{Env: e, initX: initX, initY: initY}
}
func (c *cursorEnv) Reset() error {
c.curX = c.initX
c.curY = c.initY
return c.Env.Reset()
}
func (c *cursorEnv) Step(t time.Duration, events ...interface{}) (reward float64,
done bool, err error) {
reward, done, err = c.Env.Step(t, events...)
for _, evt := range events {
if mouse, ok := evt.(*chrome.MouseEvent); ok {
c.curX = mouse.X
c.curY = mouse.Y
}
}
return
}
func (c *cursorEnv) Observe() (obs Obs, err error) {
obs, err = c.Env.Observe()
if err != nil {
return
}
img, err := obs.Image()
if err != nil {
return
}
width, height := img.Bounds().Dx(), img.Bounds().Dy()
newImg := image.NewRGBA(image.Rect(0, 0, width, height))
for y := 0; y < height; y++ {
for x := 0; x < width; x++ {
widthForY := (y - c.curY) / 2
switch true {
case y >= c.curY+mouseSize && y < c.curY+mouseSize+mouseStem &&
x == c.curX:
fallthrough
case y > c.curY && y < c.curY+mouseSize && x < c.curX+widthForY &&
x > c.curX-widthForY:
newImg.Set(x, y, color.Gray{Y: 0})
case (x == c.curX+widthForY || x == c.curX-widthForY) &&
y > c.curY && y < c.curY+mouseSize:
fallthrough
case y == c.curY && x == c.curX:
fallthrough
case y >= c.curY+mouseSize && y < c.curY+mouseSize+mouseStem &&
(x == c.curX-1 || x == c.curX+1):
fallthrough
case y == c.curY+mouseSize+mouseStem &&
(x >= c.curX-1 && x <= c.curX+1):
fallthrough
case y == c.curY+mouseSize && x < c.curX+widthForY &&
x > c.curX-widthForY:
newImg.Set(x, y, color.Gray{Y: 0xff})
default:
newImg.Set(x, y, img.At(x, y))
}
}
}
return &imageObs{Img: newImg}, nil
}