-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathutils.py
155 lines (148 loc) · 5.99 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import jax.numpy as jnp
from evosax import Strategies, NetworkMapper
import gymnax
from brax import envs
def get_network_and_pholder(task, args):
if task in ['MNIST', 'FMNIST']:
return NetworkMapper['CNN'](
depth_1=1,
depth_2=1,
features_1=8,
features_2=16,
kernel_1=5,
kernel_2=5,
strides_1=1,
strides_2=1,
num_linear_layers=0,
num_output_units=10,
), jnp.zeros((1, 28, 28, 1))
elif task == 'CIFAR10':
return NetworkMapper['CNN'](
depth_1=1,
depth_2=1,
features_1=64,
features_2=128,
kernel_1=5,
kernel_2=5,
strides_1=1,
strides_2=1,
num_linear_layers=1,
num_hidden_units=256,
num_output_units=10,
), jnp.zeros((1, 32, 32, 3))
elif task in ['CartPole-v1', 'Acrobot-v1', 'MountainCar-v0', 'Asterix-MinAtar']:
env, env_param = gymnax.make(task)
pholder = jnp.zeros(env.observation_space(env_param).shape)
print(env.observation_space(env_param).shape, env.num_actions)
if args.recurrent:
network = NetworkMapper["LSTM"](
num_hidden_units=32,
num_output_units=env.num_actions,
output_activation="categorical",
)
else:
network = NetworkMapper["MLP"](
num_hidden_units=64,
num_hidden_layers=2,
num_output_units=env.num_actions,
hidden_activation="relu",
output_activation="categorical",
)
return network, pholder
elif task == ['Pendulum-v1', 'MountainCarContinuous-v0']:
env, env_param = gymnax.make(task)
pholder = jnp.zeros(env.observation_space(env_param).shape)
if args.recurrent:
network = NetworkMapper["LSTM"](
num_hidden_units=32,
num_output_units=env.num_actions,
output_activation="gaussian",
)
else:
network = NetworkMapper["MLP"](
num_hidden_units=64,
num_hidden_layers=2,
num_output_units=env.num_actions,
hidden_activation="relu",
output_activation="gaussian",
)
return network, pholder
elif task in [
"ant",
"halfcheetah",
"hopper",
"humanoid",
"reacher",
"walker2d",
"fetch",
"grasp",
"ur5e",
]:
env = envs.create(env_name=task)
pholder = jnp.zeros((1, env.observation_size))
if args.recurrent:
network = NetworkMapper["LSTM"](
num_hidden_units=32,
num_output_units=env.action_size,
output_activation="tanh",
)
else:
network = NetworkMapper["MLP"](
num_hidden_units=32,
num_hidden_layers=4,
num_output_units=env.action_size,
hidden_activation="tanh",
output_activation="tanh",
)
return network, pholder
else:
print('ERROR Task is not supported')
def get_strategy_and_params(pop_size, num_dims, args):
MyStrategy = Strategies[args.strategy]
strategy = MyStrategy(popsize=pop_size, num_dims=num_dims, opt_name=args.opt_name)
es_params = strategy.default_params
if args.strategy == 'OpenES':
# Update basic parameters of PGPE strategy
es_params = strategy.default_params.replace(
sigma_init=args.sigma_init, # Initial scale of isotropic Gaussian noise
sigma_decay=args.sigma_decay, # Multiplicative decay factor
sigma_limit=args.sigma_limit, # Smallest possible scale
init_min=args.init_min, # Range of parameter mean initialization - Min
init_max=args.init_max, # Range of parameter mean initialization - Max
clip_min=args.clip_min, # Range of parameter proposals - Min
clip_max=args.clip_max # Range of parameter proposals - Max
)
# Update optimizer-specific parameters of Adam
es_params = es_params.replace(opt_params=es_params.opt_params.replace(
lrate_init=args.lr_init, # Initial learning rate
lrate_decay=args.lrate_init, # Multiplicative decay factor
lrate_limit=args.lrate_limit, # Smallest possible lrate
beta_1=args.beta_1, # Adam - beta_1
beta_2=args.beta_2, # Adam - beta_2
eps=args.eps, # eps constant,
)
)
elif args.strategy == 'PGPE':
# Update basic parameters of PGPE strategy
es_params = strategy.default_params.replace(
sigma_init =args.sigma_init, # Initial scale of isotropic Gaussian noise
sigma_decay=args.sigma_decay, # Multiplicative decay factor
sigma_limit=args.sigma_limit, # Smallest possible scale
sigma_lrate=args.sigma_lrate, # Learning rate for scale
sigma_max_change=args.sigma_max_change, # clips adaptive sigma to 20%
init_min=args.init_min, # Range of parameter mean initialization - Min
init_max=args.init_max, # Range of parameter mean initialization - Max
clip_min=args.clip_min, # Range of parameter proposals - Min
clip_max=args.clip_max # Range of parameter proposals - Max
)
# Update optimizer-specific parameters of Adam
es_params = es_params.replace(opt_params=es_params.opt_params.replace(
lrate_init=args.lr_init, # Initial learning rate
lrate_decay=args.lrate_init, # Multiplicative decay factor
lrate_limit=args.lrate_limit, # Smallest possible lrate
beta_1=args.beta_1, # Adam - beta_1
beta_2=args.beta_2, # Adam - beta_2
eps=args.eps, # eps constant,
)
)
return strategy, es_params