From 5380aeb97aff493452e13abf7e3ead029213055c Mon Sep 17 00:00:00 2001 From: Alphadius Date: Thu, 21 Dec 2023 12:57:47 +0100 Subject: [PATCH 1/4] - Add structure code for the MCTS implementation --- rlBlocking/CMakeLists.txt | 12 +++- .../inc/gmds/rlBlocking/MCTSAlgorithm.h | 42 ++++++++++++++ rlBlocking/inc/gmds/rlBlocking/MCTSMove.h | 29 ++++++++++ .../inc/gmds/rlBlocking/MCTSMovePolycube.h | 30 ++++++++++ rlBlocking/inc/gmds/rlBlocking/MCTSState.h | 57 +++++++++++++++++++ .../inc/gmds/rlBlocking/MCTSStatePolycube.h | 47 +++++++++++++++ rlBlocking/inc/gmds/rlBlocking/MCTSTree.h | 38 +++++++++++++ rlBlocking/src/MCTSAlgorithm.cpp | 15 +++++ rlBlocking/src/MCTSMovePolycube.cpp | 12 ++++ rlBlocking/src/MCTSStatePolycube.cpp | 26 +++++++++ rlBlocking/src/MCTSTree.cpp | 14 +++++ rlBlocking/tst/BlockQualityTestSuite.h | 4 +- rlBlocking/tst/CMakeLists.txt | 1 + rlBlocking/tst/MCTSTestSuite.h | 17 ++++++ rlBlocking/tst/main_test.cpp | 1 + 15 files changed, 342 insertions(+), 3 deletions(-) create mode 100644 rlBlocking/inc/gmds/rlBlocking/MCTSAlgorithm.h create mode 100644 rlBlocking/inc/gmds/rlBlocking/MCTSMove.h create mode 100644 rlBlocking/inc/gmds/rlBlocking/MCTSMovePolycube.h create mode 100644 rlBlocking/inc/gmds/rlBlocking/MCTSState.h create mode 100644 rlBlocking/inc/gmds/rlBlocking/MCTSStatePolycube.h create mode 100644 rlBlocking/inc/gmds/rlBlocking/MCTSTree.h create mode 100644 rlBlocking/src/MCTSAlgorithm.cpp create mode 100644 rlBlocking/src/MCTSMovePolycube.cpp create mode 100644 rlBlocking/src/MCTSStatePolycube.cpp create mode 100644 rlBlocking/src/MCTSTree.cpp create mode 100644 rlBlocking/tst/MCTSTestSuite.h diff --git a/rlBlocking/CMakeLists.txt b/rlBlocking/CMakeLists.txt index d987d2e8e..0d06204a1 100644 --- a/rlBlocking/CMakeLists.txt +++ b/rlBlocking/CMakeLists.txt @@ -9,11 +9,21 @@ set(GMDS_INC inc/gmds/rlBlocking/BlockingQuality.h inc/gmds/rlBlocking/LinkerBlockingGeom.h inc/gmds/rlBlocking/ValidBlocking.h + inc/gmds/rlBlocking/MCTSAlgorithm.h + inc/gmds/rlBlocking/MCTSState.h + inc/gmds/rlBlocking/MCTSTree.h + inc/gmds/rlBlocking/MCTSMove.h + inc/gmds/rlBlocking/MCTSMovePolycube.h + inc/gmds/rlBlocking/MCTSStatePolycube.h ) set(GMDS_SRC src/BlockingQuality.cpp src/LinkerBlockingGeom.cpp - src/ValidBlocking.cpp) + src/ValidBlocking.cpp + src/MCTSTree.cpp + src/MCTSAlgorithm.cpp + src/MCTSMovePolycube.cpp + src/MCTSStatePolycube.cpp) #============================================================================== add_library(${GMDS_LIB} ${GMDS_INC} ${GMDS_SRC}) #============================================================================== diff --git a/rlBlocking/inc/gmds/rlBlocking/MCTSAlgorithm.h b/rlBlocking/inc/gmds/rlBlocking/MCTSAlgorithm.h new file mode 100644 index 000000000..9ef280e91 --- /dev/null +++ b/rlBlocking/inc/gmds/rlBlocking/MCTSAlgorithm.h @@ -0,0 +1,42 @@ +// +// Created by bourmaudp on 02/12/22. +// +/*----------------------------------------------------------------------------------------*/ +#ifndef GMDS_MCTSALGORITHM_H +#define GMDS_MCTSALGORITHM_H +/*----------------------------------------------------------------------------------------*/ +#include "LIB_GMDS_RLBLOCKING_export.h" +/*----------------------------------------------------------------------------------------*/ +namespace gmds { +/*----------------------------------------------------------------------------------------*/ +/** @class MCTSAlgorithm + * @brief Class that provides .... + */ +class LIB_GMDS_RLBLOCKING_API MCTSAlgorithm +{ + public: + + /*------------------------------------------------------------------------*/ + /** @brief Constructor. + * @param + */ + MCTSAlgorithm(); + + /*------------------------------------------------------------------------*/ + /** @brief Destructor. */ + virtual ~MCTSAlgorithm(); + + /*------------------------------------------------------------------------*/ + /** @brief Performs the MCTS algorithm + */ + void execute(); + + private: + /** a mesh */ + //Mesh* m_mesh; +}; +/*----------------------------------------------------------------------------*/ +} +/*----------------------------------------------------------------------------------------*/ +#endif // GMDS_MCTSALGORITHM_H +/*----------------------------------------------------------------------------------------*/ diff --git a/rlBlocking/inc/gmds/rlBlocking/MCTSMove.h b/rlBlocking/inc/gmds/rlBlocking/MCTSMove.h new file mode 100644 index 000000000..1105bdd93 --- /dev/null +++ b/rlBlocking/inc/gmds/rlBlocking/MCTSMove.h @@ -0,0 +1,29 @@ +// +// Created by bourmaudp on 02/12/22. +// +/*----------------------------------------------------------------------------------------*/ +#ifndef GMDS_MCTSMOVE_H +#define GMDS_MCTSMOVE_H +/*----------------------------------------------------------------------------------------*/ +#include "LIB_GMDS_RLBLOCKING_export.h" +/*----------------------------------------------------------------------------------------*/ +namespace gmds { +/*----------------------------------------------------------------------------------------*/ +/** @class MCTSMove + * @brief Structure that provides .... + */ +struct LIB_GMDS_RLBLOCKING_API MCTSMove { + /*------------------------------------------------------------------------*/ + /** @brief Destructor + */ + virtual ~MCTSMove() = default; + /*------------------------------------------------------------------------*/ + /** @brief Overloaded == + */ + virtual bool operator==(const MCTSMove& AOther) const = 0; +}; +/*----------------------------------------------------------------------------*/ +} +/*----------------------------------------------------------------------------------------*/ +#endif // GMDS_MCTSMOVE_H +/*----------------------------------------------------------------------------------------*/ diff --git a/rlBlocking/inc/gmds/rlBlocking/MCTSMovePolycube.h b/rlBlocking/inc/gmds/rlBlocking/MCTSMovePolycube.h new file mode 100644 index 000000000..ebb5fc0c2 --- /dev/null +++ b/rlBlocking/inc/gmds/rlBlocking/MCTSMovePolycube.h @@ -0,0 +1,30 @@ +// +// Created by bourmaudp on 02/12/22. +// +/*----------------------------------------------------------------------------------------*/ +#ifndef GMDS_MCTSMOVE_POLYCUBE_H +#define GMDS_MCTSMOVE_POLYCUBE_H +/*----------------------------------------------------------------------------------------*/ +#include "LIB_GMDS_RLBLOCKING_export.h" +#include +/*----------------------------------------------------------------------------------------*/ +namespace gmds { +/*----------------------------------------------------------------------------------------*/ +/** @class MCTSMove + * @brief Structure that provides .... + */ +struct LIB_GMDS_RLBLOCKING_API MCTSMovePolycube: public MCTSMove { + /*------------------------------------------------------------------------*/ + /** @brief Destructor + */ + virtual ~MCTSMovePolycube(); + /*------------------------------------------------------------------------*/ + /** @brief Overloaded == + */ + virtual bool operator==(const MCTSMove& AOther) const; +}; +/*----------------------------------------------------------------------------*/ +} +/*----------------------------------------------------------------------------------------*/ +#endif // GMDS_MCTSMOVE_POLYCUBE_H +/*----------------------------------------------------------------------------------------*/ diff --git a/rlBlocking/inc/gmds/rlBlocking/MCTSState.h b/rlBlocking/inc/gmds/rlBlocking/MCTSState.h new file mode 100644 index 000000000..dcc1fe153 --- /dev/null +++ b/rlBlocking/inc/gmds/rlBlocking/MCTSState.h @@ -0,0 +1,57 @@ +// +// Created by bourmaudp on 02/12/22. +// +/*----------------------------------------------------------------------------------------*/ +#ifndef GMDS_MCTSSTATE_H +#define GMDS_MCTSSTATE_H +/*----------------------------------------------------------------------------------------*/ +#include "LIB_GMDS_RLBLOCKING_export.h" +#include +/*----------------------------------------------------------------------------------------*/ +#include +/*----------------------------------------------------------------------------------------*/ +namespace gmds { +/*----------------------------------------------------------------------------------------*/ +/** @class MCTSState + * @brief Class that provides the interface to be implemented for performing the + * MCST algorithm + */ +class LIB_GMDS_RLBLOCKING_API MCTSState { + public: + /*--------------------------------------------------------------------*/ + /** @enum Status code for rollout execution + */ + typedef enum { + WIN, + LOSE, + DRAW + } ROLLOUT_STATUS; + /*------------------------------------------------------------------------*/ + /** @brief Destructor + */ + virtual ~MCTSState() = default; + /*------------------------------------------------------------------------*/ + /** @brief Gives the set of actions that can be tried from the current state + */ + virtual std::queue *actions_to_try() const = 0; + /*------------------------------------------------------------------------*/ + /** @brief Performs the @p AMove to change of states + * @param[in] AMove the movement to apply to get to a new state + */ + virtual MCTSState *next_state(const MCTSMove *AMove) const = 0; + /*------------------------------------------------------------------------*/ + /** @brief Rollout from this state (random simulation) + * @return the rollout status + */ + virtual ROLLOUT_STATUS rollout() const = 0; + /*------------------------------------------------------------------------*/ + /** @brief Indicate if we have a terminal state (win=true, fail=false) + * @return true if we have a leaf (in the sense of a traditional tree) + */ + virtual bool is_terminal() const = 0; +}; +/*----------------------------------------------------------------------------*/ +} +/*----------------------------------------------------------------------------------------*/ +#endif // GMDS_MCTSSTATE_H +/*----------------------------------------------------------------------------------------*/ diff --git a/rlBlocking/inc/gmds/rlBlocking/MCTSStatePolycube.h b/rlBlocking/inc/gmds/rlBlocking/MCTSStatePolycube.h new file mode 100644 index 000000000..fcf51649d --- /dev/null +++ b/rlBlocking/inc/gmds/rlBlocking/MCTSStatePolycube.h @@ -0,0 +1,47 @@ +// +// Created by bourmaudp on 02/12/22. +// +/*----------------------------------------------------------------------------------------*/ +#ifndef GMDS_MCTSSTATE_POLYCUBE_H +#define GMDS_MCTSSTATE_POLYCUBE_H +/*----------------------------------------------------------------------------------------*/ +#include "LIB_GMDS_RLBLOCKING_export.h" +#include +/*----------------------------------------------------------------------------------------*/ +namespace gmds { +/*----------------------------------------------------------------------------------------*/ +/** @class MCTSState + * @brief Class that provides the interface to be implemented for performing the + * MCST algorithm + */ +class LIB_GMDS_RLBLOCKING_API MCTSStatePolycube: public MCTSState{ + public: + /*------------------------------------------------------------------------*/ + /** @brief Destructor + */ + virtual ~MCTSStatePolycube(); + /*------------------------------------------------------------------------*/ + /** @brief Gives the set of actions that can be tried from the current state + */ + virtual std::queue *actions_to_try() const ; + /*------------------------------------------------------------------------*/ + /** @brief Performs the @p AMove to change of states + * @param[in] AMove the movement to apply to get to a new state + */ + virtual MCTSState *next_state(const MCTSMove *AMove) const; + /*------------------------------------------------------------------------*/ + /** @brief Rollout from this state (random simulation) + * @return the rollout status + */ + virtual ROLLOUT_STATUS rollout() const; + /*------------------------------------------------------------------------*/ + /** @brief Indicate if we have a terminal state (win=true, fail=false) + * @return true if we have a leaf (in the sense of a traditional tree) + */ + virtual bool is_terminal() const; +}; +/*----------------------------------------------------------------------------*/ +} +/*----------------------------------------------------------------------------------------*/ +#endif // GMDS_MCTSSTATE_POLYCUBE_H +/*----------------------------------------------------------------------------------------*/ diff --git a/rlBlocking/inc/gmds/rlBlocking/MCTSTree.h b/rlBlocking/inc/gmds/rlBlocking/MCTSTree.h new file mode 100644 index 000000000..e73250292 --- /dev/null +++ b/rlBlocking/inc/gmds/rlBlocking/MCTSTree.h @@ -0,0 +1,38 @@ +// +// Created by bourmaudp on 02/12/22. +// +/*----------------------------------------------------------------------------------------*/ +#ifndef GMDS_MCTSTREE_H +#define GMDS_MCTSTREE_H +/*----------------------------------------------------------------------------------------*/ +#include "LIB_GMDS_RLBLOCKING_export.h" +/*----------------------------------------------------------------------------------------*/ +namespace gmds { +/*----------------------------------------------------------------------------------------*/ +/** @class MCTSAlgorithm + * @brief Class that provides .... + */ +class LIB_GMDS_RLBLOCKING_API MCTSTree +{ + public: + + /*------------------------------------------------------------------------*/ + /** @brief Constructor. + * @param + */ + MCTSTree(); + + /*------------------------------------------------------------------------*/ + /** @brief Destructor. */ + virtual ~MCTSTree(); + + + private: + /** a mesh */ + //Mesh* m_mesh; +}; +/*----------------------------------------------------------------------------*/ +} +/*----------------------------------------------------------------------------------------*/ +#endif // GMDS_MCTSTREE_H +/*----------------------------------------------------------------------------------------*/ diff --git a/rlBlocking/src/MCTSAlgorithm.cpp b/rlBlocking/src/MCTSAlgorithm.cpp new file mode 100644 index 000000000..dc055ae17 --- /dev/null +++ b/rlBlocking/src/MCTSAlgorithm.cpp @@ -0,0 +1,15 @@ +/*----------------------------------------------------------------------------*/ +#include +/*----------------------------------------------------------------------------*/ +#include +#include +/*----------------------------------------------------------------------------*/ +using namespace gmds; +/*----------------------------------------------------------------------------*/ +MCTSAlgorithm::MCTSAlgorithm(){;} +/*----------------------------------------------------------------------------*/ +MCTSAlgorithm::~MCTSAlgorithm(){;} +/*----------------------------------------------------------------------------*/ +void MCTSAlgorithm::execute() +{} +/*----------------------------------------------------------------------------*/ diff --git a/rlBlocking/src/MCTSMovePolycube.cpp b/rlBlocking/src/MCTSMovePolycube.cpp new file mode 100644 index 000000000..edccb829e --- /dev/null +++ b/rlBlocking/src/MCTSMovePolycube.cpp @@ -0,0 +1,12 @@ +/*----------------------------------------------------------------------------*/ +#include +/*----------------------------------------------------------------------------*/ +using namespace gmds; +/*----------------------------------------------------------------------------*/ +MCTSMovePolycube::~MCTSMovePolycube() +{} +/*----------------------------------------------------------------------------*/ +bool +MCTSMovePolycube::operator==(const gmds::MCTSMove &AOther) const +{} +/*----------------------------------------------------------------------------*/ diff --git a/rlBlocking/src/MCTSStatePolycube.cpp b/rlBlocking/src/MCTSStatePolycube.cpp new file mode 100644 index 000000000..ef06f0ff8 --- /dev/null +++ b/rlBlocking/src/MCTSStatePolycube.cpp @@ -0,0 +1,26 @@ +/*----------------------------------------------------------------------------*/ +#include +/*----------------------------------------------------------------------------*/ +using namespace gmds; +/*----------------------------------------------------------------------------*/ +MCTSStatePolycube::~MCTSStatePolycube() noexcept +{} +/*----------------------------------------------------------------------------*/ +std::queue * +MCTSStatePolycube::actions_to_try() const +{} +/*----------------------------------------------------------------------------*/ +MCTSState * +MCTSStatePolycube::next_state(const gmds::MCTSMove *AMove) const +{} +/*----------------------------------------------------------------------------*/ +MCTSState::ROLLOUT_STATUS +MCTSStatePolycube::rollout() const +{ + return MCTSState::WIN; +} +/*----------------------------------------------------------------------------*/ +bool +MCTSStatePolycube::is_terminal() const +{} +/*----------------------------------------------------------------------------*/ diff --git a/rlBlocking/src/MCTSTree.cpp b/rlBlocking/src/MCTSTree.cpp new file mode 100644 index 000000000..9a83b4146 --- /dev/null +++ b/rlBlocking/src/MCTSTree.cpp @@ -0,0 +1,14 @@ +/*----------------------------------------------------------------------------*/ +#include +/*----------------------------------------------------------------------------*/ +#include +#include +/*----------------------------------------------------------------------------*/ +using namespace gmds; +/*----------------------------------------------------------------------------*/ +MCTSTree::MCTSTree() +{} +/*----------------------------------------------------------------------------*/ +MCTSTree::~MCTSTree() +{} +/*----------------------------------------------------------------------------*/ diff --git a/rlBlocking/tst/BlockQualityTestSuite.h b/rlBlocking/tst/BlockQualityTestSuite.h index 6ec4c98aa..94e7e3b21 100644 --- a/rlBlocking/tst/BlockQualityTestSuite.h +++ b/rlBlocking/tst/BlockQualityTestSuite.h @@ -1,8 +1,6 @@ #ifndef GMDS_BLOCKQUALITYTESTSUITE_H #define GMDS_BLOCKQUALITYTESTSUITE_H -#endif // GMDS_BLOCKQUALITYTESTSUITE_H - // // Created by ledouxf on 1/22/19. @@ -102,3 +100,5 @@ TEST(BlockQualityTestSuite, test_Rubiks) ASSERT_EQ(2, 2);//linker.getGeomId(n1)); } + +#endif // GMDS_BLOCKQUALITYTESTSUITE_H diff --git a/rlBlocking/tst/CMakeLists.txt b/rlBlocking/tst/CMakeLists.txt index 67aaac829..1f8376e84 100644 --- a/rlBlocking/tst/CMakeLists.txt +++ b/rlBlocking/tst/CMakeLists.txt @@ -1,5 +1,6 @@ add_executable(GMDS_BLOCKINGQUALITY_TEST BlockQualityTestSuite.h + MCTSTestSuite.h main_test.cpp) target_link_libraries(GMDS_BLOCKINGQUALITY_TEST PUBLIC diff --git a/rlBlocking/tst/MCTSTestSuite.h b/rlBlocking/tst/MCTSTestSuite.h new file mode 100644 index 000000000..6fc36811f --- /dev/null +++ b/rlBlocking/tst/MCTSTestSuite.h @@ -0,0 +1,17 @@ +#ifndef GMDS_MCTSTESTSUITE_H +#define GMDS_MCTSTESTSUITE_H +/*----------------------------------------------------------------------------*/ +#include +/*----------------------------------------------------------------------------*/ +#include +/*----------------------------------------------------------------------------*/ +using namespace gmds; +/*----------------------------------------------------------------------------*/ + +TEST(MCTSTestSuite, test1) +{ + +} +/*----------------------------------------------------------------------------*/ + +#endif // GMDS_MCTSTESTSUITE_H diff --git a/rlBlocking/tst/main_test.cpp b/rlBlocking/tst/main_test.cpp index 180fda83e..3fdec386f 100644 --- a/rlBlocking/tst/main_test.cpp +++ b/rlBlocking/tst/main_test.cpp @@ -4,6 +4,7 @@ // Files containing the different test suites to launch #include "BlockQualityTestSuite.h" +#include "MCTSTestSuite.h" /*----------------------------------------------------------------------------*/ int main(int argc, char ** argv) { ::testing::InitGoogleTest(&argc, argv); From e65aa0f8aad3f148580bd4a1fa8afe971d662061 Mon Sep 17 00:00:00 2001 From: Alphadius Date: Thu, 11 Jan 2024 10:39:38 +0100 Subject: [PATCH 2/4] add differents elements of the mcts cpp algorithm. Currently, error with the terminal state --- blocking/src/CurvedBlockingClassifier.cpp | 2 +- rlBlocking/CMakeLists.txt | 5 +- rlBlocking/inc/gmds/rlBlocking/MCTSAgent.h | 24 ++ .../inc/gmds/rlBlocking/MCTSAlgorithm.h | 19 +- rlBlocking/inc/gmds/rlBlocking/MCTSMove.h | 2 + .../inc/gmds/rlBlocking/MCTSMovePolycube.h | 17 +- rlBlocking/inc/gmds/rlBlocking/MCTSState.h | 17 +- .../inc/gmds/rlBlocking/MCTSStatePolycube.h | 53 +++- rlBlocking/inc/gmds/rlBlocking/MCTSTree.h | 111 +++++++- rlBlocking/src/MCTSAgent.cpp | 46 ++++ rlBlocking/src/MCTSAlgorithm.cpp | 34 ++- rlBlocking/src/MCTSMovePolycube.cpp | 10 +- rlBlocking/src/MCTSStatePolycube.cpp | 200 +++++++++++++- rlBlocking/src/MCTSTree.cpp | 257 +++++++++++++++++- rlBlocking/src/main_rlBlocking.cpp | 2 +- rlBlocking/tst/MCTSTestSuite.h | 45 ++- 16 files changed, 799 insertions(+), 45 deletions(-) create mode 100644 rlBlocking/inc/gmds/rlBlocking/MCTSAgent.h create mode 100644 rlBlocking/src/MCTSAgent.cpp diff --git a/blocking/src/CurvedBlockingClassifier.cpp b/blocking/src/CurvedBlockingClassifier.cpp index cad4c0fda..83ba7ce79 100644 --- a/blocking/src/CurvedBlockingClassifier.cpp +++ b/blocking/src/CurvedBlockingClassifier.cpp @@ -632,7 +632,7 @@ std::vector> CurvedBlockingClassifier::list_Possible_Cuts() { std::vector> list_actions; - auto no_capt_elements = classify(); + auto no_capt_elements = this->classify(); auto no_points_capt = no_capt_elements.non_captured_points; auto no_curves_capt = no_capt_elements.non_captured_curves; diff --git a/rlBlocking/CMakeLists.txt b/rlBlocking/CMakeLists.txt index 0d06204a1..ce7fe6055 100644 --- a/rlBlocking/CMakeLists.txt +++ b/rlBlocking/CMakeLists.txt @@ -15,6 +15,7 @@ set(GMDS_INC inc/gmds/rlBlocking/MCTSMove.h inc/gmds/rlBlocking/MCTSMovePolycube.h inc/gmds/rlBlocking/MCTSStatePolycube.h + inc/gmds/rlBlocking/MCTSAgent.h ) set(GMDS_SRC src/BlockingQuality.cpp @@ -23,7 +24,9 @@ set(GMDS_SRC src/MCTSTree.cpp src/MCTSAlgorithm.cpp src/MCTSMovePolycube.cpp - src/MCTSStatePolycube.cpp) + src/MCTSStatePolycube.cpp + src/MCTSAgent.cpp +) #============================================================================== add_library(${GMDS_LIB} ${GMDS_INC} ${GMDS_SRC}) #============================================================================== diff --git a/rlBlocking/inc/gmds/rlBlocking/MCTSAgent.h b/rlBlocking/inc/gmds/rlBlocking/MCTSAgent.h new file mode 100644 index 000000000..ebd3b9059 --- /dev/null +++ b/rlBlocking/inc/gmds/rlBlocking/MCTSAgent.h @@ -0,0 +1,24 @@ +#ifndef GMDS_MCTSAGENT_H +#define GMDS_MCTSAGENT_H + +#include +/*----------------------------------------------------------------------------------------*/ +namespace gmds { +/*----------------------------------------------------------------------------------------*/ +class LIB_GMDS_RLBLOCKING_API MCTSAgent +{ + // example of an agent based on the MCTS_tree. One can also use the tree directly. + MCTSTree *tree; + int max_iter, max_seconds, max_same_quality; + + public: + MCTSAgent(MCTSState *starting_state, int max_iter = 100000, int max_seconds = 30, int max_same_quality=3); + ~MCTSAgent(); + const MCTSMove *genmove(); + const MCTSState *get_current_state() const; + void feedback() const {tree->print_stats();} +}; +} +/*----------------------------------------------------------------------------------------*/ +#endif // GMDS_MCTSAGENT_H +/*----------------------------------------------------------------------------------------*/ diff --git a/rlBlocking/inc/gmds/rlBlocking/MCTSAlgorithm.h b/rlBlocking/inc/gmds/rlBlocking/MCTSAlgorithm.h index 9ef280e91..083ff9fd4 100644 --- a/rlBlocking/inc/gmds/rlBlocking/MCTSAlgorithm.h +++ b/rlBlocking/inc/gmds/rlBlocking/MCTSAlgorithm.h @@ -1,11 +1,11 @@ -// -// Created by bourmaudp on 02/12/22. -// /*----------------------------------------------------------------------------------------*/ #ifndef GMDS_MCTSALGORITHM_H #define GMDS_MCTSALGORITHM_H /*----------------------------------------------------------------------------------------*/ #include "LIB_GMDS_RLBLOCKING_export.h" +#include +#include +#include /*----------------------------------------------------------------------------------------*/ namespace gmds { /*----------------------------------------------------------------------------------------*/ @@ -14,17 +14,19 @@ namespace gmds { */ class LIB_GMDS_RLBLOCKING_API MCTSAlgorithm { + MCTSTree *tree; + int max_iter, max_seconds,max_same_quality; public: /*------------------------------------------------------------------------*/ /** @brief Constructor. * @param */ - MCTSAlgorithm(); + MCTSAlgorithm(gmds::cad::GeomManager *AGeom,gmds::blocking::CurvedBlocking *ABlocking,int max_iter = 100000, int max_seconds = 30,int max_same_quality = 10); /*------------------------------------------------------------------------*/ /** @brief Destructor. */ - virtual ~MCTSAlgorithm(); + ~MCTSAlgorithm(); /*------------------------------------------------------------------------*/ /** @brief Performs the MCTS algorithm @@ -32,8 +34,11 @@ class LIB_GMDS_RLBLOCKING_API MCTSAlgorithm void execute(); private: - /** a mesh */ - //Mesh* m_mesh; + /** a geom */ + gmds::cad::GeomManager *m_geom; + /** a blocking */ + gmds::blocking::CurvedBlocking *m_blocking; + }; /*----------------------------------------------------------------------------*/ } diff --git a/rlBlocking/inc/gmds/rlBlocking/MCTSMove.h b/rlBlocking/inc/gmds/rlBlocking/MCTSMove.h index 1105bdd93..75f061783 100644 --- a/rlBlocking/inc/gmds/rlBlocking/MCTSMove.h +++ b/rlBlocking/inc/gmds/rlBlocking/MCTSMove.h @@ -6,6 +6,7 @@ #define GMDS_MCTSMOVE_H /*----------------------------------------------------------------------------------------*/ #include "LIB_GMDS_RLBLOCKING_export.h" +#include /*----------------------------------------------------------------------------------------*/ namespace gmds { /*----------------------------------------------------------------------------------------*/ @@ -21,6 +22,7 @@ struct LIB_GMDS_RLBLOCKING_API MCTSMove { /** @brief Overloaded == */ virtual bool operator==(const MCTSMove& AOther) const = 0; + virtual std::string sprint() const { return "Not implemented"; } // and optionally this }; /*----------------------------------------------------------------------------*/ } diff --git a/rlBlocking/inc/gmds/rlBlocking/MCTSMovePolycube.h b/rlBlocking/inc/gmds/rlBlocking/MCTSMovePolycube.h index ebb5fc0c2..aaf9bdc43 100644 --- a/rlBlocking/inc/gmds/rlBlocking/MCTSMovePolycube.h +++ b/rlBlocking/inc/gmds/rlBlocking/MCTSMovePolycube.h @@ -1,12 +1,10 @@ -// -// Created by bourmaudp on 02/12/22. -// /*----------------------------------------------------------------------------------------*/ #ifndef GMDS_MCTSMOVE_POLYCUBE_H #define GMDS_MCTSMOVE_POLYCUBE_H /*----------------------------------------------------------------------------------------*/ #include "LIB_GMDS_RLBLOCKING_export.h" #include +#include /*----------------------------------------------------------------------------------------*/ namespace gmds { /*----------------------------------------------------------------------------------------*/ @@ -17,11 +15,20 @@ struct LIB_GMDS_RLBLOCKING_API MCTSMovePolycube: public MCTSMove { /*------------------------------------------------------------------------*/ /** @brief Destructor */ - virtual ~MCTSMovePolycube(); + ~MCTSMovePolycube(); /*------------------------------------------------------------------------*/ + TCellID m_AIdEdge; + TCellID m_AIdBlock; + double m_AParamCut; + /** @brief if typeMove=0: delete block, typeMove=1 cut block + */ + bool m_typeMove; + /** @brief Overloaded == */ - virtual bool operator==(const MCTSMove& AOther) const; + MCTSMovePolycube(TCellID AIdEdge,TCellID AIdBlock, double AParamCut,bool ATypeMove); + bool operator==(const MCTSMove& AOther) const; + }; /*----------------------------------------------------------------------------*/ } diff --git a/rlBlocking/inc/gmds/rlBlocking/MCTSState.h b/rlBlocking/inc/gmds/rlBlocking/MCTSState.h index dcc1fe153..9952850fc 100644 --- a/rlBlocking/inc/gmds/rlBlocking/MCTSState.h +++ b/rlBlocking/inc/gmds/rlBlocking/MCTSState.h @@ -7,6 +7,7 @@ /*----------------------------------------------------------------------------------------*/ #include "LIB_GMDS_RLBLOCKING_export.h" #include +#include /*----------------------------------------------------------------------------------------*/ #include /*----------------------------------------------------------------------------------------*/ @@ -43,12 +44,26 @@ class LIB_GMDS_RLBLOCKING_API MCTSState { /** @brief Rollout from this state (random simulation) * @return the rollout status */ - virtual ROLLOUT_STATUS rollout() const = 0; + virtual double state_rollout() const = 0; + /*------------------------------------------------------------------------*/ + /** @brief check the result of a terminal state + * @return the value of the result: Win, Lose, Draw + */ + virtual ROLLOUT_STATUS result_terminal() const = 0; /*------------------------------------------------------------------------*/ /** @brief Indicate if we have a terminal state (win=true, fail=false) * @return true if we have a leaf (in the sense of a traditional tree) */ virtual bool is_terminal() const = 0; + /*------------------------------------------------------------------------*/ + /** @brief Indicate if we have a terminal state (win=true, fail=false) + * @return true if we have a leaf (in the sense of a traditional tree) + */ + virtual double get_quality() const = 0; + + virtual void print() const { + std::cout << "Printing not implemented" << std::endl; + } }; /*----------------------------------------------------------------------------*/ } diff --git a/rlBlocking/inc/gmds/rlBlocking/MCTSStatePolycube.h b/rlBlocking/inc/gmds/rlBlocking/MCTSStatePolycube.h index fcf51649d..354396ea3 100644 --- a/rlBlocking/inc/gmds/rlBlocking/MCTSStatePolycube.h +++ b/rlBlocking/inc/gmds/rlBlocking/MCTSStatePolycube.h @@ -1,12 +1,12 @@ -// -// Created by bourmaudp on 02/12/22. -// /*----------------------------------------------------------------------------------------*/ #ifndef GMDS_MCTSSTATE_POLYCUBE_H #define GMDS_MCTSSTATE_POLYCUBE_H /*----------------------------------------------------------------------------------------*/ #include "LIB_GMDS_RLBLOCKING_export.h" #include +#include +#include +#include /*----------------------------------------------------------------------------------------*/ namespace gmds { /*----------------------------------------------------------------------------------------*/ @@ -15,30 +15,67 @@ namespace gmds { * MCST algorithm */ class LIB_GMDS_RLBLOCKING_API MCTSStatePolycube: public MCTSState{ + public: + /*------------------------------------------------------------------------*/ + /** @brief Constructore + */ + MCTSStatePolycube(gmds::cad::GeomManager *Ageom, gmds::blocking::CurvedBlocking *ABlocking, + std::vector AHist); /*------------------------------------------------------------------------*/ /** @brief Destructor */ - virtual ~MCTSStatePolycube(); + ~MCTSStatePolycube(); /*------------------------------------------------------------------------*/ /** @brief Gives the set of actions that can be tried from the current state */ - virtual std::queue *actions_to_try() const ; + std::queue *actions_to_try() const ; /*------------------------------------------------------------------------*/ /** @brief Performs the @p AMove to change of states * @param[in] AMove the movement to apply to get to a new state */ - virtual MCTSState *next_state(const MCTSMove *AMove) const; + MCTSState *next_state(const MCTSMove *AMove) const; /*------------------------------------------------------------------------*/ /** @brief Rollout from this state (random simulation) * @return the rollout status */ - virtual ROLLOUT_STATUS rollout() const; + double state_rollout() const; + + /** @brief check the history of qualities + * @return nb of same quality from the history + */ + int check_nb_same_quality() const; + /** @brief check the result of a terminal state + * @return Win = all elements are capt, Lose: parent_quality < enfant_quality, + * Draw : same quality for a long time + */ + ROLLOUT_STATUS result_terminal() const; /*------------------------------------------------------------------------*/ /** @brief Indicate if we have a terminal state (win=true, fail=false) * @return true if we have a leaf (in the sense of a traditional tree) */ - virtual bool is_terminal() const; + bool is_terminal() const; + /** @brief return the blocking quality + * */ + double get_quality() const; + /** @brief return the geom */ + gmds::cad::GeomManager *get_geom(); + /** @brief return the current blocking */ + gmds::blocking::CurvedBlocking *get_blocking(); + /** @brief return the current classifier */ + gmds::blocking::CurvedBlockingClassifier *get_class(); + /** @brief return the current classification */ + gmds::blocking::ClassificationErrors get_errors(); + /** @brief return the history of the parents quality */ + std::vector get_history() const; + + private : + /** @brief the curved blocking of the current state */ + gmds::blocking::CurvedBlocking* m_blocking; + gmds::cad::GeomManager* m_geom; + gmds::blocking::CurvedBlockingClassifier* m_class_blocking; + gmds::blocking::ClassificationErrors m_class_errors; + std::vector m_history; }; /*----------------------------------------------------------------------------*/ } diff --git a/rlBlocking/inc/gmds/rlBlocking/MCTSTree.h b/rlBlocking/inc/gmds/rlBlocking/MCTSTree.h index e73250292..e55407b21 100644 --- a/rlBlocking/inc/gmds/rlBlocking/MCTSTree.h +++ b/rlBlocking/inc/gmds/rlBlocking/MCTSTree.h @@ -1,36 +1,137 @@ -// -// Created by bourmaudp on 02/12/22. -// /*----------------------------------------------------------------------------------------*/ #ifndef GMDS_MCTSTREE_H #define GMDS_MCTSTREE_H /*----------------------------------------------------------------------------------------*/ #include "LIB_GMDS_RLBLOCKING_export.h" +#include +#include +#include +#include +#include /*----------------------------------------------------------------------------------------*/ namespace gmds { /*----------------------------------------------------------------------------------------*/ -/** @class MCTSAlgorithm +/** @class MCTSNode + * @brief Class that provides .... + */ +class LIB_GMDS_RLBLOCKING_API MCTSNode { + /** @brief Yes if the node have no childs */ + bool terminal; + /** @brief Number of nodes in the tree from the node. */ + unsigned int size; + /** @brief number of parent nodes having the same quality*/ + unsigned int nb_same_quality; + /** @brief Number of visits*/ + unsigned int number_of_simulations; + /** @brief e.g. number of wins (could be int but double is more general if we use evaluation functions)*/ + double score; + /** @brief state of the current node */ + MCTSState *state; + /** @brief move to get here from parent node's state*/ + const MCTSMove *move; + /** @brief the chilren for the current node */ + std::vector *children; + /** @brief the parent for the current node*/ + MCTSNode *parent; + /** @brief queue of untried actions*/ + std::queue *untried_actions; + /** @brief update the nb simulations and the score after a rollout*/ + void backpropagate(double w, int n); + public: + + /*------------------------------------------------------------------------*/ + /** @brief Constructor. + * @param AParent the parent the node + * @param AMove the action to access at this node + */ + MCTSNode(MCTSNode *AParent, const MCTSMove *AMove, MCTSState *AState); + + /*------------------------------------------------------------------------*/ + /** @brief Destructor. */ + virtual ~MCTSNode(); + + /** @brief Check if the node is fully expanded */ + bool is_fully_expanded() const; + /** @brief Check if the node is terminal. + * @param Number max of parents nodes with the same quality + * */ + bool is_terminal() const; + /** @brief Return the different moves/actions possible for a node */ + const MCTSMove *get_move() const; + /** @brief Return the size. */ + unsigned int get_size() const; + /** @brief Expand the node. */ + void expand(); + /** @brief Do a rollout. */ + void rollout(); + /** @brief Select the most promising child of the root node */ + MCTSNode *select_best_child(double c) const; + /** @brief Find child with this m and delete all others. + * @param m the selected move + * @return the next root*/ + MCTSNode *advance_tree(const MCTSMove *m); + /** @brief Return the state of the node. */ + const MCTSState *get_current_state() const; + /** @brief Print the tree and the stats. */ + void print_stats() const; + /** @brief Calculate the q rate of a node. It's: wins-looses */ + double q_rate() const; + /** @brief Calculate UCT. */ + double calculate_UCT() const; + /** @brief Calculate winrate. */ + double calculate_winrate() const; + + + + private: + /** a mesh */ + //Mesh* m_mesh; +}; + +/*----------------------------------------------------------------------------------------*/ +/** @class MCTSTree * @brief Class that provides .... */ class LIB_GMDS_RLBLOCKING_API MCTSTree { + MCTSNode *root; public: /*------------------------------------------------------------------------*/ /** @brief Constructor. * @param */ - MCTSTree(); + MCTSTree(MCTSState *starting_state); /*------------------------------------------------------------------------*/ /** @brief Destructor. */ virtual ~MCTSTree(); + /** @brief select child node to expand according to tree policy (UCT). + * @param c exploration parameter, theoretically equal to √2 + * @return a node + */ + MCTSNode *select(double c=1.41); + MCTSNode *select_best_child(); + void grow_tree(int max_iter, double max_time_in_seconds); + /** @brief if the move is applicable advance the tree, else start over + * @param move the move to do + * */ + void advance_tree(const MCTSMove *move); + /** @brief get the size of the tree. */ + unsigned int get_size() const; + /** @brief get the size of the tree. */ + const MCTSState *get_current_state() const; + /** @brief Print stats. */ + void print_stats() const; + private: /** a mesh */ //Mesh* m_mesh; }; +/*----------------------------------------------------------------------------*/ + /*----------------------------------------------------------------------------*/ } /*----------------------------------------------------------------------------------------*/ diff --git a/rlBlocking/src/MCTSAgent.cpp b/rlBlocking/src/MCTSAgent.cpp new file mode 100644 index 000000000..1975c7b49 --- /dev/null +++ b/rlBlocking/src/MCTSAgent.cpp @@ -0,0 +1,46 @@ +/*----------------------------------------------------------------------------*/ +#include +/*----------------------------------------------------------------------------*/ +using namespace gmds; +/*----------------------------------------------------------------------------*/ +MCTSAgent::MCTSAgent(gmds::MCTSState *starting_state, int max_iter, int max_seconds, int max_same_quality) + :max_iter(max_iter),max_seconds(max_seconds),max_same_quality(max_same_quality) +{ + tree = new MCTSTree(starting_state); +} + +/*----------------------------------------------------------------------------*/ +MCTSAgent::~MCTSAgent(){ + delete tree; +} +/*----------------------------------------------------------------------------*/ +const MCTSMove *MCTSAgent::genmove() { + // If game ended from opponent move, we can't do anything + if (tree->get_current_state()->is_terminal()) { + return NULL; + } +#ifdef DEBUG + std::cout << "___ DEBUG ______________________" << endl + << "Growing tree..." << std::endl; +#endif + tree->grow_tree(max_iter, max_seconds); +#ifdef DEBUG + cout << "Tree size: " << tree->get_size() << endl + << "________________________________" << endl; +#endif + MCTSNode *best_child = tree->select_best_child(); + if (best_child == NULL) { + std::cerr << "Warning: Tree root has no children! Possibly terminal node!" << std::endl; + return NULL; + } + const MCTSMove *best_move = best_child->get_move(); + tree->advance_tree(best_move); + return best_move; +} +/*----------------------------------------------------------------------------*/ +const + MCTSState *MCTSAgent::get_current_state() const +{ + return tree->get_current_state(); +} +/*----------------------------------------------------------------------------*/ diff --git a/rlBlocking/src/MCTSAlgorithm.cpp b/rlBlocking/src/MCTSAlgorithm.cpp index dc055ae17..166e142b2 100644 --- a/rlBlocking/src/MCTSAlgorithm.cpp +++ b/rlBlocking/src/MCTSAlgorithm.cpp @@ -6,10 +6,40 @@ /*----------------------------------------------------------------------------*/ using namespace gmds; /*----------------------------------------------------------------------------*/ -MCTSAlgorithm::MCTSAlgorithm(){;} +MCTSAlgorithm::MCTSAlgorithm(gmds::cad::GeomManager* AGeom,gmds::blocking::CurvedBlocking* ABlocking,int max_iter, int max_seconds, int max_same_quality) + : m_geom(AGeom),m_blocking(ABlocking),max_iter(max_iter), max_seconds(max_seconds),max_same_quality(max_same_quality) +{ std::vector hist_empty; + MCTSStatePolycube *init_state = new MCTSStatePolycube(m_geom,m_blocking,hist_empty); + tree = new MCTSTree(init_state);} /*----------------------------------------------------------------------------*/ MCTSAlgorithm::~MCTSAlgorithm(){;} /*----------------------------------------------------------------------------*/ void MCTSAlgorithm::execute() -{} +{ + std::cout<<"==========================================================="<m_geom, this->m_blocking, std::vector ()); + //state->print(); // IMPORTANT: state will be garbage after advance_tree() + MCTSAgent agent(state, 1000); + do { + agent.feedback(); + agent.genmove(); + // TODO: This way we don't check if the enemy move ends the game but it's our responsibility to check that, not the tree's... + const MCTSState *new_state = agent.get_current_state(); + new_state->print(); +// if (new_state->is_terminal()) { +// winner = ((const TicTacToe_state *) new_state)->get_winner(); +// break; +// } + done = new_state->is_terminal(); + } while (!done); + std::cout<<"==========================================================="< /*----------------------------------------------------------------------------*/ using namespace gmds; + +/*----------------------------------------------------------------------------*/ +MCTSStatePolycube::MCTSStatePolycube(gmds::cad::GeomManager* AGeom, gmds::blocking::CurvedBlocking* ABlocking, + std::vector hist ) + :m_geom(AGeom),m_blocking(ABlocking),m_history(hist) +{ + gmds::blocking::CurvedBlockingClassifier classifier(m_blocking); + m_class_blocking = new blocking::CurvedBlockingClassifier(classifier); + m_class_errors = m_class_blocking->classify(); + ;} /*----------------------------------------------------------------------------*/ MCTSStatePolycube::~MCTSStatePolycube() noexcept -{} +{ delete m_class_blocking;} /*----------------------------------------------------------------------------*/ std::queue * MCTSStatePolycube::actions_to_try() const -{} +{ + std::queue *Q = new std::queue(); + if (m_class_errors.non_captured_points.size()== 0){ + if(m_class_errors.non_captured_curves.size()==0){ + auto blocks = m_blocking->get_all_id_blocks(); + for(auto b : blocks){ + Q->push(new MCTSMovePolycube(NullID,b,0,0)); + } + } + else{ + auto listPossibleCuts = m_class_blocking->list_Possible_Cuts(); + for(auto c : listPossibleCuts){ + Q->push(new MCTSMovePolycube(c.first,NullID,c.second,1)); + } + } + } + else{ + auto listPossibleCuts = m_class_blocking->list_Possible_Cuts(); + for(auto c : listPossibleCuts){ + Q->push(new MCTSMovePolycube(c.first,NullID,c.second,1)); + } + } + return Q; +} +/*----------------------------------------------------------------------------*/ +MCTSState + *MCTSStatePolycube::next_state(const gmds::MCTSMove *AMove) const +{ + MCTSMovePolycube *m = (MCTSMovePolycube *) AMove; + std::vector hist_update = get_history(); + hist_update.push_back(get_quality()); + MCTSStatePolycube *new_state = new MCTSStatePolycube(this->m_geom,this->m_blocking,hist_update); + if(m->m_typeMove == 0){ + new_state->m_blocking->remove_block(m->m_AIdBlock); + return new_state; + } + else if(m->m_typeMove ==1) { + new_state->m_blocking->cut_sheet(m->m_AIdEdge,m->m_AParamCut); + return new_state; + } + else{ + std::cerr << "Warning: Bad type move !" << std::endl; + return new_state; + } + +} +/*----------------------------------------------------------------------------*/ +double +MCTSStatePolycube::state_rollout() const +{ + std::cout<<"STATE ROLLOUT"< *list_action = actions_to_try(); + long long r; + int a; + MCTSStatePolycube *curstate = (MCTSStatePolycube *) this; // TODO: ignore const... + srand(time(NULL)); + bool first = true; + do { + if (list_action->empty()) { + std::cerr << "Warning: Ran out of available moves and state is not terminal?"; + return 0.0; + } + //Get first move/action + //But, maybe, better to take rand move if its a delete move... + MCTSMove *firstMove = list_action->front(); //TODO: implement random move when only delete moves is possible + list_action->pop(); + + MCTSStatePolycube *old = curstate; + curstate = (MCTSStatePolycube *) curstate->next_state(firstMove); + if (!first) { + delete old; + } + first = false; + } while (!curstate->is_terminal()); + + if(MCTSStatePolycube::result_terminal() == WIN){ + res=1; + } + else if (MCTSStatePolycube::result_terminal() == LOSE) { + res=-1; + } + else{ + //Draw + res=0; + } + delete curstate; + return res; +} /*----------------------------------------------------------------------------*/ -MCTSState * -MCTSStatePolycube::next_state(const gmds::MCTSMove *AMove) const -{} +MCTSStatePolycube::ROLLOUT_STATUS +MCTSStatePolycube::result_terminal() const +{ + int max_nb_same = 3; + if (get_quality() == 0) { + return WIN; + } + else if (check_nb_same_quality() >= max_nb_same){ + return DRAW; + } + else if (m_history.back() < get_quality()){ + return LOSE; + } + std::cerr << "ERROR: NOT terminal state !" << std::endl; + return DRAW; +} /*----------------------------------------------------------------------------*/ -MCTSState::ROLLOUT_STATUS -MCTSStatePolycube::rollout() const +int MCTSStatePolycube::check_nb_same_quality() const { - return MCTSState::WIN; + int nb_same = 0; + double state_quality = get_quality(); + for (int i = m_history.size() - 1; i >= 0; --i) { + if(m_history[i] == state_quality){ + nb_same++; + } + else{ + break; + } + } + return nb_same; } /*----------------------------------------------------------------------------*/ bool MCTSStatePolycube::is_terminal() const -{} +{ + if (m_class_errors.non_captured_points.empty() && m_class_errors.non_captured_curves.empty() && m_class_errors.non_captured_surfaces.empty()) { + return true; + } + else if(check_nb_same_quality() >= 3){ + return true; + } + else if(!m_history.empty() && m_history.back() < get_quality()){ + return true; + } + else { + std::cout<<"NOT TERMINAL STATE"<geom_model(); +} + +/*----------------------------------------------------------------------------*/ +gmds::blocking::CurvedBlocking *MCTSStatePolycube::get_blocking() +{ + return m_blocking; +} + +/*----------------------------------------------------------------------------*/ +gmds::blocking::CurvedBlockingClassifier *MCTSStatePolycube::get_class() +{ + return m_class_blocking; +} +/*----------------------------------------------------------------------------*/ +gmds::blocking::ClassificationErrors MCTSStatePolycube::get_errors() +{ + return m_class_errors; +} +/*----------------------------------------------------------------------------*/ +std::vector MCTSStatePolycube::get_history() const +{ + return m_history; +} /*----------------------------------------------------------------------------*/ diff --git a/rlBlocking/src/MCTSTree.cpp b/rlBlocking/src/MCTSTree.cpp index 9a83b4146..dfb6f6d80 100644 --- a/rlBlocking/src/MCTSTree.cpp +++ b/rlBlocking/src/MCTSTree.cpp @@ -6,9 +6,260 @@ /*----------------------------------------------------------------------------*/ using namespace gmds; /*----------------------------------------------------------------------------*/ -MCTSTree::MCTSTree() -{} +/*-------------------------------- MCTS NODE ---------------------------------*/ +/*----------------------------------------------------------------------------*/ +MCTSNode::MCTSNode(gmds::MCTSNode *AParent, const gmds::MCTSMove *AMove, MCTSState *AState) + :parent(AParent), move(AMove),state(AState), score(0.0), number_of_simulations(0), size(0) +{ + children = new std::vector(); + untried_actions = state->actions_to_try(); + terminal = state->is_terminal(); + + + +} +/*----------------------------------------------------------------------------*/ +MCTSNode::~MCTSNode() { + delete state; + delete move; + for (auto *child : *children) { + delete child; + } + delete children; + while (!untried_actions->empty()) { + delete untried_actions->front(); // if a move is here then it is not a part of a child node and needs to be deleted here + untried_actions->pop(); + } + delete untried_actions; +} + +/*----------------------------------------------------------------------------*/ +void MCTSNode::expand() { + if (is_terminal()) { // can legitimately happen in end-game situations + rollout(); // keep rolling out, eventually causing UCT to pick another node to expand due to exploration + return; + } else if (is_fully_expanded()) { + std::cerr << "Warning: Cannot expanded this node any more!" << std::endl; + return; + } + // get next untried action + MCTSMove *next_move = untried_actions->front(); // get value + untried_actions->pop(); // remove it + MCTSState *next_state = state->next_state(next_move); + + if(state->get_quality() == next_state->get_quality()){ + const unsigned int nb_same_quality = this->nb_same_quality + 1; + } + else{ + const unsigned int nb_same_quality = 0; + } + // build a new MCTS node from it + MCTSNode *new_node = new MCTSNode(this,next_move,next_state); + // rollout, updating its stats + new_node->rollout(); + // add new node to tree + children->push_back(new_node); +} +/*----------------------------------------------------------------------------*/ +const MCTSState *MCTSNode::get_current_state() const +{ + return state; +} +/*----------------------------------------------------------------------------*/ +bool +MCTSNode::is_terminal() const +{ + return terminal; +} + +/*----------------------------------------------------------------------------*/ +const MCTSMove *MCTSNode::get_move() const { + return move; +} +/*----------------------------------------------------------------------------*/ +bool +MCTSNode::is_fully_expanded() const +{ + return is_terminal() || untried_actions->empty(); +} +/*----------------------------------------------------------------------------*/ +unsigned int MCTSNode::get_size() const { + return size; +} +/*----------------------------------------------------------------------------*/ +MCTSNode *MCTSNode::select_best_child(double c) const { + /** selects best child based on the winrate of whose turn it is to play */ + if (children->empty()) return NULL; + else if (children->size() == 1) return children->at(0); + else { + double uct, max = -1; + MCTSNode *argmax = NULL; + for (auto *child : *children) { + double winrate = child->score / ((double) child->number_of_simulations); + if (c > 0) { + uct = winrate + + c * sqrt(log((double) this->number_of_simulations) / ((double) child->number_of_simulations)); + } else { + uct = winrate; + } + if (uct > max) { + max = uct; + argmax = child; + } + } + return argmax; + } +} +/*----------------------------------------------------------------------------*/ +void +MCTSNode::rollout() +{ + double w = state->state_rollout(); + backpropagate(w, 1); +} +/*----------------------------------------------------------------------------*/ +void MCTSNode::backpropagate(double w, int n) { + score += w; + number_of_simulations += n; + if (parent != NULL) { + parent->size++; + parent->backpropagate(w, n); + } +} + +/*----------------------------------------------------------------------------*/ +MCTSNode *MCTSNode::advance_tree(const MCTSMove *m) { + //TODO + // Find child with this m and delete all others + MCTSNode *next = NULL; + for (auto *child: *children) { + if (*(child->move) == *(m)) { + next = child; + } else { + delete child; + } + } + // remove children from queue so that they won't be re-deleted by the destructor when this node dies (!) + this->children->clear(); + // if not found then we have to create a new node + if (next == NULL) { + // Note: UCT may lead to not fully explored tree even for short-term children due to terminal nodes being chosen + std::cout << "INFO: Didn't find child node. Had to start over." << std::endl; + MCTSState *next_state = state->next_state(m); + next = new MCTSNode(this, m,next_state); + } else { + next->parent = NULL; // make parent NULL + // IMPORTANT: m and next->move can be the same here if we pass the move from select_best_child() + // (which is what we will typically be doing). If not then it's the caller's responsibility to delete m (!) + } + // return the next root + return next; +} +/*----------------------------------------------------------------------------*/ +void MCTSNode::print_stats() const { +#define TOPK 10 + if (number_of_simulations == 0) { + std::cout << "Tree not expanded yet" << std::endl; + return; + } + std::cout << "___ INFO _______________________" << std::endl + << "Tree size: " << size << std::endl + << "Number of simulations: " << number_of_simulations << std::endl + << "Branching factor at root: " << children->size() << std::endl; + // print TOPK of them along with their winrates +// std::cout << "Best moves:" << std::endl; +// for (int i = 0 ; i < children->size() && i < TOPK ; i++) { +// std::cout << " " << i + 1 << ". " << children->at(i)->move->sprint() << " --> " +// << std::setprecision(4) << 100.0 * children->at(i)->calculate_winrate(state->player1_turn()) << "%" << endl; +// } + std::cout << "________________________________" << std::endl; +} + + +/*----------------------------------------------------------------------------*/ +double MCTSNode::calculate_winrate() const { + return score / number_of_simulations; +} + +/*----------------------------------------------------------------------------*/ +/*-------------------------------- MCTS TREE --------------------------------*/ +/*----------------------------------------------------------------------------*/ +MCTSTree::MCTSTree(MCTSState* starting_state) +{ + assert(starting_state != NULL); + root = new MCTSNode(NULL, NULL, starting_state); +} /*----------------------------------------------------------------------------*/ MCTSTree::~MCTSTree() -{} +{ + delete root; +} +/*----------------------------------------------------------------------------*/ +MCTSNode *MCTSTree::select(double c) { + MCTSNode *node = root; + while (!node->is_terminal()) { + if (!node->is_fully_expanded()) { + return node; + } else { + node = node->select_best_child(c); + } + } + return node; +} +/*----------------------------------------------------------------------------*/ +void MCTSTree::grow_tree(int max_iter, double max_time_in_seconds) { + MCTSNode *node; + double dt; +#ifdef DEBUG + std::cout << "Growing tree..." << std::endl; +#endif + time_t start_t, now_t; + time(&start_t); + for (int i = 0 ; i < max_iter ; i++){ + // select node to expand according to tree policy + node = select(); + // expand it (this will perform a rollout and backpropagate the results) + node->expand(); + // check if we need to stop + time(&now_t); + dt = difftime(now_t, start_t); + if (dt > max_time_in_seconds) { +#ifdef DEBUG + std::cout << "Early stopping: Made " << (i + 1) << " iterations in " << dt << " seconds." << std::endl; +#endif + break; + } + } +#ifdef DEBUG + time(&now_t); + dt = difftime(now_t, start_t); + cout << "Finished in " << dt << " seconds." << endl; +#endif +} +/*----------------------------------------------------------------------------*/ +MCTSNode *MCTSTree::select_best_child() { + return root->select_best_child(0.0); +} +/*----------------------------------------------------------------------------*/ +void MCTSTree::advance_tree(const MCTSMove *move) { + MCTSNode *old_root = root; + root = root->advance_tree(move); + delete old_root; // this won't delete the new root since we have emptied old_root's children +} + +/*----------------------------------------------------------------------------*/ +unsigned int MCTSTree::get_size() const { + return root->get_size(); +} +/*----------------------------------------------------------------------------*/ +const MCTSState *MCTSTree::get_current_state() const +{ + return root->get_current_state(); +} +/*----------------------------------------------------------------------------*/ +void MCTSTree::print_stats() const +{ + root->print_stats(); +} /*----------------------------------------------------------------------------*/ + diff --git a/rlBlocking/src/main_rlBlocking.cpp b/rlBlocking/src/main_rlBlocking.cpp index 5cfce3aca..789c7546a 100644 --- a/rlBlocking/src/main_rlBlocking.cpp +++ b/rlBlocking/src/main_rlBlocking.cpp @@ -44,7 +44,7 @@ int main(int argc, char* argv[]) gmds::Mesh vol_mesh(gmds::MeshModel(gmds::DIM3 | gmds::R | gmds::F | gmds::E | gmds::N | gmds::R2N | gmds::R2F | gmds::R2E | gmds::F2N | gmds::F2R | gmds::F2E | gmds::E2F | gmds::E2N | gmds::N2E)); - std::string vtk_file = "/home/bourmaudp/Documents/mambo-master/Basic/vtk/B0.vtk"; + std::string vtk_file = "/home/bourmaudp/Documents/mambo-master/Basic/vtk/cb3.vtk"; gmds::IGMeshIOService ioServiceA(&vol_mesh); gmds::VTKReader vtkReaderA(&ioServiceA); vtkReaderA.setCellOptions(gmds::N | gmds::R); diff --git a/rlBlocking/tst/MCTSTestSuite.h b/rlBlocking/tst/MCTSTestSuite.h index 6fc36811f..de5bce2b8 100644 --- a/rlBlocking/tst/MCTSTestSuite.h +++ b/rlBlocking/tst/MCTSTestSuite.h @@ -4,13 +4,56 @@ #include /*----------------------------------------------------------------------------*/ #include +#include /*----------------------------------------------------------------------------*/ using namespace gmds; /*----------------------------------------------------------------------------*/ +/**@brief setup function that initialize a geometric model using the faceted + * representation and an input vtk file name. The vtk file must contain a + * tetrahedral mesh + * + * @param AGeomModel geometric model we initialize + * @param AFileName vtk filename + */ +void set_up_MCTS(gmds::cad::FACManager* AGeomModel, const std::string AFileName) +{ + gmds::Mesh vol_mesh(gmds::MeshModel(gmds::DIM3 | gmds::R | gmds::F | gmds::E | gmds::N | gmds::R2N | gmds::R2F | gmds::R2E | gmds::F2N | gmds::F2R | gmds::F2E + | gmds::E2F | gmds::E2N | gmds::N2E)); + std::string dir(TEST_SAMPLES_DIR); + std::string vtk_file = dir +"/"+ AFileName; + gmds::IGMeshIOService ioService(&vol_mesh); + gmds::VTKReader vtkReader(&ioService); + vtkReader.setCellOptions(gmds::N | gmds::R); + vtkReader.read(vtk_file); + gmds::MeshDoctor doc(&vol_mesh); + doc.buildFacesAndR2F(); + doc.buildEdgesAndX2E(); + doc.updateUpwardConnectivity(); + AGeomModel->initFrom3DMesh(&vol_mesh); -TEST(MCTSTestSuite, test1) +} +/*----------------------------------------------------------------------------*/ +TEST(MCTSTestSuite, testExAglo) { + gmds::cad::FACManager geom_model; + set_up_MCTS(&geom_model,"M1.vtk"); + gmds::blocking::CurvedBlocking bl(&geom_model,true); + bl.save_vtk_blocking("/home/bourmaudp/Documents/PROJETS/gmds/gmds_Correction_Class_Dev/saveResults/M1_init_blocking.vtk"); + std::cout<<"NB points : "<< geom_model.getPoints().size()<execute(); + } /*----------------------------------------------------------------------------*/ From 5ffaab5555c2c3610d62d054b8fa0a2ca0ed80df Mon Sep 17 00:00:00 2001 From: Alphadius Date: Tue, 16 Jan 2024 16:10:51 +0100 Subject: [PATCH 3/4] the algorithm work. but, we need to add method to check the terminal tree. Moreover, the name of the blocking save impact the execution. Need to fix it --- blocking/inc/gmds/blocking/CurvedBlocking.h | 4 +- blocking/tst/ExecutionActionsTestSuite.h | 81 ++++++++++++ rlBlocking/inc/gmds/rlBlocking/MCTSAgent.h | 2 + rlBlocking/inc/gmds/rlBlocking/MCTSMove.h | 3 +- .../inc/gmds/rlBlocking/MCTSMovePolycube.h | 7 +- rlBlocking/inc/gmds/rlBlocking/MCTSState.h | 2 +- .../inc/gmds/rlBlocking/MCTSStatePolycube.h | 5 +- rlBlocking/inc/gmds/rlBlocking/MCTSTree.h | 4 +- rlBlocking/src/MCTSAgent.cpp | 3 +- rlBlocking/src/MCTSAlgorithm.cpp | 4 +- rlBlocking/src/MCTSMovePolycube.cpp | 8 +- rlBlocking/src/MCTSStatePolycube.cpp | 118 ++++++++++++++---- rlBlocking/src/MCTSTree.cpp | 38 ++++-- rlBlocking/tst/MCTSTestSuite.h | 17 ++- 14 files changed, 239 insertions(+), 57 deletions(-) diff --git a/blocking/inc/gmds/blocking/CurvedBlocking.h b/blocking/inc/gmds/blocking/CurvedBlocking.h index c4d756a51..05ccc8e3e 100644 --- a/blocking/inc/gmds/blocking/CurvedBlocking.h +++ b/blocking/inc/gmds/blocking/CurvedBlocking.h @@ -54,7 +54,7 @@ struct CellInfo * @param AGeomDim on-classify geometric cell dimension (4 if not classified) * @param AGeomId on-classify geometric cell unique id */ - CellInfo(cad::GeomManager* AManager, const int ATopoDim = 4, const int AGeomDim = 4, const int AGeomId = NullID) : + CellInfo(cad::GeomManager* AManager=NULL, const int ATopoDim = 4, const int AGeomDim = 4, const int AGeomId = NullID) : topo_dim(ATopoDim), topo_id(m_counter_global_id++), geom_manager(AManager),geom_dim(AGeomDim), geom_id(AGeomId) { } @@ -74,7 +74,7 @@ struct NodeInfo : CellInfo * @param AGeomId on-classify geometric cell unique id * @param APoint geometric location */ - NodeInfo(cad::GeomManager* AManager, const int AGeomDim = 4, const int AGeomId = NullID, const math::Point &APoint = math::Point(0, 0, 0)) : + NodeInfo(cad::GeomManager* AManager=NULL, const int AGeomDim = 4, const int AGeomId = NullID, const math::Point &APoint = math::Point(0, 0, 0)) : CellInfo(AManager, 0, AGeomDim, AGeomId), point(APoint) { } diff --git a/blocking/tst/ExecutionActionsTestSuite.h b/blocking/tst/ExecutionActionsTestSuite.h index bf32499b4..27711d43d 100644 --- a/blocking/tst/ExecutionActionsTestSuite.h +++ b/blocking/tst/ExecutionActionsTestSuite.h @@ -706,3 +706,84 @@ TEST(ExecutionActionsTestSuite,cb5){ vtk_writer_edges.write("debug_blocking_edges.vtk"); } + + +TEST(ExecutionActionsTestSuite,cb2_auto) { + gmds::cad::FACManager geom_model; + set_up_file(&geom_model,"cb2.vtk"); + gmds::blocking::CurvedBlocking bl(&geom_model,true); + gmds::blocking::CurvedBlockingClassifier classifier(&bl); + + + classifier.clear_classification(); + + auto errors = classifier.classify(); + + + //Check nb points of the geometry and nb nodes of the blocking + ASSERT_EQ(16,geom_model.getNbPoints()); + ASSERT_EQ(24,geom_model.getNbCurves()); + ASSERT_EQ(10,geom_model.getNbSurfaces()); + ASSERT_EQ(8,bl.get_all_nodes().size()); + ASSERT_EQ(12,bl.get_all_edges().size()); + ASSERT_EQ(6,bl.get_all_faces().size()); + + + + //Check elements class and captured + //Check nb nodes/edges/faces no classified + ASSERT_EQ(0,errors.non_classified_nodes.size()); + ASSERT_EQ(0,errors.non_classified_edges.size()); + ASSERT_EQ(6,errors.non_classified_faces.size()); + + //Check nb points/curves/surfaces no captured + ASSERT_EQ(8,errors.non_captured_points.size()); + ASSERT_EQ(12,errors.non_captured_curves.size()); + ASSERT_EQ(10,errors.non_captured_surfaces.size()); + + auto listEdgesCut = classifier.list_Possible_Cuts(); + //Do 1 cut + bl.cut_sheet(listEdgesCut.front().first,listEdgesCut.front().second); + + + classifier.classify(); + + listEdgesCut = classifier.list_Possible_Cuts(); + //Do 1 cut + bl.cut_sheet(listEdgesCut.front().first,listEdgesCut.front().second); + + classifier.classify(); + + listEdgesCut = classifier.list_Possible_Cuts(); + //Do 1 cut + bl.cut_sheet(listEdgesCut.front().first,listEdgesCut.front().second); + + classifier.classify(); + + listEdgesCut = classifier.list_Possible_Cuts(); + //Do 1 cut + bl.cut_sheet(listEdgesCut.front().first,listEdgesCut.front().second); + + + gmds::Mesh m(gmds::MeshModel(gmds::DIM3|gmds::N|gmds::E|gmds::F|gmds::R|gmds::E2N|gmds::F2N|gmds::R2N)); + bl.convert_to_mesh(m); + + + gmds::IGMeshIOService ios(&m); + gmds::VTKWriter vtk_writer(&ios); + vtk_writer.setCellOptions(gmds::N|gmds::R); + vtk_writer.setDataOptions(gmds::N|gmds::R); + vtk_writer.write("cb2_debug_blocking.vtk"); + gmds::VTKWriter vtk_writer_edges(&ios); + vtk_writer_edges.setCellOptions(gmds::N|gmds::E); + vtk_writer_edges.setDataOptions(gmds::N|gmds::E); + vtk_writer_edges.write("cb2_debug_blocking_edges.vtk"); + gmds::VTKWriter vtk_writer_faces(&ios); + vtk_writer_faces.setCellOptions(gmds::N|gmds::F); + vtk_writer_faces.setDataOptions(gmds::N|gmds::F); + vtk_writer_faces.write("cb2_debug_blocking_faces.vtk"); + + +} + + diff --git a/rlBlocking/inc/gmds/rlBlocking/MCTSAgent.h b/rlBlocking/inc/gmds/rlBlocking/MCTSAgent.h index ebd3b9059..f85f81346 100644 --- a/rlBlocking/inc/gmds/rlBlocking/MCTSAgent.h +++ b/rlBlocking/inc/gmds/rlBlocking/MCTSAgent.h @@ -2,6 +2,8 @@ #define GMDS_MCTSAGENT_H #include +#include +#include /*----------------------------------------------------------------------------------------*/ namespace gmds { /*----------------------------------------------------------------------------------------*/ diff --git a/rlBlocking/inc/gmds/rlBlocking/MCTSMove.h b/rlBlocking/inc/gmds/rlBlocking/MCTSMove.h index 75f061783..8b276a22a 100644 --- a/rlBlocking/inc/gmds/rlBlocking/MCTSMove.h +++ b/rlBlocking/inc/gmds/rlBlocking/MCTSMove.h @@ -22,7 +22,8 @@ struct LIB_GMDS_RLBLOCKING_API MCTSMove { /** @brief Overloaded == */ virtual bool operator==(const MCTSMove& AOther) const = 0; - virtual std::string sprint() const { return "Not implemented"; } // and optionally this + virtual std::string sprint() const { return "Not implemented"; } + virtual void print() const =0; // and optionally this }; /*----------------------------------------------------------------------------*/ } diff --git a/rlBlocking/inc/gmds/rlBlocking/MCTSMovePolycube.h b/rlBlocking/inc/gmds/rlBlocking/MCTSMovePolycube.h index aaf9bdc43..0af5a6449 100644 --- a/rlBlocking/inc/gmds/rlBlocking/MCTSMovePolycube.h +++ b/rlBlocking/inc/gmds/rlBlocking/MCTSMovePolycube.h @@ -20,14 +20,15 @@ struct LIB_GMDS_RLBLOCKING_API MCTSMovePolycube: public MCTSMove { TCellID m_AIdEdge; TCellID m_AIdBlock; double m_AParamCut; - /** @brief if typeMove=0: delete block, typeMove=1 cut block + /** @brief if typeMove=2: delete block, typeMove=1 cut block */ - bool m_typeMove; + unsigned int m_typeMove; /** @brief Overloaded == */ - MCTSMovePolycube(TCellID AIdEdge,TCellID AIdBlock, double AParamCut,bool ATypeMove); + MCTSMovePolycube(TCellID AIdEdge = -1,TCellID AIdBlock = -1 , double AParamCut = 0,unsigned int ATypeMove = -1); bool operator==(const MCTSMove& AOther) const; + void print() const; }; /*----------------------------------------------------------------------------*/ diff --git a/rlBlocking/inc/gmds/rlBlocking/MCTSState.h b/rlBlocking/inc/gmds/rlBlocking/MCTSState.h index 9952850fc..92301494e 100644 --- a/rlBlocking/inc/gmds/rlBlocking/MCTSState.h +++ b/rlBlocking/inc/gmds/rlBlocking/MCTSState.h @@ -34,7 +34,7 @@ class LIB_GMDS_RLBLOCKING_API MCTSState { /*------------------------------------------------------------------------*/ /** @brief Gives the set of actions that can be tried from the current state */ - virtual std::queue *actions_to_try() const = 0; + virtual std::deque *actions_to_try() const = 0; /*------------------------------------------------------------------------*/ /** @brief Performs the @p AMove to change of states * @param[in] AMove the movement to apply to get to a new state diff --git a/rlBlocking/inc/gmds/rlBlocking/MCTSStatePolycube.h b/rlBlocking/inc/gmds/rlBlocking/MCTSStatePolycube.h index 354396ea3..bd4341015 100644 --- a/rlBlocking/inc/gmds/rlBlocking/MCTSStatePolycube.h +++ b/rlBlocking/inc/gmds/rlBlocking/MCTSStatePolycube.h @@ -29,7 +29,7 @@ class LIB_GMDS_RLBLOCKING_API MCTSStatePolycube: public MCTSState{ /*------------------------------------------------------------------------*/ /** @brief Gives the set of actions that can be tried from the current state */ - std::queue *actions_to_try() const ; + std::deque *actions_to_try() const ; /*------------------------------------------------------------------------*/ /** @brief Performs the @p AMove to change of states * @param[in] AMove the movement to apply to get to a new state @@ -69,6 +69,9 @@ class LIB_GMDS_RLBLOCKING_API MCTSStatePolycube: public MCTSState{ /** @brief return the history of the parents quality */ std::vector get_history() const; + /** @brief update the classification of a state */ + void update_class(); + private : /** @brief the curved blocking of the current state */ gmds::blocking::CurvedBlocking* m_blocking; diff --git a/rlBlocking/inc/gmds/rlBlocking/MCTSTree.h b/rlBlocking/inc/gmds/rlBlocking/MCTSTree.h index e55407b21..e8e2328c0 100644 --- a/rlBlocking/inc/gmds/rlBlocking/MCTSTree.h +++ b/rlBlocking/inc/gmds/rlBlocking/MCTSTree.h @@ -34,7 +34,7 @@ class LIB_GMDS_RLBLOCKING_API MCTSNode { /** @brief the parent for the current node*/ MCTSNode *parent; /** @brief queue of untried actions*/ - std::queue *untried_actions; + std::deque *untried_actions; /** @brief update the nb simulations and the score after a rollout*/ void backpropagate(double w, int n); public: @@ -72,6 +72,8 @@ class LIB_GMDS_RLBLOCKING_API MCTSNode { MCTSNode *advance_tree(const MCTSMove *m); /** @brief Return the state of the node. */ const MCTSState *get_current_state() const; + /** @brief Return the children of the node. */ + std::vector *get_children(); /** @brief Print the tree and the stats. */ void print_stats() const; /** @brief Calculate the q rate of a node. It's: wins-looses */ diff --git a/rlBlocking/src/MCTSAgent.cpp b/rlBlocking/src/MCTSAgent.cpp index 1975c7b49..01bfd526e 100644 --- a/rlBlocking/src/MCTSAgent.cpp +++ b/rlBlocking/src/MCTSAgent.cpp @@ -14,7 +14,8 @@ MCTSAgent::~MCTSAgent(){ delete tree; } /*----------------------------------------------------------------------------*/ -const MCTSMove *MCTSAgent::genmove() { +const MCTSMove *MCTSAgent::genmove() +{ // If game ended from opponent move, we can't do anything if (tree->get_current_state()->is_terminal()) { return NULL; diff --git a/rlBlocking/src/MCTSAlgorithm.cpp b/rlBlocking/src/MCTSAlgorithm.cpp index 166e142b2..4338cbb17 100644 --- a/rlBlocking/src/MCTSAlgorithm.cpp +++ b/rlBlocking/src/MCTSAlgorithm.cpp @@ -25,7 +25,7 @@ void MCTSAlgorithm::execute() MCTSState *state = new MCTSStatePolycube(this->m_geom, this->m_blocking, std::vector ()); //state->print(); // IMPORTANT: state will be garbage after advance_tree() - MCTSAgent agent(state, 1000); + MCTSAgent agent(state, 100); do { agent.feedback(); agent.genmove(); @@ -38,6 +38,8 @@ void MCTSAlgorithm::execute() // } done = new_state->is_terminal(); } while (!done); + + std::cout<<"==========================================================="< +#include /*----------------------------------------------------------------------------*/ using namespace gmds; + /*----------------------------------------------------------------------------*/ MCTSStatePolycube::MCTSStatePolycube(gmds::cad::GeomManager* AGeom, gmds::blocking::CurvedBlocking* ABlocking, std::vector hist ) - :m_geom(AGeom),m_blocking(ABlocking),m_history(hist) + :m_geom(AGeom),m_history(hist) { + m_blocking = new blocking::CurvedBlocking(*ABlocking); gmds::blocking::CurvedBlockingClassifier classifier(m_blocking); m_class_blocking = new blocking::CurvedBlockingClassifier(classifier); - m_class_errors = m_class_blocking->classify(); + m_class_errors = m_class_blocking->classify(0.2); ;} /*----------------------------------------------------------------------------*/ MCTSStatePolycube::~MCTSStatePolycube() noexcept -{ delete m_class_blocking;} +{ + delete m_class_blocking; + delete m_blocking; +} /*----------------------------------------------------------------------------*/ -std::queue * +std::deque * MCTSStatePolycube::actions_to_try() const { - std::queue *Q = new std::queue(); + std::deque *Q = new std::deque(); if (m_class_errors.non_captured_points.size()== 0){ + std::cout<<"POINTS CAPT :"<get_all_id_blocks(); for(auto b : blocks){ - Q->push(new MCTSMovePolycube(NullID,b,0,0)); + Q->push_back(new MCTSMovePolycube(NullID,b,0,2)); } } else{ + std::cout<<"NB CURVES CAPT :"<< m_class_errors.non_captured_curves.size()<list_Possible_Cuts(); for(auto c : listPossibleCuts){ - Q->push(new MCTSMovePolycube(c.first,NullID,c.second,1)); + Q->push_back(new MCTSMovePolycube(c.first,NullID,c.second,1)); } } } else{ + std::cout<<"POINTS NO CAPT :"<list_Possible_Cuts(); for(auto c : listPossibleCuts){ - Q->push(new MCTSMovePolycube(c.first,NullID,c.second,1)); + Q->push_back(new MCTSMovePolycube(c.first,NullID,c.second,1)); } } + std::cout<<"LIST ACTIONS :"<print(); + } return Q; } /*----------------------------------------------------------------------------*/ MCTSState *MCTSStatePolycube::next_state(const gmds::MCTSMove *AMove) const { + std::cout<<"==================== EXECUTE ACTION ! ===================="< hist_update = get_history(); hist_update.push_back(get_quality()); - MCTSStatePolycube *new_state = new MCTSStatePolycube(this->m_geom,this->m_blocking,hist_update); - if(m->m_typeMove == 0){ - new_state->m_blocking->remove_block(m->m_AIdBlock); + gmds::blocking::CurvedBlocking* new_b = new gmds::blocking::CurvedBlocking(*m_blocking); + MCTSStatePolycube *new_state = new MCTSStatePolycube(this->m_geom,new_b,hist_update); + if(m->m_typeMove == 2){ + //TODO ERROR, sometimes, block select not in the current blocks list...Check why !!! + std::cout<<"LIST BLOCK BLOCKING : "<get_all_id_blocks()){ + std::cout<m_AIdBlock){ + b_in_list = true; + break; + } + } + if(b_in_list){ + std::cout<<"BLOCK A DELETE :"<m_AIdBlock<m_blocking->remove_block(m->m_AIdBlock); + } + else{ + std::cout<<"BLOCK A DELETE :"<get_all_id_blocks().front()<m_blocking->remove_block(m_blocking->get_all_id_blocks().front()); + } + + new_state->update_class(); + //SAVE Blocking vtk + std::string name_save_folder = "/home/bourmaudp/Documents/PROJETS/gmds/gmds_Correction_Class_Dev/saveResults/cb2/"; + std::string id_act = std::to_string(m->m_AIdEdge); + std::string name_file = "cb2_action"+ id_act +".vtk"; + new_state->m_blocking->save_vtk_blocking(name_save_folder+name_file); return new_state; } else if(m->m_typeMove ==1) { new_state->m_blocking->cut_sheet(m->m_AIdEdge,m->m_AParamCut); + new_state->update_class(); + //SAVE Blocking vtk + std::string name_save_folder = "/home/bourmaudp/Documents/PROJETS/gmds/gmds_Correction_Class_Dev/saveResults/cb2/"; + std::string id_act = std::to_string(m->m_AIdEdge); + std::string name_file = "cb2_action"+ id_act +".vtk"; + new_state->m_blocking->save_vtk_blocking(name_save_folder+name_file); return new_state; } else{ - std::cerr << "Warning: Bad type move !" << std::endl; + std::cerr << "Warning: Bad type move ! \n Type move :" << m->m_typeMove << " & ID " << m->m_AIdEdge<< std::endl; return new_state; } + std::string name_save_folder = "/home/bourmaudp/Documents/PROJETS/gmds/gmds_Correction_Class_Dev/saveResults/"; + std::string id_act = std::to_string(m->m_AIdEdge); + std::string name_file = "M1_action"+ id_act +".vtk"; + m_blocking->save_vtk_blocking(name_save_folder+name_file); } /*----------------------------------------------------------------------------*/ @@ -69,6 +122,7 @@ double MCTSStatePolycube::state_rollout() const { std::cout<<"STATE ROLLOUT"< *list_action = actions_to_try(); + long long r; int a; MCTSStatePolycube *curstate = (MCTSStatePolycube *) this; // TODO: ignore const... srand(time(NULL)); bool first = true; do { - if (list_action->empty()) { - std::cerr << "Warning: Ran out of available moves and state is not terminal?"; - return 0.0; - } + std::deque *list_action = actions_to_try(); //Get first move/action //But, maybe, better to take rand move if its a delete move... MCTSMove *firstMove = list_action->front(); //TODO: implement random move when only delete moves is possible - list_action->pop(); + list_action->pop_front(); MCTSStatePolycube *old = curstate; + std::cout<<"===== SIZE UNTRIED ACTIONS : "<size()+1<<" ====="<next_state(firstMove); if (!first) { delete old; @@ -109,10 +161,10 @@ MCTSStatePolycube::state_rollout() const first = false; } while (!curstate->is_terminal()); - if(MCTSStatePolycube::result_terminal() == WIN){ + if(curstate->result_terminal() == WIN){ res=1; } - else if (MCTSStatePolycube::result_terminal() == LOSE) { + else if (curstate->result_terminal() == LOSE) { res=-1; } else{ @@ -126,17 +178,19 @@ MCTSStatePolycube::state_rollout() const MCTSStatePolycube::ROLLOUT_STATUS MCTSStatePolycube::result_terminal() const { - int max_nb_same = 3; - if (get_quality() == 0) { + if (m_class_errors.non_captured_points.empty() && m_class_errors.non_captured_curves.empty() && m_class_errors.non_captured_surfaces.empty()) { return WIN; } - else if (check_nb_same_quality() >= max_nb_same){ + else if (check_nb_same_quality() >= 3){ return DRAW; } - else if (m_history.back() < get_quality()){ + else if (!m_history.empty() && m_history.back() < this->get_quality()){ + return LOSE; + } + else if (this->actions_to_try()->empty()){ return LOSE; } - std::cerr << "ERROR: NOT terminal state !" << std::endl; + std::cerr << "ERROR: NOT terminal state ..." << std::endl; return DRAW; } /*----------------------------------------------------------------------------*/ @@ -164,7 +218,10 @@ MCTSStatePolycube::is_terminal() const else if(check_nb_same_quality() >= 3){ return true; } - else if(!m_history.empty() && m_history.back() < get_quality()){ + else if(!m_history.empty() && m_history.back() < this->get_quality()){ + return true; + } + else if(this->actions_to_try()->empty()){ return true; } else { @@ -178,7 +235,7 @@ double { return m_class_errors.non_captured_points.size() * 0.8 + m_class_errors.non_captured_curves.size() * 0.6 + m_class_errors.non_captured_surfaces.size() * 0.4; -} + } /*----------------------------------------------------------------------------*/ gmds::cad::GeomManager* MCTSStatePolycube::get_geom(){ return m_blocking->geom_model(); @@ -206,3 +263,10 @@ std::vector MCTSStatePolycube::get_history() const return m_history; } /*----------------------------------------------------------------------------*/ +void MCTSStatePolycube::update_class() +{ + gmds::blocking::CurvedBlockingClassifier classifier(m_blocking); + m_class_blocking = new blocking::CurvedBlockingClassifier(classifier); + m_class_errors = m_class_blocking->classify(0.2); +} +/*----------------------------------------------------------------------------*/ diff --git a/rlBlocking/src/MCTSTree.cpp b/rlBlocking/src/MCTSTree.cpp index dfb6f6d80..24044c6a4 100644 --- a/rlBlocking/src/MCTSTree.cpp +++ b/rlBlocking/src/MCTSTree.cpp @@ -28,7 +28,7 @@ MCTSNode::~MCTSNode() { delete children; while (!untried_actions->empty()) { delete untried_actions->front(); // if a move is here then it is not a part of a child node and needs to be deleted here - untried_actions->pop(); + untried_actions->pop_front(); } delete untried_actions; } @@ -44,7 +44,7 @@ void MCTSNode::expand() { } // get next untried action MCTSMove *next_move = untried_actions->front(); // get value - untried_actions->pop(); // remove it + untried_actions->pop_front(); // remove it MCTSState *next_state = state->next_state(next_move); if(state->get_quality() == next_state->get_quality()){ @@ -66,6 +66,12 @@ const MCTSState *MCTSNode::get_current_state() const return state; } /*----------------------------------------------------------------------------*/ +std::vector +*MCTSNode::get_children() +{ + return children; +} +/*----------------------------------------------------------------------------*/ bool MCTSNode::is_terminal() const { @@ -89,7 +95,9 @@ unsigned int MCTSNode::get_size() const { /*----------------------------------------------------------------------------*/ MCTSNode *MCTSNode::select_best_child(double c) const { /** selects best child based on the winrate of whose turn it is to play */ - if (children->empty()) return NULL; + if (children->empty()) { + return NULL; + } else if (children->size() == 1) return children->at(0); else { double uct, max = -1; @@ -166,11 +174,25 @@ void MCTSNode::print_stats() const { << "Tree size: " << size << std::endl << "Number of simulations: " << number_of_simulations << std::endl << "Branching factor at root: " << children->size() << std::endl; - // print TOPK of them along with their winrates -// std::cout << "Best moves:" << std::endl; -// for (int i = 0 ; i < children->size() && i < TOPK ; i++) { -// std::cout << " " << i + 1 << ". " << children->at(i)->move->sprint() << " --> " -// << std::setprecision(4) << 100.0 * children->at(i)->calculate_winrate(state->player1_turn()) << "%" << endl; + // Print the best move for a current node +// MCTSNode *bestChild; +// bool first = true; +// double winRateChild = 0; +// if(!children->empty()) { +// for (int i = 0; i < children->size(); i++) { +// if (first) { +// bestChild = children->at(i); +// winRateChild = bestChild->calculate_winrate(); +// first = false; +// } +// +// else if (winRateChild < children->at(i)->calculate_winrate()) { +// bestChild = children->at(i); +// winRateChild = bestChild->calculate_winrate(); +// } +// } +// std::cout << "Best Move :" << std::endl; +// bestChild->move->print(); // } std::cout << "________________________________" << std::endl; } diff --git a/rlBlocking/tst/MCTSTestSuite.h b/rlBlocking/tst/MCTSTestSuite.h index de5bce2b8..3d40cb22d 100644 --- a/rlBlocking/tst/MCTSTestSuite.h +++ b/rlBlocking/tst/MCTSTestSuite.h @@ -37,23 +37,20 @@ TEST(MCTSTestSuite, testExAglo) { gmds::cad::FACManager geom_model; - set_up_MCTS(&geom_model,"M1.vtk"); + set_up_MCTS(&geom_model,"cb2.vtk"); gmds::blocking::CurvedBlocking bl(&geom_model,true); - bl.save_vtk_blocking("/home/bourmaudp/Documents/PROJETS/gmds/gmds_Correction_Class_Dev/saveResults/M1_init_blocking.vtk"); - std::cout<<"NB points : "<< geom_model.getPoints().size()<execute(); + + std::cout<<"==================== END TEST ! ===================="< Date: Mon, 22 Jan 2024 10:11:29 +0100 Subject: [PATCH 4/4] removed the static global id for CurvedBlocking entities in favor of a local counter for each blocking --- blocking/inc/gmds/blocking/CurvedBlocking.h | 45 ++++++++++++++++----- blocking/src/CurvedBlocking.cpp | 36 +++++++++++++---- 2 files changed, 65 insertions(+), 16 deletions(-) diff --git a/blocking/inc/gmds/blocking/CurvedBlocking.h b/blocking/inc/gmds/blocking/CurvedBlocking.h index 05ccc8e3e..47149d57d 100644 --- a/blocking/inc/gmds/blocking/CurvedBlocking.h +++ b/blocking/inc/gmds/blocking/CurvedBlocking.h @@ -22,6 +22,15 @@ namespace gmds { /*----------------------------------------------------------------------------*/ namespace blocking { /*----------------------------------------------------------------------------*/ +class Counter{ + public: + Counter(int c) + : m_counter_global_id(c){} + int get_and_increment_id(){return m_counter_global_id++;} + int value(){return m_counter_global_id;} + private: + int m_counter_global_id; +}; /**@struct CellInfo * @brief This structure gather the pieces of data that are shared by any * blocking cell. Each cell is defined by: @@ -41,22 +50,29 @@ struct CellInfo int topo_id; /*** link to the cad manager to have access to geometric cells */ cad::GeomManager* geom_manager; + /*** link to the counter used to assign a unique id to each entity */ + Counter* counter; /*** dimension of the geometrical cell we are classifid on */ int geom_dim; /*** unique id of the geomtrical cell */ int geom_id; - /*** global counter used to assign an unique id to each block */ - static int m_counter_global_id; /** @brief Constructor + * @param Ac the id counter; the CGAL gmap copy constructor requires a CellInfo() + * call with no params * @param AManager the geometric manager to access cells * @param ATopoDim Cell dimension * @param AGeomDim on-classify geometric cell dimension (4 if not classified) * @param AGeomId on-classify geometric cell unique id */ - CellInfo(cad::GeomManager* AManager=NULL, const int ATopoDim = 4, const int AGeomDim = 4, const int AGeomId = NullID) : - topo_dim(ATopoDim), topo_id(m_counter_global_id++), geom_manager(AManager),geom_dim(AGeomDim), geom_id(AGeomId) + CellInfo(Counter* Ac=nullptr, cad::GeomManager* AManager=nullptr, const int ATopoDim = 4, const int AGeomDim = 4, const int AGeomId = NullID) : + topo_dim(ATopoDim), geom_manager(AManager), counter(Ac), geom_dim(AGeomDim), geom_id(AGeomId) { + if(Ac != nullptr) { + topo_id = Ac->get_and_increment_id(); + } else { + topo_id = -1; + } } }; /*----------------------------------------------------------------------------*/ @@ -69,13 +85,15 @@ struct NodeInfo : CellInfo /*** node location in space, i.e. a single point */ math::Point point; /** @brief Constructor + * @param Ac the id counter; the CGAL gmap copy constructor requires a CellInfo() + * call with no params * @param AManager the geometric manager to access cells * @param AGeomDim on-classify geometric cell dimension (4 if not classified) * @param AGeomId on-classify geometric cell unique id * @param APoint geometric location */ - NodeInfo(cad::GeomManager* AManager=NULL, const int AGeomDim = 4, const int AGeomId = NullID, const math::Point &APoint = math::Point(0, 0, 0)) : - CellInfo(AManager, 0, AGeomDim, AGeomId), point(APoint) + NodeInfo(Counter* Ac=nullptr, cad::GeomManager* AManager=nullptr, const int AGeomDim = 4, const int AGeomId = NullID, const math::Point &APoint = math::Point(0, 0, 0)) : + CellInfo(Ac, AManager, 0, AGeomDim, AGeomId), point(APoint) { } }; @@ -208,8 +226,7 @@ struct SplitFunctor ca2.info().geom_dim = ca1.info().geom_dim; ca2.info().geom_id = ca1.info().geom_id; ca2.info().topo_dim = ca1.info().topo_dim; - ca2.info().topo_id = CellInfo::m_counter_global_id++; - + ca2.info().topo_id = ca1.info().counter->get_and_increment_id(); } }; @@ -227,7 +244,7 @@ struct SplitFunctorNode ca2.info().geom_id = ca1.info().geom_id; ca2.info().point = ca1.info().point; ca2.info().topo_dim = ca1.info().topo_dim; - ca2.info().topo_id = CellInfo::m_counter_global_id++; + ca2.info().topo_id = ca1.info().counter->get_and_increment_id(); } }; /*----------------------------------------------------------------------------*/ @@ -278,6 +295,8 @@ class LIB_GMDS_BLOCKING_API CurvedBlocking */ CurvedBlocking(cad::GeomManager *AGeomModel, bool AInitAsBoundingBox = false); + CurvedBlocking(const CurvedBlocking &ABl); + /** @brief Destructor */ virtual ~CurvedBlocking(); @@ -685,6 +704,11 @@ class LIB_GMDS_BLOCKING_API CurvedBlocking */ std::vector> get_projection_info(math::Point &AP, std::vector &AEdges); + Counter* getCounter() + { + return &m_counter; + } + private: /**@brief Mark with @p AMark all the darts of orbit <0,1>(@p ADart) @@ -736,6 +760,9 @@ class LIB_GMDS_BLOCKING_API CurvedBlocking cad::GeomManager *m_geom_model; /*** the underlying n-g-map model*/ GMap3 m_gmap; + + /*** id counter*/ + Counter m_counter; }; /*----------------------------------------------------------------------------*/ } // namespace blocking diff --git a/blocking/src/CurvedBlocking.cpp b/blocking/src/CurvedBlocking.cpp index b6ff874cc..de461f295 100644 --- a/blocking/src/CurvedBlocking.cpp +++ b/blocking/src/CurvedBlocking.cpp @@ -4,10 +4,12 @@ using namespace gmds; using namespace gmds::blocking; /*----------------------------------------------------------------------------*/ -int CellInfo::m_counter_global_id = 0; +//int CellInfo::m_counter_global_id = 0; /*----------------------------------------------------------------------------*/ -CurvedBlocking::CurvedBlocking(cad::GeomManager *AGeomModel, bool AInitAsBoundingBox) : m_geom_model(AGeomModel) { +CurvedBlocking::CurvedBlocking(cad::GeomManager *AGeomModel, bool AInitAsBoundingBox) + : m_geom_model(AGeomModel), m_counter(0) +{ if (AInitAsBoundingBox) { TCoord min[3] = {MAXFLOAT, MAXFLOAT, MAXFLOAT}; TCoord max[3] = {-MAXFLOAT, -MAXFLOAT, -MAXFLOAT}; @@ -32,7 +34,27 @@ CurvedBlocking::CurvedBlocking(cad::GeomManager *AGeomModel, bool AInitAsBoundin create_block(p1, p2, p3, p4, p5, p6, p7, p8); } } - +/*----------------------------------------------------------------------------*/ +CurvedBlocking::CurvedBlocking(const CurvedBlocking &ABl) +: m_geom_model(ABl.m_geom_model), m_gmap(ABl.m_gmap), m_counter(ABl.m_counter) +{ + auto listBlocks = get_all_blocks(); + for(auto b : listBlocks){ + b->info().counter = &m_counter; + } + auto listFaces = get_all_faces(); + for(auto b : listFaces){ + b->info().counter = &m_counter; + } + auto listEdges = get_all_edges(); + for(auto b : listEdges){ + b->info().counter = &m_counter; + } + auto listNodes = get_all_nodes(); + for(auto b : listNodes){ + b->info().counter = &m_counter; + } +} /*----------------------------------------------------------------------------*/ CurvedBlocking::~CurvedBlocking() {} @@ -51,25 +73,25 @@ CurvedBlocking::geom_model() { /*----------------------------------------------------------------------------*/ CurvedBlocking::Node CurvedBlocking::create_node(const int AGeomDim, const int AGeomId, const math::Point &APoint) { - return m_gmap.create_attribute<0>(NodeInfo(m_geom_model,AGeomDim, AGeomId, APoint)); + return m_gmap.create_attribute<0>(NodeInfo(this->getCounter(),m_geom_model,AGeomDim, AGeomId, APoint)); } /*----------------------------------------------------------------------------*/ CurvedBlocking::Edge CurvedBlocking::create_edge(const int AGeomDim, const int AGeomId) { - return m_gmap.create_attribute<1>(CellInfo(m_geom_model,1, AGeomDim, AGeomId)); + return m_gmap.create_attribute<1>(CellInfo(this->getCounter(),m_geom_model,1, AGeomDim, AGeomId)); } /*----------------------------------------------------------------------------*/ CurvedBlocking::Face CurvedBlocking::create_face(const int AGeomDim, const int AGeomId) { - return m_gmap.create_attribute<2>(CellInfo(m_geom_model,2, AGeomDim, AGeomId)); + return m_gmap.create_attribute<2>(CellInfo(this->getCounter(),m_geom_model,2, AGeomDim, AGeomId)); } /*----------------------------------------------------------------------------*/ CurvedBlocking::Block CurvedBlocking::create_block(const int AGeomDim, const int AGeomId) { - return m_gmap.create_attribute<3>(CellInfo(m_geom_model,3, AGeomDim, AGeomId)); + return m_gmap.create_attribute<3>(CellInfo(this->getCounter(),m_geom_model,3, AGeomDim, AGeomId)); } /*----------------------------------------------------------------------------*/