@@ -50,7 +50,7 @@ def test_drill_discrete(self):
50
50
except Exception :
51
51
self .fail ("Discrete drill raises error." )
52
52
53
- def test_drill_dexterity_multicontinuous (self ):
53
+ def test_drill_manipulate_multicontinuous (self ):
54
54
"""Test drilling of discrete agent (LunarLander)."""
55
55
56
56
try :
@@ -62,7 +62,7 @@ def test_drill_dexterity_multicontinuous(self):
62
62
except Exception :
63
63
self .fail ("HumanoidManipulateBlockDiscreteAsynchronous drill raises error." )
64
64
65
- def test_drill_dexterity_continuous (self ):
65
+ def test_drill_manipulate_continuous (self ):
66
66
"""Test drilling of discrete agent (LunarLander)."""
67
67
68
68
try :
@@ -74,6 +74,30 @@ def test_drill_dexterity_continuous(self):
74
74
except Exception :
75
75
self .fail ("HumanoidManipulateBlockDiscreteAsynchronous drill raises error." )
76
76
77
+ def test_drill_reach (self ):
78
+ """Test drilling of discrete agent (LunarLander)."""
79
+
80
+ try :
81
+ wrappers = [StateNormalizationTransformer , RewardNormalizationTransformer ]
82
+ env = make_env ("ReachAbsolute-v0" , reward_config = None , transformers = wrappers )
83
+ build_models = get_model_builder (model = "shadow" , model_type = "lstm" , shared = False )
84
+ agent = PPOAgent (build_models , env , workers = 2 , horizon = 128 , distribution = BetaPolicyDistribution (env ))
85
+ agent .drill (n = 2 , epochs = 2 , batch_size = 64 )
86
+ except Exception :
87
+ self .fail ("ReachAbsolute drill raises error." )
88
+
89
+ def test_drill_freereach (self ):
90
+ """Test drilling of free reach agent."""
91
+
92
+ try :
93
+ wrappers = [StateNormalizationTransformer , RewardNormalizationTransformer ]
94
+ env = make_env ("FreeReachAbsolute-v0" , reward_config = None , transformers = wrappers )
95
+ build_models = get_model_builder (model = "shadow" , model_type = "lstm" , shared = False )
96
+ agent = PPOAgent (build_models , env , workers = 2 , horizon = 128 , distribution = BetaPolicyDistribution (env ))
97
+ agent .drill (n = 2 , epochs = 2 , batch_size = 64 )
98
+ except Exception :
99
+ self .fail ("FreeReachAbsolute drill raises error." )
100
+
77
101
78
102
if __name__ == '__main__' :
79
103
unittest .main ()
0 commit comments