Skip to content

Commit 3ce795f

Browse files
vmoenssvekars
andauthored
Update torchrl==0.3.0 tutos (#2759)
* Update dqn_with_rnn_tutorial collector device --------- Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
1 parent fb63044 commit 3ce795f

File tree

6 files changed

+118
-84
lines changed

6 files changed

+118
-84
lines changed

advanced_source/coding_ddpg.py

+20-35
Original file line numberDiff line numberDiff line change
@@ -65,26 +65,33 @@
6565

6666
# sphinx_gallery_start_ignore
6767
import warnings
68+
6869
warnings.filterwarnings("ignore")
69-
import multiprocessing
70+
from torch import multiprocessing
71+
7072
# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside
7173
# `__main__` method call, but for the easy of reading the code switch to fork
7274
# which is also a default spawn method in Google's Colaboratory
7375
try:
7476
multiprocessing.set_start_method("fork")
7577
except RuntimeError:
76-
assert multiprocessing.get_start_method() == "fork"
78+
pass
79+
7780
# sphinx_gallery_end_ignore
7881

7982

80-
import torchrl
8183
import torch
8284
import tqdm
83-
from typing import Tuple
85+
8486

8587
###############################################################################
8688
# We will execute the policy on CUDA if available
87-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
89+
is_fork = multiprocessing.get_start_method() == "fork"
90+
device = (
91+
torch.device(0)
92+
if torch.cuda.is_available() and not is_fork
93+
else torch.device("cpu")
94+
)
8895
collector_device = torch.device("cpu") # Change the device to ``cuda`` to use CUDA
8996

9097
###############################################################################
@@ -244,23 +251,18 @@ def make_value_estimator(self, value_type: ValueEstimators, **hyperparams):
244251
hp.update(hyperparams)
245252
value_key = "state_action_value"
246253
if value_type == ValueEstimators.TD1:
247-
self._value_estimator = TD1Estimator(
248-
value_network=self.actor_critic, value_key=value_key, **hp
249-
)
254+
self._value_estimator = TD1Estimator(value_network=self.actor_critic, **hp)
250255
elif value_type == ValueEstimators.TD0:
251-
self._value_estimator = TD0Estimator(
252-
value_network=self.actor_critic, value_key=value_key, **hp
253-
)
256+
self._value_estimator = TD0Estimator(value_network=self.actor_critic, **hp)
254257
elif value_type == ValueEstimators.GAE:
255258
raise NotImplementedError(
256259
f"Value type {value_type} it not implemented for loss {type(self)}."
257260
)
258261
elif value_type == ValueEstimators.TDLambda:
259-
self._value_estimator = TDLambdaEstimator(
260-
value_network=self.actor_critic, value_key=value_key, **hp
261-
)
262+
self._value_estimator = TDLambdaEstimator(value_network=self.actor_critic, **hp)
262263
else:
263264
raise NotImplementedError(f"Unknown value type {value_type}")
265+
self._value_estimator.set_keys(value=value_key)
264266

265267

266268
###############################################################################
@@ -311,7 +313,7 @@ def _loss_actor(
311313
def _loss_value(
312314
self,
313315
tensordict,
314-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
316+
):
315317
td_copy = tensordict.clone()
316318

317319
# V(s, a)
@@ -349,7 +351,7 @@ def _loss_value(
349351
# value and actor loss, collect the cost values and write them in a ``TensorDict``
350352
# delivered to the user.
351353

352-
from tensordict.tensordict import TensorDict, TensorDictBase
354+
from tensordict import TensorDict, TensorDictBase
353355

354356

355357
def _forward(self, input_tensordict: TensorDictBase) -> TensorDict:
@@ -457,6 +459,7 @@ def make_env(from_pixels=False):
457459
raise NotImplementedError
458460

459461
env_kwargs = {
462+
"device": device,
460463
"from_pixels": from_pixels,
461464
"pixels_only": from_pixels,
462465
"frame_skip": 2,
@@ -519,16 +522,6 @@ def make_transformed_env(
519522
# syntax.
520523
env.append_transform(RewardScaling(loc=0.0, scale=reward_scaling))
521524

522-
double_to_float_list = []
523-
double_to_float_inv_list = []
524-
if env_library is DMControlEnv:
525-
# ``DMControl`` requires double-precision
526-
double_to_float_list += [
527-
"reward",
528-
"action",
529-
]
530-
double_to_float_inv_list += ["action"]
531-
532525
# We concatenate all states into a single "observation_vector"
533526
# even if there is a single tensor, it'll be renamed in "observation_vector".
534527
# This facilitates the downstream operations as we know the name of the
@@ -544,12 +537,7 @@ def make_transformed_env(
544537
# version of the transform
545538
env.append_transform(ObservationNorm(in_keys=[out_key], standard_normal=True))
546539

547-
double_to_float_list.append(out_key)
548-
env.append_transform(
549-
DoubleToFloat(
550-
in_keys=double_to_float_list, in_keys_inv=double_to_float_inv_list
551-
)
552-
)
540+
env.append_transform(DoubleToFloat())
553541

554542
env.append_transform(StepCounter(max_frames_per_traj))
555543

@@ -874,9 +862,6 @@ def make_ddpg_actor(
874862
reset_at_each_iter=False,
875863
split_trajs=False,
876864
device=collector_device,
877-
# device for execution
878-
storing_device=collector_device,
879-
# device where data will be stored and passed
880865
exploration_type=ExplorationType.RANDOM,
881866
)
882867

advanced_source/pendulum.py

+25-7
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
is an integrative part of reinforcement learning and control engineering.
1111
1212
TorchRL provides a set of tools to do this in multiple contexts.
13-
This tutorial demonstrates how to use PyTorch and TorchRL code a pendulum
13+
This tutorial demonstrates how to use PyTorch and TorchRL code a pendulum
1414
simulator from the ground up.
1515
It is freely inspired by the Pendulum-v1 implementation from `OpenAI-Gym/Farama-Gymnasium
1616
control library <https://github.com/Farama-Foundation/Gymnasium>`__.
@@ -49,9 +49,9 @@
4949
# cover a broader range of features of the environment API in TorchRL.
5050
#
5151
# Modeling stateless environments gives users full control over the input and
52-
# outputs of the simulator: one can reset an experiment at any stage or actively
53-
# modify the dynamics from the outside. However, it assumes that we have some control
54-
# over a task, which may not always be the case: solving a problem where we cannot
52+
# outputs of the simulator: one can reset an experiment at any stage or actively
53+
# modify the dynamics from the outside. However, it assumes that we have some control
54+
# over a task, which may not always be the case: solving a problem where we cannot
5555
# control the current state is more challenging but has a much wider set of applications.
5656
#
5757
# Another advantage of stateless environments is that they can enable
@@ -73,14 +73,31 @@
7373
# simulation graph.
7474
# * Finally, we will train a simple policy to solve the system we implemented.
7575
#
76+
77+
# sphinx_gallery_start_ignore
78+
import warnings
79+
80+
warnings.filterwarnings("ignore")
81+
from torch import multiprocessing
82+
83+
# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside
84+
# `__main__` method call, but for the easy of reading the code switch to fork
85+
# which is also a default spawn method in Google's Colaboratory
86+
try:
87+
multiprocessing.set_start_method("fork")
88+
except RuntimeError:
89+
pass
90+
91+
# sphinx_gallery_end_ignore
92+
7693
from collections import defaultdict
7794
from typing import Optional
7895

7996
import numpy as np
8097
import torch
8198
import tqdm
99+
from tensordict import TensorDict, TensorDictBase
82100
from tensordict.nn import TensorDictModule
83-
from tensordict.tensordict import TensorDict, TensorDictBase
84101
from torch import nn
85102

86103
from torchrl.data import BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec
@@ -167,7 +184,7 @@
167184
# of :meth:`~torchrl.envs.EnvBase.step` in the input ``tensordict`` to enforce
168185
# input/output consistency.
169186
#
170-
# Typically, for stateful environments, this will look like this:
187+
# Typically, for stateful environments, this will look like this:
171188
#
172189
# .. code-block::
173190
#
@@ -221,6 +238,7 @@
221238
# needed as the state needs to be read from the environment.
222239
#
223240

241+
224242
def _step(tensordict):
225243
th, thdot = tensordict["th"], tensordict["thdot"] # th := theta
226244

@@ -896,7 +914,7 @@ def plot():
896914
######################################################################
897915
# Conclusion
898916
# ----------
899-
#
917+
#
900918
# In this tutorial, we have learned how to code a stateless environment from
901919
# scratch. We touched the subjects of:
902920
#

intermediate_source/dqn_with_rnn_tutorial.py

+32-7
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,22 @@
6868
# -----
6969
#
7070

71+
# sphinx_gallery_start_ignore
72+
import warnings
73+
74+
warnings.filterwarnings("ignore")
75+
from torch import multiprocessing
76+
77+
# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside
78+
# `__main__` method call, but for the easy of reading the code switch to fork
79+
# which is also a default spawn method in Google's Colaboratory
80+
try:
81+
multiprocessing.set_start_method("fork")
82+
except RuntimeError:
83+
pass
84+
85+
# sphinx_gallery_end_ignore
86+
7187
import torch
7288
import tqdm
7389
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
@@ -88,10 +104,15 @@
88104
TransformedEnv,
89105
)
90106
from torchrl.envs.libs.gym import GymEnv
91-
from torchrl.modules import ConvNet, EGreedyWrapper, LSTMModule, MLP, QValueModule
107+
from torchrl.modules import ConvNet, EGreedyModule, LSTMModule, MLP, QValueModule
92108
from torchrl.objectives import DQNLoss, SoftUpdate
93109

94-
device = torch.device(0) if torch.cuda.device_count() else torch.device("cpu")
110+
is_fork = multiprocessing.get_start_method() == "fork"
111+
device = (
112+
torch.device(0)
113+
if torch.cuda.is_available() and not is_fork
114+
else torch.device("cpu")
115+
)
95116

96117
######################################################################
97118
# Environment
@@ -293,11 +314,15 @@
293314
# DQN being a deterministic algorithm, exploration is a crucial part of it.
294315
# We'll be using an :math:`\epsilon`-greedy policy with an epsilon of 0.2 decaying
295316
# progressively to 0.
296-
# This decay is achieved via a call to :meth:`~torchrl.modules.EGreedyWrapper.step`
317+
# This decay is achieved via a call to :meth:`~torchrl.modules.EGreedyModule.step`
297318
# (see training loop below).
298319
#
299-
stoch_policy = EGreedyWrapper(
300-
stoch_policy, annealing_num_steps=1_000_000, spec=env.action_spec, eps_init=0.2
320+
exploration_module = EGreedyModule(
321+
annealing_num_steps=1_000_000, spec=env.action_spec, eps_init=0.2
322+
)
323+
stoch_policy = Seq(
324+
stoch_policy,
325+
exploration_module,
301326
)
302327

303328
######################################################################
@@ -362,7 +387,7 @@
362387
# For the sake of efficiency, we're only running a few thousands iterations
363388
# here. In a real setting, the total number of frames should be set to 1M.
364389
#
365-
collector = SyncDataCollector(env, stoch_policy, frames_per_batch=50, total_frames=200)
390+
collector = SyncDataCollector(env, stoch_policy, frames_per_batch=50, total_frames=200, device=device)
366391
rb = TensorDictReplayBuffer(
367392
storage=LazyMemmapStorage(20_000), batch_size=4, prefetch=10
368393
)
@@ -403,7 +428,7 @@
403428
pbar.set_description(
404429
f"steps: {longest}, loss_val: {loss_vals['loss'].item(): 4.4f}, action_spread: {data['action'].sum(0)}"
405430
)
406-
stoch_policy.step(data.numel())
431+
exploration_module.step(data.numel())
407432
updater.step()
408433

409434
with set_exploration_type(ExplorationType.MODE), torch.no_grad():

intermediate_source/mario_rl_tutorial.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@
3232
#
3333
# %%bash
3434
# pip install gym-super-mario-bros==7.4.0
35-
# pip install tensordict==0.2.0
36-
# pip install torchrl==0.2.0
35+
# pip install tensordict==0.3.0
36+
# pip install torchrl==0.3.0
3737
#
3838

3939
import torch

0 commit comments

Comments
 (0)