Skip to content

Commit cf760e5

Browse files
committed
Fixed FreeReach
1 parent edd963e commit cf760e5

File tree

2 files changed

+27
-3
lines changed

2 files changed

+27
-3
lines changed

angorapy/environments/reach.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def _sample_goal(self):
318318

319319
def get_target_finger_position(self):
320320
"""Get position of the target finger in space."""
321-
return self.sim.data.get_site_xpos(FINGERTIP_SITE_NAMES[np.where(self.goal == 1)[0].item()]).flatten()
321+
return self.data.site(FINGERTIP_SITE_NAMES[np.where(self.goal == 1)[0].item()]).xpos.flatten()
322322

323323
def _is_success(self, achieved_goal, desired_goal):
324324
d = get_fingertip_distance(self.get_thumb_position(), self.get_target_finger_position())

tests/test_agent.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def test_drill_discrete(self):
5050
except Exception:
5151
self.fail("Discrete drill raises error.")
5252

53-
def test_drill_dexterity_multicontinuous(self):
53+
def test_drill_manipulate_multicontinuous(self):
5454
"""Test drilling of discrete agent (LunarLander)."""
5555

5656
try:
@@ -62,7 +62,7 @@ def test_drill_dexterity_multicontinuous(self):
6262
except Exception:
6363
self.fail("HumanoidManipulateBlockDiscreteAsynchronous drill raises error.")
6464

65-
def test_drill_dexterity_continuous(self):
65+
def test_drill_manipulate_continuous(self):
6666
"""Test drilling of discrete agent (LunarLander)."""
6767

6868
try:
@@ -74,6 +74,30 @@ def test_drill_dexterity_continuous(self):
7474
except Exception:
7575
self.fail("HumanoidManipulateBlockDiscreteAsynchronous drill raises error.")
7676

77+
def test_drill_reach(self):
78+
"""Test drilling of discrete agent (LunarLander)."""
79+
80+
try:
81+
wrappers = [StateNormalizationTransformer, RewardNormalizationTransformer]
82+
env = make_env("ReachAbsolute-v0", reward_config=None, transformers=wrappers)
83+
build_models = get_model_builder(model="shadow", model_type="lstm", shared=False)
84+
agent = PPOAgent(build_models, env, workers=2, horizon=128, distribution=BetaPolicyDistribution(env))
85+
agent.drill(n=2, epochs=2, batch_size=64)
86+
except Exception:
87+
self.fail("ReachAbsolute drill raises error.")
88+
89+
def test_drill_freereach(self):
90+
"""Test drilling of free reach agent."""
91+
92+
try:
93+
wrappers = [StateNormalizationTransformer, RewardNormalizationTransformer]
94+
env = make_env("FreeReachAbsolute-v0", reward_config=None, transformers=wrappers)
95+
build_models = get_model_builder(model="shadow", model_type="lstm", shared=False)
96+
agent = PPOAgent(build_models, env, workers=2, horizon=128, distribution=BetaPolicyDistribution(env))
97+
agent.drill(n=2, epochs=2, batch_size=64)
98+
except Exception:
99+
self.fail("FreeReachAbsolute drill raises error.")
100+
77101

78102
if __name__ == '__main__':
79103
unittest.main()

0 commit comments

Comments
 (0)