Skip to content

Commit a5b4912

Browse files
committed
Merge branch 'master' of github.com:ccnmaastricht/dexterous-robot-hand
2 parents ab7a542 + 9c88c2c commit a5b4912

File tree

1 file changed

+5
-23
lines changed

1 file changed

+5
-23
lines changed

README.md

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -64,32 +64,14 @@ Where ID is the agent's ID given when its created (`train.py` prints this outt,
6464
To train agents with custom models, environments, etc. you write your own script. The following is a minimal example:
6565

6666
```python
67-
from angorapy.agent.ppo_agent import PPOAgent
68-
from angorapy.common.policies import BetaPolicyDistribution
69-
from angorapy.common.transformers import RewardNormalizationTransformer, StateNormalizationTransformer
7067
from angorapy.common.wrappers import make_env
7168
from angorapy.models import get_model_builder
69+
from angorapy.agent.ppo_agent import PPOAgent
7270

73-
wrappers = [StateNormalizationTransformer, RewardNormalizationTransformer]
74-
env = make_env("LunarLanderContinuous-v2", reward_config=None, transformers=wrappers)
75-
76-
# make policy distribution
77-
distribution = BetaPolicyDistribution(env)
78-
79-
# the agent needs to create the model itself, so we build a method that builds a model
80-
build_models = get_model_builder(model="simple", model_type="ffn", shared=False)
81-
82-
# given the model builder and the environment we can create an agent
83-
agent = PPOAgent(build_models, env, horizon=1024, workers=12, distribution=distribution)
84-
85-
# let's check the agents ID, so we can find its saved states after training
86-
print(f"My Agent's ID: {agent.agent_id}")
87-
88-
# ... and then train that agent for n cycles
89-
agent.drill(n=100, epochs=3, batch_size=64)
90-
91-
# after training, we can save the agent for analysis or the like
92-
agent.save_agent_state()
71+
env = make_env("LunarLanderContinuous-v2")
72+
model_builder = get_model_builder("simple", "ffn")
73+
agent = PPOAgent(model_builder, env)
74+
agent.drill(100, 10, 512)
9375
```
9476

9577
For more details, consult the [examples](examples).

0 commit comments

Comments
 (0)