Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding updates from pwhiddy patch pokemon_red_minimal.py #50

Open
wants to merge 2 commits into
base: 0.5-cleanup
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 66 additions & 7 deletions pufferlib/environments/pokemon_red_minimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,17 @@
WindowEvent.PRESS_ARROW_RIGHT,
WindowEvent.PRESS_ARROW_UP,
WindowEvent.PRESS_BUTTON_A,
WindowEvent.PRESS_BUTTON_B,
WindowEvent.PRESS_BUTTON_START,
WindowEvent.PASS
WindowEvent.PRESS_BUTTON_B
]


if self.extra_buttons:
self.valid_actions.extend([
WindowEvent.PRESS_BUTTON_START,
WindowEvent.PASS
])


RELEASE_ARROW = [
WindowEvent.RELEASE_ARROW_DOWN,
WindowEvent.RELEASE_ARROW_LEFT,
Expand Down Expand Up @@ -77,6 +83,11 @@ def __init__(
self.downsample_factor = downsample_factor
self.init_state = init_state

self.explore_weight = 1 if 'explore_weight' not in config else config['explore_weight']
self.use_screen_explore = True if 'use_screen_explore' not in config else config['use_screen_explore']
self.reward_scale = 1 if 'reward_scale' not in config else config['reward_scale']
self.extra_buttons = False if 'extra_buttons' not in config else config['extra_buttons']

# Reward
self.r_healing = reward_scale_healing
self.r_level = reward_scale_level
Expand Down Expand Up @@ -127,15 +138,30 @@ def __init__(

self.reset()

def update_seen_coords(self):
x_pos = self.read_m(0xD362)
y_pos = self.read_m(0xD361)
map_n = self.read_m(0xD35E)
coord_string = f"x:{x_pos} y:{y_pos} m:{map_n}"
if self.get_levels_sum() >= 22 and not self.levels_satisfied:
self.levels_satisfied = True
self.base_explore = len(self.seen_coords)
self.seen_coords = {}

self.seen_coords[coord_string] = self.step_count

def reset(self, seed=None):
self.seed = seed

# restart game, skipping credits
#with open(self.init_state, "rb") as f:
# self.pyboy.load_state(f)

self.init_knn()

if self.use_screen_explore:
self.init_knn()
else:
self.init_map_mem

self.rewards = {
"healing" : 0,
"event" : 0,
Expand Down Expand Up @@ -166,11 +192,16 @@ def render(self, reduce_res=True):

def step(self, action):
self.run_action_on_emulator(action)
self.append_agent_stats(action)
ob = self.render()

# trim off memory from frame for knn index
self.update_frame_knn_index(ob)
if self.use_screen_explore:
self.update_frame_knn_index(obs_flat)
else:
self.update_seen_coords()

self.update_heal_reward()
reward, _ = self.compute_rewards()
self.s_health = self.read_hp_fraction()

Expand All @@ -184,6 +215,9 @@ def step(self, action):
self.step_count += 1
return ob, reward, False, step_limit_reached, info

def init_map_mem(self):
self.seen_coords = {}

def run_action_on_emulator(self, action):
# press button then release after some steps
self.pyboy.send_input(VALID_ACTIONS[action])
Expand All @@ -199,11 +233,33 @@ def run_action_on_emulator(self, action):
if action > 3 and action < 6:
# release button
self.pyboy.send_input(RELEASE_BUTTON[action - 4])
if action == WindowEvent.PRESS_BUTTON_START:
if self.valid_actions[action] == WindowEvent.PRESS_BUTTON_START:
self.pyboy.send_input(WindowEvent.RELEASE_BUTTON_START)
if i == self.act_freq - 1:
self.pyboy._rendering(True)
self.pyboy.tick()

def compute_rewards(self):
# addresses from https://datacrystal.romhacking.net/wiki/Pok%C3%A9mon_Red/Blue:RAM_map
# https://github.com/pret/pokered/blob/91dc3c9f9c8fd529bb6e8307b58b96efa0bec67e/constants/event_constants.asm

self.rewards_old = self.rewards.copy()

# adds up all event flags, exclude museum ticket
event_flags_start = 0xD747
event_flags_end = 0xD886
museum_ticket = (0xD754, 0)
base_event_flags = 13
return max(
sum(
[
self.bit_count(self.read_m(i))
for i in range(event_flags_start, event_flags_end)
]
)
- base_event_flags
- int(self.read_bit(museum_ticket[0], museum_ticket[1])),
0,)

def get_agent_stats(self, action):
return {
Expand Down Expand Up @@ -263,6 +319,9 @@ def compute_rewards(self):
self.rewards["badges"] = self.r_badge * self.get_badges()

# exploration reward
pre_rew = self.explore_weight * 0.005
post_rew = self.explore_weight * 0.01
cur_size = self.knn_index.get_current_count() if self.use_screen_explore else len(self.seen_coords)
curr_size = self.knn_index.get_current_count()
if self.s_levels_satisfied:
base = self.s_base_explore
Expand Down