Skip to content

Geonove/update to gymnasium api #22

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,7 @@ build*

# Dummy folder with all gym recordings and metadata
/dummy

#.vscode folders
/cpp/.vscode/*
/.vscode/*
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# gym-tcp-api

This project provides a distributed infrastructure (TCP API) to the OpenAI Gym toolkit, allowing development in languages other than python.
This project provides a distributed infrastructure (TCP API) to the Farama-Foundation Gymnasium toolkit, allowing development in languages other than python.

The server is written in elixir, enabling a distributed infrastructure. Where each node makes use of a limitted set of processes that can be used to perform time consuming tasks (4 python instances per default).

Expand All @@ -20,7 +20,7 @@ The server has the following dependencies:

Python3
Elixir >= 1.0
OpenAI Gym.
Farama-Foundation Gymnasium.

The c++ example agent has the following dependencies:

Expand Down Expand Up @@ -105,7 +105,7 @@ We use JSON as the format to cimmunicate with the server.

Create the specified environment:

{"env" {"name": "CartPole-v0"}}
{"env" {"name": "CartPole-v0", "render_mode": "human"}}

Close the environment:

Expand Down
5 changes: 3 additions & 2 deletions cpp/environment.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,16 @@ class Environment
*/
Environment(const std::string& host,
const std::string& port,
const std::string& environment);
const std::string& environment,
const std::string& render_mode);

/*
* Instantiate the environment object using the specified environment name.
*
* @param environment Name of the environments used to train/evaluate
* the model.
*/
void make(const std::string& environment);
void make(const std::string& environment, const std::string& render_mode);

/*
* Renders the environment.
Expand Down
9 changes: 5 additions & 4 deletions cpp/environment_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,17 @@ inline Environment::Environment(const std::string& host, const std::string& port
inline Environment::Environment(
const std::string& host,
const std::string& port,
const std::string& environment) :
const std::string& environment,
const std::string& render_mode="") :
renderValue(false)
{
client.connect(host, port);
make(environment);
make(environment, render_mode);
}

inline void Environment::make(const std::string& environment)
inline void Environment::make(const std::string& environment, const std::string& render_mode=nullptr)
{
client.send(messages::EnvironmentName(environment));
client.send(messages::EnvironmentName(environment, render_mode));

std::string json;
client.receive(json);
Expand Down
3 changes: 2 additions & 1 deletion cpp/example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@ int main(int argc, char* argv[])
const std::string environment = "CartPole-v1";
const std::string host = "127.0.0.1";
const std::string port = "4040";
const std::string render_mode = "human";

double totalReward = 0;
size_t totalSteps = 0;

Environment env(host, port, environment);
Environment env(host, port, environment, render_mode);
env.compression(9);
env.record_episode_stats.start();

Expand Down
4 changes: 2 additions & 2 deletions cpp/messages.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ namespace gym {
namespace messages {

//! Create message to set the enviroment name.
static inline std::string EnvironmentName(const std::string& name)
static inline std::string EnvironmentName(const std::string& name, const std::string& render_mode)
{
return "{\"env\":{\"name\": \"" + name + "\"}}";
return "{\"env\":{\"name\": \"" + name + "\", \"render_mode\": \"" + render_mode + "\"}}";
}

//! Create message to reset the enviroment.
Expand Down
53 changes: 34 additions & 19 deletions python/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import numpy as np
import socket
import os
import platform
from _thread import *
import glob

import gym
from gym.wrappers import RecordEpisodeStatistics
import gymnasium as gym
from gymnasium.wrappers import RecordEpisodeStatistics

try:
import zlib
Expand All @@ -30,7 +31,7 @@
except socket.error as e:
print(str(e))

print('Waitiing for a Connection..')
print('Waiting for a Connection..')
ServerSocket.listen(5)


Expand Down Expand Up @@ -77,9 +78,9 @@ def _remove_env(self, instance_id):
except KeyError:
raise InvalidUsage('Instance_id {} unknown'.format(instance_id))

def create(self, env_id):
def create(self, env_id, render_mode=None):
try:
env = gym.make(env_id)
env = gym.make(env_id, render_mode=render_mode)
except gym.error.Error:
raise InvalidUsage(
"Attempted to look up malformed environment ID '{}'".format(env_id))
Expand All @@ -90,17 +91,17 @@ def create(self, env_id):

def reset(self, instance_id):
env = self._lookup_env(instance_id)
obs = env.reset()
obs, _ = env.reset()
return env.observation_space.to_jsonable(obs)

def step(self, instance_id, action, render):
env = self._lookup_env(instance_id)
action_from_json = env.action_space.from_jsonable(action)
action_from_json = env.action_space.from_jsonable([action])
if (not isinstance(action_from_json, (list))):
action_from_json = int(action_from_json)

if render: env.render()
[observation, reward, done, info] = env.step(action_from_json)
[observation, reward, done, _, info] = env.step(action_from_json[0])

obs_jsonable = env.observation_space.to_jsonable(observation)
return [obs_jsonable, reward, done, info]
Expand Down Expand Up @@ -264,12 +265,15 @@ def process_response(response, connection, envs, enviroment, instance_id, close,
print(jsonMessage)

enviroment = get_optional_params(jsonMessage, "env", "name")
render_mode = get_optional_params(jsonMessage, "env", "render_mode")
render_mode = render_mode if render_mode != "" else None

if isinstance(enviroment, basestring):
compressionLevel = 0
if instance_id != None:
envs.env_close(instance_id)

instance_id = envs.create(enviroment)
instance_id = envs.create(enviroment, render_mode)
data = json.dumps({"instance" : instance_id}, cls = NDArrayEncoder)
connection.send(process_data(data, compressionLevel))
return enviroment, instance_id, close, compressionLevel
Expand Down Expand Up @@ -321,9 +325,7 @@ def process_response(response, connection, envs, enviroment, instance_id, close,

render = True if (render is not None and render == 1) else False

[obs, reward, done, info] = envs.step(
instance_id, action, render)

[obs, reward, done, info] = envs.step(instance_id, action, render)
data = json.dumps({"observation" : obs,
"reward" : reward,
"done" : done,
Expand Down Expand Up @@ -367,10 +369,23 @@ def process_response(response, connection, envs, enviroment, instance_id, close,
return
connection.close()

while True:
Client, address = ServerSocket.accept()
print('Connected to: ' + address[0] + ':' + str(address[1]))
start_new_thread(threaded_client, (Client, ))
ThreadCount += 1
print('Thread Number: ' + str(ThreadCount))
ServerSocket.close()

if __name__ == "__main__":
if platform.system() == 'Darwin':
# This is needed to make it work on MacOS. In fact, MacOS doesn't allow to rendere on threads
# other than the main thread
while True:
Client, address = ServerSocket.accept()
print('Connected to: ' + address[0] + ':' + str(address[1]))
ThreadCount += 1
print('Thread Number: ' + str(ThreadCount))
threaded_client(Client) # Directly call the function without threading
ServerSocket.close()
else:
while True:
Client, address = ServerSocket.accept()
print('Connected to: ' + address[0] + ':' + str(address[1]))
start_new_thread(threaded_client, (Client, ))
ThreadCount += 1
print('Thread Number: ' + str(ThreadCount))
ServerSocket.close()