Instead of implementing a DQN ourselves again, in this tutorial we will use stable-baselines3: a Python library for Reinforcement Learning!
Notebook Cell
import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import MSELoss
from matplotlib import animation
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
# set animations to jshtml to render them in browser
# plt.rcParams["animation.html"] = "jshtml"
SEED = 19
ENV_NAME = "LunarLander-v3"
rng = np.random.default_rng(SEED)
def new_seed(rng):
return rng.integers(10_000).item()Notebook Cell
def replay(frames):
fig, ax = plt.subplots()
img = ax.imshow(frames[0])
ax.axis("off")
def update(frame):
img.set_data(frame)
return [img]
anim = FuncAnimation(fig, update, frames=frames, interval=30, blit=True)
plt.close(fig)
return HTML(anim.to_jshtml())Notebook Cell
def play_episode(agent, env, seed=19):
"""Run one episode and return replay."""
observation, info = env.reset(seed=seed)
frames = []
terminated = False
truncated = False
while not (terminated or truncated):
frames.append(env.render())
action = agent.act(observation)
# apply the selected action to the environment
observation, reward, terminated, truncated, info = env.step(action)
return replay(frames)Notebook Cell
class Agent:
"""Abstract base class for our cart pole agents."""
def act(self, obs):
raise NotImplementedError()
class RandomAgent(Agent):
"""Select a random action."""
def __init__(self, n=4, rng=None):
self.rng = np.random.default_rng(rng)
self.n = n
def act(self, obs):
return self.rng.integers(self.n)Import stable-baselines3:
import stable_baselines3 as sb3Let’s look at the LunarLander environment, which is a bit more tricky than CartPole:
env = gym.make("LunarLander-v3", render_mode="rgb_array")
obs, info = env.reset(seed=SEED)
terminated = False
truncated = False
frames = []
while not terminated and not truncated:
frame = env.render()
frames.append(frame)
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
env.close()replay(frames)Loading...
Now, let’s train a DQN for the LunarLander environment via sb3:
env = gym.make("LunarLander-v3", render_mode="rgb_array")
# model = sb3.DQN("MlpPolicy", env, verbose=1)
model = sb3.DQN("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=50_000, log_interval=10_000)Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
<stable_baselines3.dqn.dqn.DQN at 0x7f230facbbf0>class SB3Agent(Agent):
def __init__(self, model):
self.model = model
def act(self, obs):
action, _states = self.model.predict(obs)
return actionsb3_agent = SB3Agent(model)play_episode(sb3_agent, env)Loading...