-
Notifications
You must be signed in to change notification settings - Fork 5
Fixed send message computations in Schafer-Shenoy #15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Big thanks for looking into this! Hugely appreciated! 👍
I had one comment/question regarding special-case handling.
Also, if it's easy, could you also write a single minimalist unit test that reproduces the original issue so it's easy to confirm that it was broken before and now it's fixed. It would also then stay fixed in the future.
Thanks a lot!
junctiontree/computation.py
Outdated
msg_prod = 1 if len(messages) == 0 else dl.einsum( | ||
*messages, | ||
neighbor_vars |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this special case 1
related only to the normal sum-product? If some other distributive law was used, would the value be different then? If so, should this if-else be inside the specific "sum-product distributive law" so it wouldn't affect all distributive laws? Or, alternatively, should each distributive law define an identity element ("empty value") that is used in this kind of cases? My gut feeling is that this latter option would make perfect sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, now that this if-else is only here, it won't fix all other places where dl.einsum
is used. I suppose this special-case handling should be used always when dl.einsum
is used, right? Therefore, even more so, I would suggest defining an identity element for the distributive laws.
Does this make sense or am I misunderstanding something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've been trying to come up with a simple unit test for the inconsistent results that were identified in the original issue. However, the issue derives from the lack of a guaranteed order when creating Python sets. There is an assumption about the ordering of the indices within the remove_messages
function implementation with regard to the ordering of the indices. The use of the set
function prior to calling remove_messages
produces inconsistent behavior. So, the incorrect results are only produced occasionally.
There are 2 possible fixes:
-
Use NumPy's set functions in place of the standard Python
set
function so that the indices are always ordered prior to being used as an argument toremove_message
. -
Rely on
numpy.einsum
to compute the product of the messages with the message from the current neighbor being excluded. I originally wrote theremove_message
function in order to avoid repeated calculations of the product of the messages. But the complexity of the code for performing the calculation to divide out the neighbor's message may not be worth any small efficiency gains from avoiding the repeated calculations.
My commit included both replacing the standard set
function calls with NumPy set implementations and removing the remove_message
function to simply the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added an additional commit in this branch to support an identity element for SumProduct distributive law
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok! If it's not easy to add a unit test, then that can be left out. 👍
Your most recent commit added new files computation.py
and sum_product.py
that are in the root of the repository, not inside junctiontree
package. I suppose this was a mistake, right? Being outside the package, they aren't used anywhere. Perhaps you meant to replace junctiontree/computation.py
and junctiontree/sum_product.py
?
this is for Sum Product distributive law to avoid special case handling in get_messages function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added one comment about the changes in the most recent commit.
junctiontree/computation.py
Outdated
msg_prod = 1 if len(messages) == 0 else dl.einsum( | ||
*messages, | ||
neighbor_vars |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok! If it's not easy to add a unit test, then that can be left out. 👍
Your most recent commit added new files computation.py
and sum_product.py
that are in the root of the repository, not inside junctiontree
package. I suppose this was a mistake, right? Being outside the package, they aren't used anywhere. Perhaps you meant to replace junctiontree/computation.py
and junctiontree/sum_product.py
?
Sorry about committing those files in the wrong location. The versions in the root directory were deleted and the two updated files are in the junctiontree directory where they belong. |
Most recent changes to computation.py are a direct implementation of the Shafer-Shenoy updates. My original implementation attempted some optimizations of the code prematurely resulting in an incorrect implementation. This implementation uses Einstein summations for all computations. As a result, there are no longer concerns about matrix index mismatches and the code is simplified. |
Fixed index mismatches causing #13 results and non-reproducibility. Removed
remove_messages
function so that excluded message is not included in the product of messages calculation rather than trying to remove it after product is computed.