diff --git a/pufferlib/config/ocean/minatar_freeway.ini b/pufferlib/config/ocean/minatar_freeway.ini new file mode 100644 index 0000000000..72cd09c49f --- /dev/null +++ b/pufferlib/config/ocean/minatar_freeway.ini @@ -0,0 +1,13 @@ +[base] +package = ocean +env_name = puffer_minatar_freeway +policy_name = Policy +rnn_name = Recurrent + +[env] +num_envs = 1024 +use_minimal_action_set = True + +[train] +total_timesteps = 50_000_000 +minibatch_size = 32768 diff --git a/pufferlib/ocean/environment.py b/pufferlib/ocean/environment.py index 6c56a4ea20..eadbf845d7 100644 --- a/pufferlib/ocean/environment.py +++ b/pufferlib/ocean/environment.py @@ -127,6 +127,7 @@ def make_multiagent(buf=None, **kwargs): 'enduro': 'Enduro', 'tetris': 'Tetris', 'cartpole': 'Cartpole', + 'minatar_freeway': 'MinAtarFreeway', 'moba': 'Moba', 'matsci': 'Matsci', 'memory': 'Memory', diff --git a/pufferlib/ocean/minatar_freeway/binding.c b/pufferlib/ocean/minatar_freeway/binding.c new file mode 100644 index 0000000000..d31b72a3a9 --- /dev/null +++ b/pufferlib/ocean/minatar_freeway/binding.c @@ -0,0 +1,19 @@ +#include "minatar_freeway.h" + +#define Env MinAtarFreeway +#include "../env_binding.h" + +static int my_init(Env* env, PyObject* args, PyObject* kwargs) { + env->use_minimal_action_set = unpack(kwargs, "use_minimal_action_set"); + env->sticky_action_prob = unpack(kwargs, "sticky_action_prob"); + init(env); + return 0; +} + +static int my_log(PyObject* dict, Log* log) { + assign_to_dict(dict, "perf", log->perf); + assign_to_dict(dict, "score", log->score); + assign_to_dict(dict, "episode_return", log->episode_return); + assign_to_dict(dict, "episode_length", log->episode_length); + return 0; +} diff --git a/pufferlib/ocean/minatar_freeway/minatar_freeway.c b/pufferlib/ocean/minatar_freeway/minatar_freeway.c new file mode 100644 index 0000000000..019642abcf --- /dev/null +++ b/pufferlib/ocean/minatar_freeway/minatar_freeway.c @@ -0,0 +1,33 @@ +#include "minatar_freeway.h" + +int main() { + MinAtarFreeway env = { + .use_minimal_action_set = false, + .sticky_action_prob = 0.1f, + }; + init(&env); + env.observations = (int*)calloc(10*10*7, sizeof(int)); + env.actions = (int*)calloc(1, sizeof(int)); + env.rewards = (float*)calloc(1, sizeof(float)); + env.terminals = (unsigned char*)calloc(1, sizeof(unsigned char)); + + c_reset(&env); + c_render(&env); + while (!WindowShouldClose()) { + if (IsKeyDown(KEY_LEFT_SHIFT)) { + env.actions[0] = 0; + if (IsKeyDown(KEY_UP) || IsKeyDown(KEY_W)) env.actions[0] = UP; + if (IsKeyDown(KEY_DOWN) || IsKeyDown(KEY_S)) env.actions[0] = DOWN; + } else { + env.actions[0] = NOOP; + } + c_step(&env); + c_render(&env); + } + free(env.observations); + free(env.actions); + free(env.rewards); + free(env.terminals); + c_close(&env); +} + diff --git a/pufferlib/ocean/minatar_freeway/minatar_freeway.h b/pufferlib/ocean/minatar_freeway/minatar_freeway.h new file mode 100644 index 0000000000..c102b92df9 --- /dev/null +++ b/pufferlib/ocean/minatar_freeway/minatar_freeway.h @@ -0,0 +1,272 @@ + +#include +#include +#include +#include "raylib.h" + +const unsigned char NOOP = 0; +const unsigned char LEFT = 1; +const unsigned char UP = 2; +const unsigned char RIGHT = 3; +const unsigned char DOWN = 4; +const unsigned char FIRE = 5; + +const unsigned char PLAYER_SPEED = 3; +const unsigned short int TIME_LIMIT = 2500; +// 9 moves to get across freeway +const unsigned int MAX_SCORE = (TIME_LIMIT / PLAYER_SPEED) / 9; + +const unsigned char FULL_ACTION_SET[6] = {NOOP, LEFT, UP, RIGHT, DOWN, FIRE}; +const unsigned char MINIMAL_ACTION_SET[3] = {NOOP, UP, DOWN}; + +typedef struct { + float perf; // Recommended 0-1 normalized single real number perf metric + float score; // Recommended unnormalized single real number perf metric + float episode_return; // Recommended metric: sum of agent rewards over episode + float episode_length; // Recommended metric: number of steps of agent episode + // Any extra fields you add here may be exported to Python in binding.c + float n; // Required as the last field +} Log; + + +typedef struct { + Log log; + // 10 x 10 x 7 + int* observations; // Required. You can use any obs type, but make sure it matches in Python! + int* actions; // Required. int* for discrete/multidiscrete, float* for box + float* rewards; + unsigned char* terminals; // Required. We don't yet have truncations as standard yet + int* prev_action; + bool use_minimal_action_set; + float sticky_action_prob; + int** cars; // 8 x 4 + int position; + int move_timer; + int terminate_timer; + float episode_score; +} MinAtarFreeway; + +void add_log(MinAtarFreeway* env) { + env->log.perf += env->episode_score / (float)MAX_SCORE; + env->log.score += env->episode_score; + env->log.episode_length += env->terminate_timer; + env->log.episode_return += env->episode_score; + env->log.n++; +} + +int random_int(int min, int max){ + // from: https://c-faq.com/lib/randrange.html + return min + rand() / (RAND_MAX / (max - min + 1) + 1); +} + +int min(int a, int b) { + if (a < b) { + return a; + } + return b; +} + +int max(int a, int b) { + if (a > b) { + return a; + } + return b; +} + +void init(MinAtarFreeway* env) { + env->cars = (int**)(calloc(8, sizeof(int*))); + for (int i = 0; i < 8; i++) { + env->cars[i] = (int*)(calloc(4, sizeof(int))); + } + env->prev_action = (int*)calloc(1, sizeof(int)); +} + +void randomize_cars(MinAtarFreeway* env, bool initialize) { + for (int i = 0; i < 8; i++) { + int speed = random_int(1, 5); + int direction = 2 * random_int(0, 1) - 1; + if (initialize) { + env->cars[i][0] = 0; + env->cars[i][1] = i + 1; + } + env->cars[i][2] = speed; + env->cars[i][3] = speed * direction; + } + return; +} + +int get_index(int h, int w, int c) { + return h + 10 * w + 100 * c; +} + +void get_obs(MinAtarFreeway* env) { + memset(env->observations, 0, 10 * 10 * 7*sizeof(int)); + env->observations[get_index(env->position, 4, 0)] = 1; + for (int i = 0; i < 8; i++) { + env->observations[get_index(env->cars[i][1], env->cars[i][0], 1)] = 1; + int back_x0; + if (env->cars[i][3] > 0) { + back_x0 = env->cars[i][0] - 1; + } else { + back_x0 = env->cars[i][0] + 1; + } + if (back_x0 < 0) { + back_x0 = 9; + } else if (back_x0 > 9) { + back_x0 = 0; + } + int trail = abs(env->cars[i][3]) + 1; + env->observations[get_index(env->cars[i][1], back_x0, trail)] = 1; + } +} + +void c_reset(MinAtarFreeway* env) { + env->position = 9; + env->episode_score = 0.0f; + env->move_timer = PLAYER_SPEED; + env->terminate_timer = 0; + memset(env->prev_action, 0, sizeof(int)); + randomize_cars(env, true); + get_obs(env); +} + + +void c_step(MinAtarFreeway* env) { + env->terminals[0] = 0; + int action; + float reward = 0.0; + + if (rand() < ((RAND_MAX + 1u) * env->sticky_action_prob)){ + action = env->prev_action[0]; + } else { + if (env->use_minimal_action_set) { + action = MINIMAL_ACTION_SET[env->actions[0]]; + } else { + action = FULL_ACTION_SET[env->actions[0]]; + } + } + env->prev_action[0] = action; + + // update player + + if (env->move_timer == 0) { + env->move_timer = PLAYER_SPEED; + if (action == 2) { + env->position = max(0, env->position - 1); + } else if (action == 4) { + env->position = min(9, env->position + 1); + } + } else { + env->move_timer--; + } + + if (env->position == 0) { + reward++; + randomize_cars(env, false); + env->position = 9; + } + + // update cars + for (int i = 0; i < 8; i++) { + // player is always in column 4 + if ((env->cars[i][0] == 4) && (env->cars[i][1] == env->position)) { + env->position = 9; + } else if (env->cars[i][2] == 0) { + env->cars[i][2] = abs(env->cars[i][3]); + if (env->cars[i][3] > 0) { + env->cars[i][0]++; + } else { + env->cars[i][0]--; + } + if (env->cars[i][0] < 0) { + env->cars[i][0] = 9; + } else if (env->cars[i][0] > 9) { + env->cars[i][0] = 0; + } + if ((env->cars[i][0] == 4) && (env->cars[i][1] == env->position)) { + env->position = 9; + } + } else { + env->cars[i][2]--; + } + } + + env->terminate_timer++; + env->rewards[0] = reward; + env->episode_score += reward; + if (env->terminate_timer > TIME_LIMIT) { + env->terminals[0] = 1; + add_log(env); + c_reset(env); + } + get_obs(env); + return; +} + +unsigned char U8(float x) { + int v = (int)(x * 255.0f + 0.5f); + if (v < 0) { + v = 0; + } + if (v > 255) { + v = 255; + } + return (unsigned char)v; +} + +Color RGBf(float r, float g, float b) { + return (Color){U8(r), U8(g), U8(b), 255}; +} + +void c_render(MinAtarFreeway* env) { + if (!IsWindowReady()) { + InitWindow(30 * 10, 30 * 10, "PufferLib MinAtar Freeway"); + SetTargetFPS(10); + } + + // Standard across our envs so exiting is always the same + if (IsKeyDown(KEY_ESCAPE)) { + exit(0); + } + + // from https://github.com/sotetsuk/pgx-minatar/blob/main/utils.py + const Color palette[8] = { + BLACK, + RGBf(0.1041941874f, 0.1163201922f, 0.2327552016f), + RGBf(0.0852351161f, 0.3266177900f, 0.2973201283f), + RGBf(0.2653876155f, 0.4675654910f, 0.1908220645f), + RGBf(0.6328422475f, 0.4747981096f, 0.2907020921f), + RGBf(0.8306875711f, 0.5175161304f, 0.6628221029f), + RGBf(0.7779565181f, 0.7069421943f, 0.9314406084f), + RGBf(0.7964528048f, 0.9086689735f, 0.9398253501f), + }; + BeginDrawing(); + ClearBackground(BLACK); + + for (int h = 0; h < 10; h++) { + for (int w = 0; w < 10; w++) { + int code = 0; + for (int c = 0; c < 7; c++) { + if (env->observations[get_index(h, w, c)]) { + code = c + 1; + } + } + int x = w * 30; + int y = h * 30; + DrawRectangle(x, y, 30, 30, palette[code]); + } + } + + EndDrawing(); +} + +void c_close(MinAtarFreeway* env) { + if (IsWindowReady()) { + CloseWindow(); + } + free(env->prev_action); + for (int i = 0; i < 8; i++) { + free(env->cars[i]); + } + free(env->cars); +} diff --git a/pufferlib/ocean/minatar_freeway/minatar_freeway.py b/pufferlib/ocean/minatar_freeway/minatar_freeway.py new file mode 100644 index 0000000000..2c9f0b9fa1 --- /dev/null +++ b/pufferlib/ocean/minatar_freeway/minatar_freeway.py @@ -0,0 +1,86 @@ +import gymnasium +import numpy as np + +import pufferlib +from pufferlib.ocean.minatar_freeway import binding + + +class MinAtarFreeway(pufferlib.PufferEnv): + def __init__( + self, + num_envs=1, + render_mode=None, + log_interval=128, + use_minimal_action_set=False, + sticky_action_prob=0.1, + buf=None, + seed=0, + ): + self.single_observation_space = gymnasium.spaces.Box( + low=0, high=1, shape=(10 * 10 * 7,), dtype=np.uint8 + ) + if use_minimal_action_set: + self.single_action_space = gymnasium.spaces.Discrete(6) + else: + self.single_action_space = gymnasium.spaces.Discrete(3) + self.render_mode = render_mode + self.num_agents = num_envs + self.log_interval = log_interval + + super().__init__(buf) + self.c_envs = binding.vec_init( + self.observations, + self.actions, + self.rewards, + self.terminals, + self.truncations, + num_envs, + seed, + use_minimal_action_set=use_minimal_action_set, + sticky_action_prob=sticky_action_prob, + ) + + def reset(self, seed=0): + binding.vec_reset(self.c_envs, seed) + self.tick = 0 + return self.observations, [] + + def step(self, actions): + self.tick += 1 + + self.actions[:] = actions + binding.vec_step(self.c_envs) + + info = [] + if self.tick % self.log_interval == 0: + info.append(binding.vec_log(self.c_envs)) + + return (self.observations, self.rewards, self.terminals, self.truncations, info) + + def render(self): + binding.vec_render(self.c_envs, 0) + + def close(self): + binding.vec_close(self.c_envs) + + +if __name__ == "__main__": + N = 4096 + + env = MinAtarFreeway(num_envs=N) + env.reset() + steps = 0 + + CACHE = 1024 + actions = np.random.randint(0, 6, (CACHE, N)) + + i = 0 + import time + + start = time.time() + while time.time() - start < 10: + env.step(actions[i % CACHE]) + steps += N + i += 1 + + print("MinAtar Freeway SPS:", int(steps / (time.time() - start)))