-
Notifications
You must be signed in to change notification settings - Fork 3.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support of Python 3.12 #2098
Open
miguelgfierro
wants to merge
35
commits into
staging
Choose a base branch
from
python312
base: staging
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Support of Python 3.12 #2098
Changes from all commits
Commits
Show all changes
35 commits
Select commit
Hold shift + click to select a range
b3abc8a
Enabling python 3.12 #2097
miguelgfierro 809a1d3
Enabling python 3.12 #2097
miguelgfierro 496b22f
Cornac>2
miguelgfierro 2b9ddcc
setuptools >= 67
daviddavo 69f2d36
Temporary use lightfm patch
daviddavo d497fc0
Merge branch 'staging' into python312
daviddavo c223e3d
Updated tensorflow
daviddavo 75b884b
Update pymanopt
daviddavo 0c28cf1
Installing pymanopt from git
daviddavo ad0c79d
Add setuptools (for python 3.12)
daviddavo cdf9515
Use numpy 1.26 on python 3.12
daviddavo a442edb
Update python version used to launch tests
daviddavo dff4b0c
Added numpy 1.26 setup requires on python 3.12
daviddavo 2af59a1
[test ci] Force numpy 1.26
daviddavo 100a6f5
Merge remote-tracking branch 'origin/staging' into python312
daviddavo a7208e8
Updated conda in azureml tests
daviddavo 94ccdeb
Update pymanopt usage
daviddavo cca9669
Use numpy 1.26.4 in Python 3.12
daviddavo 12924d6
Merge branch 'staging' into python312
daviddavo 5f02f45
Updated pymanopt in aml_utils
daviddavo 18ff5b7
Relaxed numpy requirements
daviddavo 913abde
Removed Python 3.8 tests
daviddavo 3b0d1a1
Bumped required python version
daviddavo 108c8a0
Upgraded azure python version to python 3.12
daviddavo 2e2c99a
Updated tf-keras for new python versions
daviddavo 8f0265a
Revert "Removed Python 3.8 tests"
daviddavo 7705b5d
Revert "Bumped required python version"
daviddavo 4d4779a
#2138 Using requirements-external to specify git deps
daviddavo 1cbe6b6
Fixed conda env file
daviddavo 734e4ac
Changed requirements-external order
daviddavo 7b01eff
Install requirements-external.txt from GitHub
daviddavo 8515dbf
Relaxed scipy version for Pyton 3.8
daviddavo 4a544b2
Merge branch 'staging' into python312
daviddavo 450bf17
Merge in staging
SimonYansenZhao 9e00c0c
Solve merge conflict
miguelgfierro File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,8 +11,9 @@ | |
from numba import njit, prange | ||
from pymanopt import Problem | ||
from pymanopt.manifolds import Stiefel, Product, SymmetricPositiveDefinite | ||
from pymanopt.solvers import ConjugateGradient | ||
from pymanopt.solvers.linesearch import LineSearchBackTracking | ||
from pymanopt.autodiff.backends import numpy as backend_decorator | ||
from pymanopt.optimizers import ConjugateGradient | ||
from pymanopt.optimizers.line_search import BackTrackingLineSearcher | ||
|
||
|
||
class IMCProblem(object): | ||
|
@@ -68,23 +69,19 @@ def _computeLoss_csrmatrix(a, b, cd, indices, indptr, residual_global): | |
residual_global[j] = num - cd[j] | ||
return residual_global | ||
|
||
def _cost(self, params, residual_global): | ||
def _cost(self, U, S, VT, residual_global): | ||
"""Compute the cost of GeoIMC optimization problem | ||
|
||
Args: | ||
params (Iterator): An iterator containing the manifold point at which | ||
the cost needs to be evaluated. | ||
residual_global (csr_matrix): Residual matrix. | ||
""" | ||
U = params[0] | ||
B = params[1] | ||
V = params[2] | ||
|
||
regularizer = 0.5 * self.lambda1 * np.sum(B**2) | ||
regularizer = 0.5 * self.lambda1 * np.sum(S**2) | ||
|
||
IMCProblem._computeLoss_csrmatrix( | ||
self.X.dot(U.dot(B)), | ||
V.T.dot(self.Z.T), | ||
self.X.dot(U.dot(S)), | ||
VT.T.dot(self.Z.T), | ||
self.Y.data, | ||
self.Y.indices, | ||
self.Y.indptr, | ||
|
@@ -94,37 +91,33 @@ def _cost(self, params, residual_global): | |
|
||
return cost | ||
|
||
def _egrad(self, params, residual_global): | ||
def _egrad(self, U, S, VT, residual_global): | ||
"""Computes the euclidean gradient | ||
|
||
Args: | ||
params (Iterator): An iterator containing the manifold point at which | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please update docstring |
||
the cost needs to be evaluated. | ||
residual_global (csr_matrix): Residual matrix. | ||
""" | ||
U = params[0] | ||
B = params[1] | ||
V = params[2] | ||
|
||
residual_global_csr = csr_matrix( | ||
(residual_global, self.Y.indices, self.Y.indptr), | ||
shape=self.shape, | ||
) | ||
|
||
gradU = ( | ||
np.dot(self.X.T, residual_global_csr.dot(self.Z.dot(V.dot(B.T)))) | ||
np.dot(self.X.T, residual_global_csr.dot(self.Z.dot(VT.dot(S.T)))) | ||
/ self.nSamples | ||
) | ||
|
||
gradB = ( | ||
np.dot((self.X.dot(U)).T, residual_global_csr.dot(self.Z.dot(V))) | ||
np.dot((self.X.dot(U)).T, residual_global_csr.dot(self.Z.dot(VT))) | ||
/ self.nSamples | ||
+ self.lambda1 * B | ||
+ self.lambda1 * S | ||
) | ||
gradB_sym = (gradB + gradB.T) / 2 | ||
|
||
gradV = ( | ||
np.dot((self.X.dot(U.dot(B))).T, residual_global_csr.dot(self.Z)).T | ||
np.dot((self.X.dot(U.dot(S))).T, residual_global_csr.dot(self.Z)).T | ||
/ self.nSamples | ||
) | ||
|
||
|
@@ -154,20 +147,29 @@ def _optimize(self, max_opt_time, max_opt_iter, verbosity): | |
residual_global = np.zeros(self.Y.data.shape) | ||
|
||
solver = ConjugateGradient( | ||
maxtime=max_opt_time, | ||
maxiter=max_opt_iter, | ||
linesearch=LineSearchBackTracking(), | ||
max_time=max_opt_time, | ||
max_iterations=max_opt_iter, | ||
line_searcher=BackTrackingLineSearcher(), | ||
verbosity=verbosity, | ||
) | ||
|
||
@backend_decorator(self.manifold) | ||
def _cost(u, s, vt): | ||
return self._cost(u, s, vt, residual_global) | ||
|
||
@backend_decorator(self.manifold) | ||
def _egrad(u, s, vt): | ||
return self._egrad(u, s, vt, residual_global) | ||
|
||
prb = Problem( | ||
manifold=self.manifold, | ||
cost=lambda x: self._cost(x, residual_global), | ||
egrad=lambda z: self._egrad(z, residual_global), | ||
verbosity=verbosity, | ||
cost=_cost, | ||
euclidean_gradient=_egrad, | ||
) | ||
solution = solver.solve(prb, x=self.W) | ||
self.W = [solution[0], solution[1], solution[2]] | ||
solution = solver.run(prb, initial_point=self.W) | ||
self.W = [solution.point[0], solution.point[1], solution.point[2]] | ||
|
||
return self._cost(self.W, residual_global) | ||
return solution.cost | ||
|
||
def reset(self): | ||
"""Reset the model.""" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# 2024/02/29: pymanopt bumped to python 3.8 | ||
pymanopt @git+https://github.com/pymanopt/pymanopt@e13cecaec3089c790cc93174840b2f557d179b3f ; python_version<'3.12' | ||
|
||
# Jun 2024: Fixes py312 | ||
pymanopt @git+https://github.com/pymanopt/pymanopt@1de3b6f47258820fdc072fceaeaa763b9fd263b0 ; python_version>='3.12' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to update the docstring too