[WIP] Add a stablized function entropic_partial_wasserstein_logscale#724
[WIP] Add a stablized function entropic_partial_wasserstein_logscale#724wzm2256 wants to merge 5 commits intoPythonOT:masterfrom
Conversation
ot.partial. This function solves the same problem as entropic\_partial\_wasserstein but is computed in logscale, so it is more robust. 2. Test exampless are provided in compare_logscale_POT.py. Test data is in data\entropic_partial_OT_cost.txt
|
Hello @wzm2256 and thanks for the PR we are indeed missing a stabilized partial entropic ot solver. When we add a stabilized solver, we usually add the function and make the original function as a wrapper when we add the log implementation (see how it is done for function ot.sinkhorn ) so that with an additional parameter Finally your example is nice we usually try to add a visualization and we really try not to add raw data to the git. Could you design a simulated example instead of a new dataset? Thanks again for your work, we will do a proper code review as soon as possible but wold prefer to do that after the comments above are taken into account. |
|
Sure! I will work on that. |
|
Hello @wzm2256 , quick reminder that we are waiting for a few changes here. |
|
I tested the current implementation, for some reason its extremely slow at the first iteration itself. |
add a new function called entropic_partial_wasserstein_logscale to ot.partial. This function solves the same problem as entropic_partial_wasserstein but is computed in logscale, so it is more robust.
Test exampless are provided in compare_logscale_POT.py. Test data is in data\entropic_partial_OT_cost.txt
Types of changes
I implement a new function
entropic_partial_wasserstein_logscalethat solves exactly the same problem as the one inentropic_partial_wassersteinin log scale. The new function is a line-to-line translation of the old one, and the input/output format is exactly the same.I do not remove the old function because the new function can be slower due to the use of the logsumexp trick. So when there is no Nan error, the old function is favored.
Motivation and context / Related issue
#723
How has this been tested (if it applies)
I test the new function
entropic_partial_wasserstein_logscaleagainst the old oneentropic_partial_wassersteinin the example file 'compare_logscale_POT.py` for both numpy and pytorch.PR checklist
I could not build the document in my laptop due to some errors:
so I am not completely sure whether the document is fine, although I only added a few sentences to the docs.
Also, I do not know how to use pytest to test my code. If this is necessary, I may need some help here.