65
65
66
66
# sphinx_gallery_start_ignore
67
67
import warnings
68
+
68
69
warnings .filterwarnings ("ignore" )
69
- import multiprocessing
70
+ from torch import multiprocessing
71
+
70
72
# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside
71
73
# `__main__` method call, but for the easy of reading the code switch to fork
72
74
# which is also a default spawn method in Google's Colaboratory
73
75
try :
74
76
multiprocessing .set_start_method ("fork" )
75
77
except RuntimeError :
76
- assert multiprocessing .get_start_method () == "fork"
78
+ pass
79
+
77
80
# sphinx_gallery_end_ignore
78
81
79
82
80
- import torchrl
81
83
import torch
82
84
import tqdm
83
- from typing import Tuple
85
+
84
86
85
87
###############################################################################
86
88
# We will execute the policy on CUDA if available
87
- device = torch .device ("cuda:0" if torch .cuda .is_available () else "cpu" )
89
+ is_fork = multiprocessing .get_start_method () == "fork"
90
+ device = (
91
+ torch .device (0 )
92
+ if torch .cuda .is_available () and not is_fork
93
+ else torch .device ("cpu" )
94
+ )
88
95
collector_device = torch .device ("cpu" ) # Change the device to ``cuda`` to use CUDA
89
96
90
97
###############################################################################
@@ -244,23 +251,18 @@ def make_value_estimator(self, value_type: ValueEstimators, **hyperparams):
244
251
hp .update (hyperparams )
245
252
value_key = "state_action_value"
246
253
if value_type == ValueEstimators .TD1 :
247
- self ._value_estimator = TD1Estimator (
248
- value_network = self .actor_critic , value_key = value_key , ** hp
249
- )
254
+ self ._value_estimator = TD1Estimator (value_network = self .actor_critic , ** hp )
250
255
elif value_type == ValueEstimators .TD0 :
251
- self ._value_estimator = TD0Estimator (
252
- value_network = self .actor_critic , value_key = value_key , ** hp
253
- )
256
+ self ._value_estimator = TD0Estimator (value_network = self .actor_critic , ** hp )
254
257
elif value_type == ValueEstimators .GAE :
255
258
raise NotImplementedError (
256
259
f"Value type { value_type } it not implemented for loss { type (self )} ."
257
260
)
258
261
elif value_type == ValueEstimators .TDLambda :
259
- self ._value_estimator = TDLambdaEstimator (
260
- value_network = self .actor_critic , value_key = value_key , ** hp
261
- )
262
+ self ._value_estimator = TDLambdaEstimator (value_network = self .actor_critic , ** hp )
262
263
else :
263
264
raise NotImplementedError (f"Unknown value type { value_type } " )
265
+ self ._value_estimator .set_keys (value = value_key )
264
266
265
267
266
268
###############################################################################
@@ -311,7 +313,7 @@ def _loss_actor(
311
313
def _loss_value (
312
314
self ,
313
315
tensordict ,
314
- ) -> Tuple [ torch . Tensor , torch . Tensor , torch . Tensor , torch . Tensor ] :
316
+ ):
315
317
td_copy = tensordict .clone ()
316
318
317
319
# V(s, a)
@@ -349,7 +351,7 @@ def _loss_value(
349
351
# value and actor loss, collect the cost values and write them in a ``TensorDict``
350
352
# delivered to the user.
351
353
352
- from tensordict . tensordict import TensorDict , TensorDictBase
354
+ from tensordict import TensorDict , TensorDictBase
353
355
354
356
355
357
def _forward (self , input_tensordict : TensorDictBase ) -> TensorDict :
@@ -457,6 +459,7 @@ def make_env(from_pixels=False):
457
459
raise NotImplementedError
458
460
459
461
env_kwargs = {
462
+ "device" : device ,
460
463
"from_pixels" : from_pixels ,
461
464
"pixels_only" : from_pixels ,
462
465
"frame_skip" : 2 ,
@@ -519,16 +522,6 @@ def make_transformed_env(
519
522
# syntax.
520
523
env .append_transform (RewardScaling (loc = 0.0 , scale = reward_scaling ))
521
524
522
- double_to_float_list = []
523
- double_to_float_inv_list = []
524
- if env_library is DMControlEnv :
525
- # ``DMControl`` requires double-precision
526
- double_to_float_list += [
527
- "reward" ,
528
- "action" ,
529
- ]
530
- double_to_float_inv_list += ["action" ]
531
-
532
525
# We concatenate all states into a single "observation_vector"
533
526
# even if there is a single tensor, it'll be renamed in "observation_vector".
534
527
# This facilitates the downstream operations as we know the name of the
@@ -544,12 +537,7 @@ def make_transformed_env(
544
537
# version of the transform
545
538
env .append_transform (ObservationNorm (in_keys = [out_key ], standard_normal = True ))
546
539
547
- double_to_float_list .append (out_key )
548
- env .append_transform (
549
- DoubleToFloat (
550
- in_keys = double_to_float_list , in_keys_inv = double_to_float_inv_list
551
- )
552
- )
540
+ env .append_transform (DoubleToFloat ())
553
541
554
542
env .append_transform (StepCounter (max_frames_per_traj ))
555
543
@@ -874,9 +862,6 @@ def make_ddpg_actor(
874
862
reset_at_each_iter = False ,
875
863
split_trajs = False ,
876
864
device = collector_device ,
877
- # device for execution
878
- storing_device = collector_device ,
879
- # device where data will be stored and passed
880
865
exploration_type = ExplorationType .RANDOM ,
881
866
)
882
867
0 commit comments