-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Added the initial implementation of KT-split #871
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added the comments which we mostly discussed previously. After you have a final working implementation, add some unit tests or create an issue for those (but it might be helpful for you to check that everything is working).
probabilistic components in the algorithm. | ||
""" | ||
|
||
kernel: ScalarValuedKernel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add the defaults for some parameters as discussed previously
coreax/solvers/coresubset.py
Outdated
random_key: KeyArrayLike | ||
|
||
@classmethod | ||
def get_swap_params( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems to be the same as get_a_and_param in kt_half
, do you need both?
coreax/solvers/coresubset.py
Outdated
final_coresets = self.kt_split(dataset) | ||
return self.kt_refine(self.kt_choose(final_coresets, dataset)), solver_state | ||
|
||
def kt_half_recursive(self, points, m, original_dataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would rename points
to something like current_subset
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, add type annotations.
""" | ||
n = len(points) // 2 | ||
original_array = points.data | ||
arr1 = jnp.zeros(n, dtype=jnp.int32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use more descriptive variable names and preferable add a comment before to explain what they are for. If the variables are simply temporary placeholders, explain that in a comment.
subset1 = eqx.tree_at(lambda x: x.nodes.data, subset1, subset1_indices) | ||
subset2 = eqx.tree_at(lambda x: x.nodes.data, subset2, subset2_indices) | ||
|
||
# Recur for both subsets and concatenate results |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
recur -> recurse
alpha = term1 + term2 | ||
return alpha, bool_arr_1, bool_arr_2 | ||
|
||
def final_function( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, just be more descriptive in the name, e.g., apply_probabilistic_assignment
(just an example, feel free to make it more appropriate of course).
coreax/solvers/coresubset.py
Outdated
|
||
return Coresubset(final_arr1, points), Coresubset(final_arr2, points) | ||
|
||
def kt_split(self, points: _Data) -> list[Coresubset[_Data]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed, we probably want to remove this from here for now but save it for later.
|
||
return final_coresets | ||
|
||
def kt_choose( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will likely have to be changed to be jit-compatible, e.g., using vmap to get a vector of MMD values and then jnp.argmin to select the best.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, implement the baseline coreset computation and expose the method for this as parameter (random is probably a good default).
coreax/solvers/coresubset.py
Outdated
|
||
return best_coreset | ||
|
||
def kt_refine(self, candidate_coreset: Coresubset[_Data]) -> Coresubset[_Data]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to use the Kernel Herding refine method here.
coreax/solvers/coresubset.py
Outdated
|
||
return a, new_sigma | ||
|
||
def reduce( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we want to make this an ExplicitSizeSolver, this might be a place to do the logic of discarding and padding the points. Also, this will provide the coreset_size
parameter, so you will probably want to remove m
as a parameter and compute it as log2(data_size/coreset_size)
(after discarding etc).
# Conflicts: # coreax/solvers/coresubset.py
…hinning back to required. #893
PR Type
Description
How Has This Been Tested?
Checklist before requesting a review