Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Stable Baselines 3

Instead of implementing a DQN ourselves again, in this tutorial we will use stable-baselines3: a Python library for Reinforcement Learning!

Notebook Cell
import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import MSELoss

from matplotlib import animation
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

# set animations to jshtml to render them in browser
# plt.rcParams["animation.html"] = "jshtml"

SEED = 19
ENV_NAME = "LunarLander-v3"

rng = np.random.default_rng(SEED)

def new_seed(rng):
    return rng.integers(10_000).item()
Notebook Cell
def replay(frames):
    fig, ax = plt.subplots()
    img = ax.imshow(frames[0])
    ax.axis("off")

    def update(frame):
        img.set_data(frame)
        return [img]

    anim = FuncAnimation(fig, update, frames=frames, interval=30, blit=True)
    plt.close(fig)
    return HTML(anim.to_jshtml())
Notebook Cell
def play_episode(agent, env, seed=19):
    """Run one episode and return replay."""
    observation, info = env.reset(seed=seed)

    frames = []

    terminated = False
    truncated = False

    while not (terminated or truncated):
        frames.append(env.render())
        action = agent.act(observation)
        # apply the selected action to the environment
        observation, reward, terminated, truncated, info = env.step(action) 
    
    return replay(frames)
Notebook Cell
class Agent:
    """Abstract base class for our cart pole agents."""

    def act(self, obs):
        raise NotImplementedError()

class RandomAgent(Agent):
    """Select a random action."""
    def __init__(self, n=4, rng=None):
        self.rng = np.random.default_rng(rng)
        self.n = n

    def act(self, obs):
        return self.rng.integers(self.n)

Import stable-baselines3:

import stable_baselines3 as sb3

Let’s look at the LunarLander environment, which is a bit more tricky than CartPole:

env = gym.make("LunarLander-v3", render_mode="rgb_array")

obs, info = env.reset(seed=SEED)

terminated = False
truncated = False
frames = []

while not terminated and not truncated:
    frame = env.render()
    frames.append(frame)
    action = env.action_space.sample()
    obs, reward, terminated, truncated, info = env.step(action)

env.close()
replay(frames)
Loading...

Now, let’s train a DQN for the LunarLander environment via sb3:

env = gym.make("LunarLander-v3", render_mode="rgb_array")

# model = sb3.DQN("MlpPolicy", env, verbose=1)
model = sb3.DQN("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=50_000, log_interval=10_000)
Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
<stable_baselines3.dqn.dqn.DQN at 0x7f230facbbf0>
class SB3Agent(Agent):
    def __init__(self, model):
        self.model = model

    def act(self, obs):
        action, _states = self.model.predict(obs)
        return action
sb3_agent = SB3Agent(model)
play_episode(sb3_agent, env)
Loading...